In [1]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm
In [2]:
import pytorch_lightning as pl

1. Data

Load data

In [3]:
label_list = [
    'T-shirt/top',
    'Trouser',
    'Pullover',
    'Dress',
    'Coat',
    'Sandal',
    'Shirt',
    'Sneaker',
    'Bag',
    'Ankle boot'
    ]
In [4]:
labels = np.frombuffer(open('./t10k-labels-idx1-ubyte', 'rb').read(), dtype=np.uint8,
                               offset=8)

images = np.frombuffer(open('./t10k-images-idx3-ubyte', 'rb').read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
In [5]:
images.shape, labels.shape
Out[5]:
((10000, 784), (10000,))
In [ ]:
idx = np.random.randint(10000)
plt.imshow(images[idx].reshape(28, 28), cmap='gray')
label_list[labels[idx]]

Dataset

In [14]:
class FashionDS(Dataset):
    index = None
    def __init__(self, train = True):
        self.imgs = images
        self.labels = labels
        self.label_lst = label_list
        self.train = train
        self.train_fraction = len(labels) * 8 // 10
        if FashionDS.index is None:
            FashionDS.index = np.random.permutation(np.arange(len(self.labels)))
        
    def __len__(self):
        if self.train:
            return self.train_fraction
        else:
            return len(self.labels) - self.train_fraction
    
    def __getitem__(self, idx):
        if self.train:
            idx2 = FashionDS.index[idx]
        else:
            idx2 = FashionDS.index[self.fraction + idx]
        return self.imgs[idx2].reshape(28, 28), self.labels[idx2]
In [15]:
ds = FashionDS()
print(len(ds))
idx = np.random.randint(len(ds))
x, y = ds[idx]
plt.imshow(x, cmap='gray')
ds.label_lst[y]
8000
Out[15]:
'Pullover'
In [16]:
valds = FashionDS(False)
print(len(valds))
idx = np.random.randint(len(valds))
x, y = ds[idx]
plt.imshow(x, cmap='gray')
ds.label_lst[y]
2000
Out[16]:
'Bag'

Dataloader

In [17]:
dl = DataLoader(ds, batch_size=16, shuffle=True)
In [18]:
x, y = next(iter(dl))
x.shape, y.shape
Out[18]:
(torch.Size([16, 28, 28]), torch.Size([16]))

2. Model

Vanilla pytorch model

In [ ]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
In [80]:
net = Net()
In [83]:
x, y = next(iter(dl))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape
Out[83]:
torch.Size([16, 10])

Loss

In [99]:
torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[99]:
2.314405679702759

optimizer

In [95]:
opt = torch.optim.SGD(net.parameters(), 3e-4, 0.9)

Train

In [ ]:
loss_lst = []
acc = []
for epoch in range(5):
    print(epoch)
    for x, y in tqdm(dl):
        
        opt.zero_grad()
        x = x.float() / 255.
        yhat = net(x[:, None, ...])
        acc += [(yhat.argmax(dim=1) == y).sum()*1.0 / len(y)]
            
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        loss_lst += [loss.item()]
        loss.backward()
        
        opt.step()        
In [ ]:
plt.plot(loss_lst)
In [ ]:
plt.plot(acc)
In [116]:
np.mean(acc)
Out[116]:
0.81164

Lightning model

In [125]:
class Net(pl.LightningModule):                    ## changed
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
In [126]:
net = Net()

Sanity check

In [127]:
x, y = next(iter(dl))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[127]:
2.3112006187438965
In [128]:
opt = torch.optim.SGD(net.parameters(), 3e-4, 0.9)
In [ ]:
loss_lst = []
acc = []
for epoch in range(5):
    print(epoch)
    for x, y in tqdm(dl):
        
        opt.zero_grad()
        x = x.float() / 255.
        yhat = net(x[:, None, ...])
        acc += [(yhat.argmax(dim=1) == y).sum()*1.0 / len(y)]
            
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        loss_lst += [loss.item()]
        loss.backward()
        
        opt.step()        
In [ ]:
plt.plot(loss_lst)
In [ ]:
plt.plot(acc)
In [132]:
np.mean(acc)
Out[132]:
0.6867

Integrated training loop

In [21]:
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
    
    def training_step(self, batch, batch_id):                   ## changed
        x, y = batch
        x = x.float() / 255.
        yhat = self(x[:, None, ...])
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        return loss
    
    def configure_optimizers(self):                             ## changed
        opt = torch.optim.SGD(net.parameters(), 3e-4, 0.9)      
        return opt
In [22]:
net = Net()

Sanity check

In [23]:
x, y = next(iter(dl))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[23]:
2.3086438179016113
In [ ]:
trainer = pl.Trainer()            ## changed
trainer.max_epochs = 5            ## changed
trainer.fit(net, dl)              ## changed

Log training loss and accuracy

In [48]:
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
    
    def training_step(self, batch, batch_id):
        x, y = batch
        x = x.float() / 255.
        yhat = self(x[:, None, ...])
        acc = (yhat.argmax(dim=1) == y).sum()*1.0 / len(y)
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        self.log('train loss', loss.item())                     ## changed
        self.log('train acc', acc)                              ## changed
        return loss 
    
    def configure_optimizers(self):
        opt = torch.optim.SGD(net.parameters(), 3e-4, 0.9)
        return opt
In [49]:
net = Net()

Sanity check

In [50]:
x, y = next(iter(dl))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[50]:
2.2562389373779297
In [ ]:
trainer = pl.Trainer()
trainer.max_epochs = 5
trainer.fit(net, dl)

Integrate training dataloder into model

In [19]:
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
    
    def training_step(self, batch, batch_id):
        x, y = batch
        x = x.float() / 255.
        yhat = self(x[:, None, ...])
        acc = (yhat.argmax(dim=1) == y).sum()*1.0 / len(y)
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        self.log('train loss', loss.item())
        self.log('train acc', acc)
        return loss
    
    def configure_optimizers(self):
        opt = torch.optim.SGD(net.parameters(), 3e-4, 0.9)
        return opt
        
    def train_dataloader(self):                       ## changed
        return DataLoader(FashionDS(), batch_size=64)
In [20]:
net = Net()

Sanity check

In [21]:
x, y = next(iter(net.train_dataloader()))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[21]:
2.28153133392334
In [ ]:
trainer = pl.Trainer()
trainer.max_epochs = 5
trainer.fit(net)        ## changed

LR Find

In [62]:
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
    
    def training_step(self, batch, batch_id):
        x, y = batch
        x = x.float() / 255.
        yhat = self(x[:, None, ...])
        acc = (yhat.argmax(dim=1) == y).sum()*1.0 / len(y)
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        self.log('train loss', loss.item())
        self.log('train acc', acc)
        return loss
    
    def configure_optimizers(self):
        self.lr = 3e-4
        opt = torch.optim.SGD(net.parameters(), self.lr, 0.9)       ## changed
        return opt
        
    def train_dataloader(self):
        return DataLoader(FashionDS(), batch_size=64)
In [63]:
net = Net()

Sanity check

In [64]:
x, y = next(iter(net.train_dataloader()))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[64]:
2.321258306503296
In [ ]:
trainer = pl.Trainer()
trainer.max_epochs = 10
In [ ]:
lrf = trainer.tuner.lr_find(net)  ## changed
fig = lrf.plot(suggest=True)      ## changed
fig.show()                        ## changed
In [67]:
net.hparams.lr = 1e-2            ## changed
In [ ]:
trainer.fit(net)

LR Schedulars & log LR

In [75]:
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.linear = torch.nn.Linear(26*26*5, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.nn.functional.relu(out)
        out = out.view(-1, 26*26*5)
        out = self.linear(out)
        return out
    
    def training_step(self, batch, batch_id):
        x, y = batch
        x = x.float() / 255.
        yhat = self(x[:, None, ...])
        acc = (yhat.argmax(dim=1) == y).sum()*1.0 / len(y)
        loss = torch.nn.functional.cross_entropy(yhat, y.long())
        self.log('train loss', loss.item())
        self.log('train acc', acc)
        return loss
    
    def configure_optimizers(self):
        self.lr = 3e-4
        opt = torch.optim.SGD(net.parameters(), self.lr, 0.9)
        sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, 200) ## changed
        return [opt], [{                                                       ## changed
                'scheduler': sched,                                            ## changed
                'interval': 'step', # The unit of the scheduler's step size    ## changed
            }]
        
    def train_dataloader(self):
        return DataLoader(FashionDS(), batch_size=64)
In [76]:
net = Net()

Sanity check

In [73]:
x, y = next(iter(net.train_dataloader()))
x = x.float() / 255.
yhat = net(x[:, None, ...])
yhat.shape

torch.nn.functional.cross_entropy(yhat, y.long()).item()
Out[73]:
2.314138889312744
In [ ]:
trainer = pl.Trainer(callbacks=[pl.callbacks.LearningRateMonitor('step')])  ## changed
trainer.max_epochs = 3
net.hparams.lr = 1e-2
trainer.fit(net)

3. Visualize

In [ ]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs
In [ ]: