Table of Contents
· Note 1: train() and eval() mode
· Note 2: save model
· Note 3: GPU/CPU
Note 1: train() and eval() mode
We should use eval
mode when we perform evaluation; however in the training stage we should remain the train
mode.
It is also important to mention that with torch.no_grad
in the eval
mode.
model = Classifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
epochs = 30
steps = 0
train_losses, test_losses = [], []
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
optimizer.zero_grad()
log_ps = model(images)
loss = criterion(log_ps, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
test_loss = 0
accuracy = 0
with torch.no_grad():
model.eval()
for images, labels in testloader:
log_ps = model(images)
test_loss += criterion(log_ps, labels)
ps = torch.exp(log_ps)
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
model.train()
train_losses.append(running_loss/len(trainloader))
test_losses.append(test_loss/len(testloader))
In the inference stage, it is also important to use eval
mode
model.eval()
with torch.no_grad():
output = model.forward(img)
model.train()
This is an important step because batch normalization has different behavior during training on a batch or testing/evaluating on a larger dataset.
Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model.train() to ensure these layers are in training mode.
Note 2: save model
Simple way
Method 1: state_dict
torch.save(model.state_dict(), 'checkpoint.pth')
state_dict = torch.load('checkpoint.pth') model.load_state_dict(state_dict)
Method 2: save model
FILE = "model.pth"
torch.save(model, FILE)
loaded_model = torch.load(FILE)
loaded_model.eval()
This way has the assumption is that model has the same architecture.
A better way is to keep additional information:
Method 3
checkpoint = {'input_size': 784,
'output_size': 10,
'hidden_layers': [each.out_features for each in model.hidden_layers],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = fc_model.Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'])
model.load_state_dict(checkpoint['state_dict'])
return model
It is also possible to export this model to ONNX
Method 4
input_image = torch.zeros((1,3,224,224))
onnx.export(model, input_image, 'data/model.onnx')
It also possible to save optimizer information:
checkpoint = {"epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()}FILE = "checkpoint.pth"
torch.save(checkpoint, FILE)model = Model(n_input_features=6)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0)checkpoint = torch.load(FILE)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
Note 3: GPU/CPU
3.1 Training
When GPU is employed in training, make sure that all the model parameters, input data, input labels are allocated to GPU. This can be done using
model.to(torch.device('cuda:0'))
data.to()
label.to()
or
model.cuda()
, data.cuda()
, label.cuda()
We can use the following function to detect GPU:
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU ...')
else:
print('CUDA is available! Training on GPU ...')
3.2 Testing
Models trained on CPU can be used for models running on GPU, and vice versa
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)
device = torch.device('cpu')
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
Note 4: Random number
We need fix random number in order to get consistent training model.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)