src.tests.plot_test

 1from os.path import exists
 2
 3import torch
 4from torch.utils.tensorboard.writer import SummaryWriter
 5
 6from src.data import get_mnist_dataloaders
 7from src.model import BasicCNN
 8from src.plot import generate_projection
 9from src.plot import plot_confusion_matrix
10from src.plot import saliency_map
11
12
13def test_plot_confusion_matrix(tmpdir):
14    """Generate a confusion matrix using a dataloader.  Used to exercise the code."""
15
16    # Create a model and dataloader
17    model = BasicCNN()
18    _, dataloader_val, _, _ = get_mnist_dataloaders()
19    figure_name = "confusion_matrix_test"
20
21    # Exercise plotting
22    save_path = f"{tmpdir}"
23    plot_confusion_matrix(
24        model=model,
25        dataloader_val=dataloader_val,
26        save_path=save_path,
27        fig_name=figure_name,
28    )
29
30    # Check that a plot was generated
31    assert exists(f"{save_path}/{figure_name}.png"), "Could not find confusion matrix image file!"
32
33
34def test_generate_projection(tmpdir):
35    """Generate a projection using an untrained model.  Used to exercise the code."""
36
37    # Create a model, writer, and dataloader
38    model = BasicCNN()
39    log_path = str(tmpdir)
40    writer = SummaryWriter(log_dir=log_path)
41    _, dataloader_val, _, _ = get_mnist_dataloaders()
42
43    # Exercise plotting
44    generate_projection(model=model, dataloader_val=dataloader_val, writer=writer)
45
46    # Verify that logs were created (does not check if they are correct!)
47    assert exists(
48        f"{log_path}/projector_config.pbtxt"
49    ), "Could not find projector_config.pbtxt file!"
50
51
52def test_saliency_map(tmpdir):
53    """Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid.  Used to exercise the code."""
54
55    # Create a model
56    model = BasicCNN()
57    save_path = str(tmpdir)
58    figure_name = "test"
59    global_step = 12
60    model_weights_path = f"{save_path}/model.pkl"
61    torch.save(model.state_dict(), model_weights_path)
62
63    # Get a Dataset
64    _, _, _, dataset_val = get_mnist_dataloaders()
65
66    saliency_map(
67        model=model,
68        dataset_val=dataset_val,
69        figure_name=figure_name,
70        global_step=global_step,
71        model_weights_path=model_weights_path,
72        save_path=save_path,
73    )
74
75    assert exists(
76        f"{save_path}/{figure_name}_{global_step}.png"
77    ), "Could not find saliency.png file!"
def test_plot_confusion_matrix(tmpdir):
14def test_plot_confusion_matrix(tmpdir):
15    """Generate a confusion matrix using a dataloader.  Used to exercise the code."""
16
17    # Create a model and dataloader
18    model = BasicCNN()
19    _, dataloader_val, _, _ = get_mnist_dataloaders()
20    figure_name = "confusion_matrix_test"
21
22    # Exercise plotting
23    save_path = f"{tmpdir}"
24    plot_confusion_matrix(
25        model=model,
26        dataloader_val=dataloader_val,
27        save_path=save_path,
28        fig_name=figure_name,
29    )
30
31    # Check that a plot was generated
32    assert exists(f"{save_path}/{figure_name}.png"), "Could not find confusion matrix image file!"

Generate a confusion matrix using a dataloader. Used to exercise the code.

def test_generate_projection(tmpdir):
35def test_generate_projection(tmpdir):
36    """Generate a projection using an untrained model.  Used to exercise the code."""
37
38    # Create a model, writer, and dataloader
39    model = BasicCNN()
40    log_path = str(tmpdir)
41    writer = SummaryWriter(log_dir=log_path)
42    _, dataloader_val, _, _ = get_mnist_dataloaders()
43
44    # Exercise plotting
45    generate_projection(model=model, dataloader_val=dataloader_val, writer=writer)
46
47    # Verify that logs were created (does not check if they are correct!)
48    assert exists(
49        f"{log_path}/projector_config.pbtxt"
50    ), "Could not find projector_config.pbtxt file!"

Generate a projection using an untrained model. Used to exercise the code.

def test_saliency_map(tmpdir):
53def test_saliency_map(tmpdir):
54    """Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid.  Used to exercise the code."""
55
56    # Create a model
57    model = BasicCNN()
58    save_path = str(tmpdir)
59    figure_name = "test"
60    global_step = 12
61    model_weights_path = f"{save_path}/model.pkl"
62    torch.save(model.state_dict(), model_weights_path)
63
64    # Get a Dataset
65    _, _, _, dataset_val = get_mnist_dataloaders()
66
67    saliency_map(
68        model=model,
69        dataset_val=dataset_val,
70        figure_name=figure_name,
71        global_step=global_step,
72        model_weights_path=model_weights_path,
73        save_path=save_path,
74    )
75
76    assert exists(
77        f"{save_path}/{figure_name}_{global_step}.png"
78    ), "Could not find saliency.png file!"

Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid. Used to exercise the code.