src.tests.plot_test
1from os.path import exists 2 3import torch 4from torch.utils.tensorboard.writer import SummaryWriter 5 6from src.data import get_mnist_dataloaders 7from src.model import BasicCNN 8from src.plot import generate_projection 9from src.plot import plot_confusion_matrix 10from src.plot import saliency_map 11 12 13def test_plot_confusion_matrix(tmpdir): 14 """Generate a confusion matrix using a dataloader. Used to exercise the code.""" 15 16 # Create a model and dataloader 17 model = BasicCNN() 18 _, dataloader_val, _, _ = get_mnist_dataloaders() 19 figure_name = "confusion_matrix_test" 20 21 # Exercise plotting 22 save_path = f"{tmpdir}" 23 plot_confusion_matrix( 24 model=model, 25 dataloader_val=dataloader_val, 26 save_path=save_path, 27 fig_name=figure_name, 28 ) 29 30 # Check that a plot was generated 31 assert exists(f"{save_path}/{figure_name}.png"), "Could not find confusion matrix image file!" 32 33 34def test_generate_projection(tmpdir): 35 """Generate a projection using an untrained model. Used to exercise the code.""" 36 37 # Create a model, writer, and dataloader 38 model = BasicCNN() 39 log_path = str(tmpdir) 40 writer = SummaryWriter(log_dir=log_path) 41 _, dataloader_val, _, _ = get_mnist_dataloaders() 42 43 # Exercise plotting 44 generate_projection(model=model, dataloader_val=dataloader_val, writer=writer) 45 46 # Verify that logs were created (does not check if they are correct!) 47 assert exists( 48 f"{log_path}/projector_config.pbtxt" 49 ), "Could not find projector_config.pbtxt file!" 50 51 52def test_saliency_map(tmpdir): 53 """Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid. Used to exercise the code.""" 54 55 # Create a model 56 model = BasicCNN() 57 save_path = str(tmpdir) 58 figure_name = "test" 59 global_step = 12 60 model_weights_path = f"{save_path}/model.pkl" 61 torch.save(model.state_dict(), model_weights_path) 62 63 # Get a Dataset 64 _, _, _, dataset_val = get_mnist_dataloaders() 65 66 saliency_map( 67 model=model, 68 dataset_val=dataset_val, 69 figure_name=figure_name, 70 global_step=global_step, 71 model_weights_path=model_weights_path, 72 save_path=save_path, 73 ) 74 75 assert exists( 76 f"{save_path}/{figure_name}_{global_step}.png" 77 ), "Could not find saliency.png file!"
def
test_plot_confusion_matrix(tmpdir):
14def test_plot_confusion_matrix(tmpdir): 15 """Generate a confusion matrix using a dataloader. Used to exercise the code.""" 16 17 # Create a model and dataloader 18 model = BasicCNN() 19 _, dataloader_val, _, _ = get_mnist_dataloaders() 20 figure_name = "confusion_matrix_test" 21 22 # Exercise plotting 23 save_path = f"{tmpdir}" 24 plot_confusion_matrix( 25 model=model, 26 dataloader_val=dataloader_val, 27 save_path=save_path, 28 fig_name=figure_name, 29 ) 30 31 # Check that a plot was generated 32 assert exists(f"{save_path}/{figure_name}.png"), "Could not find confusion matrix image file!"
Generate a confusion matrix using a dataloader. Used to exercise the code.
def
test_generate_projection(tmpdir):
35def test_generate_projection(tmpdir): 36 """Generate a projection using an untrained model. Used to exercise the code.""" 37 38 # Create a model, writer, and dataloader 39 model = BasicCNN() 40 log_path = str(tmpdir) 41 writer = SummaryWriter(log_dir=log_path) 42 _, dataloader_val, _, _ = get_mnist_dataloaders() 43 44 # Exercise plotting 45 generate_projection(model=model, dataloader_val=dataloader_val, writer=writer) 46 47 # Verify that logs were created (does not check if they are correct!) 48 assert exists( 49 f"{log_path}/projector_config.pbtxt" 50 ), "Could not find projector_config.pbtxt file!"
Generate a projection using an untrained model. Used to exercise the code.
def
test_saliency_map(tmpdir):
53def test_saliency_map(tmpdir): 54 """Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid. Used to exercise the code.""" 55 56 # Create a model 57 model = BasicCNN() 58 save_path = str(tmpdir) 59 figure_name = "test" 60 global_step = 12 61 model_weights_path = f"{save_path}/model.pkl" 62 torch.save(model.state_dict(), model_weights_path) 63 64 # Get a Dataset 65 _, _, _, dataset_val = get_mnist_dataloaders() 66 67 saliency_map( 68 model=model, 69 dataset_val=dataset_val, 70 figure_name=figure_name, 71 global_step=global_step, 72 model_weights_path=model_weights_path, 73 save_path=save_path, 74 ) 75 76 assert exists( 77 f"{save_path}/{figure_name}_{global_step}.png" 78 ), "Could not find saliency.png file!"
Generate a saliency map of the some sample images with the network's focus. This is using an untrained model so that focus areas will not be valid. Used to exercise the code.