src.tests.train_test

 1import numpy as np
 2from rich import print
 3from torch.nn import CrossEntropyLoss
 4from torch.optim.adam import Adam
 5from torch.utils.tensorboard.writer import SummaryWriter
 6
 7from src.data import get_mnist_dataloaders
 8from src.model import BasicCNN
 9from src.train import training_iteration
10
11
12def test_training_iteration(tmpdir):
13    """Exercise the training iteration to ensure that everything performs as expected without errors."""
14
15    # Setup Dataloader and writers
16    _, dataloader_val, _, _ = get_mnist_dataloaders()
17    log_dir = str(tmpdir)
18    writer = SummaryWriter(log_dir=log_dir)
19
20    # Create a model and learning parameters
21    model = BasicCNN()
22    loss_fn = CrossEntropyLoss()
23    optimizer = Adam(model.parameters(), lr=0.001)
24
25    # Perform two training iterations to exercise the functionality
26    min_loss = np.inf
27    for global_step in range(0, 2):
28        average_loss = training_iteration(
29            model=model,
30            loss_fn=loss_fn,
31            optimizer=optimizer,
32            dataloader_train=dataloader_val,
33            global_step=global_step,
34            writer=writer,
35        )
36
37        # Verify loss is decreasing
38        print(f"Epoch:          {global_step}")
39        print(f"    Avg Loss:   {average_loss}")
40        assert average_loss < min_loss, "Loss is not converging!"  # type: ignore
41
42        # Update the min loss
43        if average_loss < min_loss:  # type: ignore
44            min_loss = average_loss
def test_training_iteration(tmpdir):
13def test_training_iteration(tmpdir):
14    """Exercise the training iteration to ensure that everything performs as expected without errors."""
15
16    # Setup Dataloader and writers
17    _, dataloader_val, _, _ = get_mnist_dataloaders()
18    log_dir = str(tmpdir)
19    writer = SummaryWriter(log_dir=log_dir)
20
21    # Create a model and learning parameters
22    model = BasicCNN()
23    loss_fn = CrossEntropyLoss()
24    optimizer = Adam(model.parameters(), lr=0.001)
25
26    # Perform two training iterations to exercise the functionality
27    min_loss = np.inf
28    for global_step in range(0, 2):
29        average_loss = training_iteration(
30            model=model,
31            loss_fn=loss_fn,
32            optimizer=optimizer,
33            dataloader_train=dataloader_val,
34            global_step=global_step,
35            writer=writer,
36        )
37
38        # Verify loss is decreasing
39        print(f"Epoch:          {global_step}")
40        print(f"    Avg Loss:   {average_loss}")
41        assert average_loss < min_loss, "Loss is not converging!"  # type: ignore
42
43        # Update the min loss
44        if average_loss < min_loss:  # type: ignore
45            min_loss = average_loss

Exercise the training iteration to ensure that everything performs as expected without errors.