src.tests.train_test
1import numpy as np 2from rich import print 3from torch.nn import CrossEntropyLoss 4from torch.optim.adam import Adam 5from torch.utils.tensorboard.writer import SummaryWriter 6 7from src.data import get_mnist_dataloaders 8from src.model import BasicCNN 9from src.train import training_iteration 10 11 12def test_training_iteration(tmpdir): 13 """Exercise the training iteration to ensure that everything performs as expected without errors.""" 14 15 # Setup Dataloader and writers 16 _, dataloader_val, _, _ = get_mnist_dataloaders() 17 log_dir = str(tmpdir) 18 writer = SummaryWriter(log_dir=log_dir) 19 20 # Create a model and learning parameters 21 model = BasicCNN() 22 loss_fn = CrossEntropyLoss() 23 optimizer = Adam(model.parameters(), lr=0.001) 24 25 # Perform two training iterations to exercise the functionality 26 min_loss = np.inf 27 for global_step in range(0, 2): 28 average_loss = training_iteration( 29 model=model, 30 loss_fn=loss_fn, 31 optimizer=optimizer, 32 dataloader_train=dataloader_val, 33 global_step=global_step, 34 writer=writer, 35 ) 36 37 # Verify loss is decreasing 38 print(f"Epoch: {global_step}") 39 print(f" Avg Loss: {average_loss}") 40 assert average_loss < min_loss, "Loss is not converging!" # type: ignore 41 42 # Update the min loss 43 if average_loss < min_loss: # type: ignore 44 min_loss = average_loss
def
test_training_iteration(tmpdir):
13def test_training_iteration(tmpdir): 14 """Exercise the training iteration to ensure that everything performs as expected without errors.""" 15 16 # Setup Dataloader and writers 17 _, dataloader_val, _, _ = get_mnist_dataloaders() 18 log_dir = str(tmpdir) 19 writer = SummaryWriter(log_dir=log_dir) 20 21 # Create a model and learning parameters 22 model = BasicCNN() 23 loss_fn = CrossEntropyLoss() 24 optimizer = Adam(model.parameters(), lr=0.001) 25 26 # Perform two training iterations to exercise the functionality 27 min_loss = np.inf 28 for global_step in range(0, 2): 29 average_loss = training_iteration( 30 model=model, 31 loss_fn=loss_fn, 32 optimizer=optimizer, 33 dataloader_train=dataloader_val, 34 global_step=global_step, 35 writer=writer, 36 ) 37 38 # Verify loss is decreasing 39 print(f"Epoch: {global_step}") 40 print(f" Avg Loss: {average_loss}") 41 assert average_loss < min_loss, "Loss is not converging!" # type: ignore 42 43 # Update the min loss 44 if average_loss < min_loss: # type: ignore 45 min_loss = average_loss
Exercise the training iteration to ensure that everything performs as expected without errors.