src.plot
1import subprocess 2from copy import deepcopy 3 4import matplotlib.pyplot as plt 5import torch 6from sklearn.metrics import confusion_matrix 7from sklearn.metrics import ConfusionMatrixDisplay 8from torch import nn 9from torch.utils.data import DataLoader 10from torch.utils.data import Dataset 11from torch.utils.tensorboard.writer import SummaryWriter 12 13 14def plot_confusion_matrix( 15 model: nn.Module, 16 dataloader_val: DataLoader, 17 save_path: str = "save", 18 fig_name: str = "", 19 fig_x_size: int = 8, 20 fig_y_size: int = 8, 21 dpi: int = 250, 22): 23 """Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data. We compare the highest predicted value against the true label and generate a confusion matrix. 24 25 Args: 26 model (nn.Module): Torch model. 27 dataloader_val (DataLoader): DataLoader that outputs `(x, y)`. 28 save_path (str, optional): Path to save plots. Defaults to "save". 29 fig_name (str, optional): Figure name. Defaults to "". 30 fig_x_size (int, optional): Figure x size. Defaults to 8. 31 fig_y_size (int, optional): Figure y size. Defaults to 8. 32 dpi (int, optional): Increase for higher resolution. Defaults to 250. 33 """ 34 35 # Make dirs 36 subprocess.run(["mkdir", "-p", save_path]) 37 38 # Preallocate lists 39 y_pred = [] 40 y_true = [] 41 42 # Iterate over all inputs and labels for the data iterator 43 for x, y in dataloader_val: 44 # Perform a forward pass with the input data x and get highest predicted value 45 y_hat, _ = model(x) 46 y_idx = torch.argmax(y_hat, 1) 47 48 # Add the prediction and true label to the preallocated list 49 y_pred.extend(list(y_idx.numpy())) 50 y_true.extend(list(y.numpy())) 51 52 # Generate the confusion matrix using the predictions versus the true labels 53 cm = confusion_matrix(y_true, y_pred) 54 55 # Generate a figure and plot the results 56 figure = plt.figure(figsize=(fig_x_size, fig_y_size)) 57 disp = ConfusionMatrixDisplay(confusion_matrix=cm) 58 disp.plot() 59 60 # Add axes labels 61 plt.title("Confusion Matrix") 62 plt.ylabel("Predictions") 63 plt.xlabel("Truth") 64 65 # Save off the plot 66 plt.savefig(f"{save_path}/{fig_name}.png", dpi=dpi) 67 plt.close("all") 68 69 70def generate_projection( 71 model: nn.Module, 72 dataloader_val: DataLoader, 73 writer: SummaryWriter, 74 global_step: int = 0, 75 projection_limit: int = 1000, 76): 77 """Create a projection of the data in N-dimensional space. This can be visualized using Tensorboard. 78 79 Args: 80 model (nn.Module): Torch model. 81 dataloader_val (DataLoader): Dataloader that outputs `(x, y)`. 82 writer (SummaryWriter): Tensorboard writer. 83 global_step (int, optional): The global step tracker. Defaults to 0. 84 projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000. 85 """ 86 # Initialize variables 87 embeddings = [] 88 labels = [] 89 images = [] 90 91 # Iterate over all data 92 for x, y in dataloader_val: 93 # Remove gradient tape 94 with torch.no_grad(): 95 _, embedding = model(x) # get the embeddings (not predictions of the model) 96 embeddings.append(embedding) # save off embeddings 97 98 # Append the data 99 labels.append(y) 100 images.append(x) 101 102 # Stack the output embeddings 103 stacked_embeddings = torch.vstack(embeddings) 104 stacked_labels = torch.hstack(labels).numpy() # convert to numpy array instead of tensors 105 stacked_images = torch.vstack(images) 106 107 # Tensorboard is only able to plot a maximum number of data points in its projection. Thus, we will limit the number of data points we will display. 108 mean, std, _ = ( 109 torch.mean(stacked_embeddings), 110 torch.std(stacked_embeddings), 111 torch.var(stacked_embeddings), 112 ) 113 stacked_embeddings = (stacked_embeddings - mean) / std 114 stacked_embeddings = stacked_embeddings[0:projection_limit, :] 115 stacked_labels = stacked_labels[0:projection_limit] 116 stacked_images = stacked_images[0:projection_limit, :, :, :] 117 118 # Add the embedding information to Tensorboard logs 119 writer.add_embedding( 120 stacked_embeddings, 121 metadata=stacked_labels, 122 # label_img=stacked_images, # turn on to display images for each data point! 123 global_step=global_step, 124 ) 125 126 127def saliency_map( 128 model: nn.Module, 129 dataset_val: Dataset, 130 figure_name: str = "saliency", 131 global_step: int = 0, 132 model_weights_path: str | None = None, 133 save_path: str = "save", 134): 135 """Create a saliency map showing which inputs contributed the most toward the predictions. 136 137 Args: 138 model (nn.Module): Torch model. 139 dataset_val (Dataset): The Dataset which provides `(x, y)`. 140 figure_name (str): The name of the figure to save as file. Defaults to "saliency". 141 global_step (int): The global step for tracking. Defaults to 0. 142 model_weights_path (str | None): Torch model weights path. 143 save_path (str): Save path. 144 """ 145 146 # Create dirs if they do not exists 147 subprocess.run(["mkdir", "-p", save_path]) 148 149 # configuration (must be even number!) 150 n_rows = 5 151 n_cols = 10 152 153 # Copy model 154 nn_model = deepcopy(model) 155 if model_weights_path is not None: # Load weights into model if available 156 nn_model.load_state_dict(torch.load(model_weights_path)) 157 158 data_iter = iter(dataset_val) 159 data_idx = 0 160 161 for ii in range(0, n_rows, 5): 162 for jj in range(1, n_cols + 1, 1): 163 x, y = next(data_iter) 164 data_idx += 1 165 166 # Calculate gradients w.r.t. input from output 167 x_base = torch.ones((1, 28, 28)) * -1 168 x_pred = x 169 170 x_base.requires_grad = True # set gradient tape to True 171 x_pred.requires_grad = True # set gradient tape to True 172 173 y_base, _ = nn_model(x_base) # forward prop baseline 174 y_pred, _ = nn_model(x_pred) # forward prop 175 176 y_base.sum().backward() # baseline backpropagation 177 y_pred.sum().backward() # prediction backpropagation 178 179 # Normalize the saliency plot 180 img_source = x.squeeze().detach().numpy() 181 img_baseline = torch.abs(x_base.grad.squeeze()) 182 img_saliency = torch.abs(x_pred.grad.squeeze()) 183 img_delta = x_pred.grad.squeeze() - x_base.grad.squeeze() 184 img_overlay = img_saliency * img_source 185 186 # Generate subplots 187 column_idx = ii * n_cols + jj 188 ax1 = plt.subplot(n_rows, n_cols, column_idx + 0 * n_cols) 189 ax2 = plt.subplot(n_rows, n_cols, column_idx + 1 * n_cols) 190 ax3 = plt.subplot(n_rows, n_cols, column_idx + 2 * n_cols) 191 ax4 = plt.subplot(n_rows, n_cols, column_idx + 3 * n_cols) 192 ax5 = plt.subplot(n_rows, n_cols, column_idx + 4 * n_cols) 193 194 # plot images 195 ax1.imshow(img_source, cmap=plt.cm.viridis, aspect="auto") 196 ax2.imshow(img_baseline, cmap=plt.cm.viridis, aspect="auto") 197 ax3.imshow(img_saliency, cmap=plt.cm.viridis, aspect="auto") 198 ax4.imshow(img_delta, cmap=plt.cm.viridis, aspect="auto") 199 ax5.imshow(img_overlay, cmap=plt.cm.viridis, aspect="auto") 200 201 # Add labels for first column only 202 if column_idx == 1: 203 ax1.set_ylabel("source") 204 ax2.set_ylabel("baseline") 205 ax3.set_ylabel("saliency") 206 ax4.set_ylabel("delta") 207 ax5.set_ylabel("overlay") 208 209 # Remove all ticks 210 remove_ticks(ax1) 211 remove_ticks(ax2) 212 remove_ticks(ax3) 213 remove_ticks(ax4) 214 remove_ticks(ax5) 215 216 # Save figure 217 plt.savefig(f"{save_path}/{figure_name}_{global_step}.png", dpi=250) 218 219 220from typing import Type 221 222 223def remove_ticks(axis): 224 axis.tick_params( 225 axis="y", # changes apply to the x-axis 226 which="both", # both major and minor ticks are affected 227 left=False, # ticks along the bottom edge are off 228 labelleft=False, # labels along the left edge are off 229 ) 230 231 axis.tick_params( 232 axis="x", # changes apply to the x-axis 233 which="both", # both major and minor ticks are affected 234 bottom=False, # ticks along the bottom edge are off 235 top=False, # ticks along the top edge are off 236 labelbottom=False, # labels along the bottom edge are off 237 )
def
plot_confusion_matrix( model: torch.nn.modules.module.Module, dataloader_val: torch.utils.data.dataloader.DataLoader, save_path: str = 'save', fig_name: str = '', fig_x_size: int = 8, fig_y_size: int = 8, dpi: int = 250):
15def plot_confusion_matrix( 16 model: nn.Module, 17 dataloader_val: DataLoader, 18 save_path: str = "save", 19 fig_name: str = "", 20 fig_x_size: int = 8, 21 fig_y_size: int = 8, 22 dpi: int = 250, 23): 24 """Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data. We compare the highest predicted value against the true label and generate a confusion matrix. 25 26 Args: 27 model (nn.Module): Torch model. 28 dataloader_val (DataLoader): DataLoader that outputs `(x, y)`. 29 save_path (str, optional): Path to save plots. Defaults to "save". 30 fig_name (str, optional): Figure name. Defaults to "". 31 fig_x_size (int, optional): Figure x size. Defaults to 8. 32 fig_y_size (int, optional): Figure y size. Defaults to 8. 33 dpi (int, optional): Increase for higher resolution. Defaults to 250. 34 """ 35 36 # Make dirs 37 subprocess.run(["mkdir", "-p", save_path]) 38 39 # Preallocate lists 40 y_pred = [] 41 y_true = [] 42 43 # Iterate over all inputs and labels for the data iterator 44 for x, y in dataloader_val: 45 # Perform a forward pass with the input data x and get highest predicted value 46 y_hat, _ = model(x) 47 y_idx = torch.argmax(y_hat, 1) 48 49 # Add the prediction and true label to the preallocated list 50 y_pred.extend(list(y_idx.numpy())) 51 y_true.extend(list(y.numpy())) 52 53 # Generate the confusion matrix using the predictions versus the true labels 54 cm = confusion_matrix(y_true, y_pred) 55 56 # Generate a figure and plot the results 57 figure = plt.figure(figsize=(fig_x_size, fig_y_size)) 58 disp = ConfusionMatrixDisplay(confusion_matrix=cm) 59 disp.plot() 60 61 # Add axes labels 62 plt.title("Confusion Matrix") 63 plt.ylabel("Predictions") 64 plt.xlabel("Truth") 65 66 # Save off the plot 67 plt.savefig(f"{save_path}/{fig_name}.png", dpi=dpi) 68 plt.close("all")
Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data. We compare the highest predicted value against the true label and generate a confusion matrix.
Arguments:
- model (nn.Module): Torch model.
- dataloader_val (DataLoader): DataLoader that outputs
(x, y). - save_path (str, optional): Path to save plots. Defaults to "save".
- fig_name (str, optional): Figure name. Defaults to "".
- fig_x_size (int, optional): Figure x size. Defaults to 8.
- fig_y_size (int, optional): Figure y size. Defaults to 8.
- dpi (int, optional): Increase for higher resolution. Defaults to 250.
def
generate_projection( model: torch.nn.modules.module.Module, dataloader_val: torch.utils.data.dataloader.DataLoader, writer: torch.utils.tensorboard.writer.SummaryWriter, global_step: int = 0, projection_limit: int = 1000):
71def generate_projection( 72 model: nn.Module, 73 dataloader_val: DataLoader, 74 writer: SummaryWriter, 75 global_step: int = 0, 76 projection_limit: int = 1000, 77): 78 """Create a projection of the data in N-dimensional space. This can be visualized using Tensorboard. 79 80 Args: 81 model (nn.Module): Torch model. 82 dataloader_val (DataLoader): Dataloader that outputs `(x, y)`. 83 writer (SummaryWriter): Tensorboard writer. 84 global_step (int, optional): The global step tracker. Defaults to 0. 85 projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000. 86 """ 87 # Initialize variables 88 embeddings = [] 89 labels = [] 90 images = [] 91 92 # Iterate over all data 93 for x, y in dataloader_val: 94 # Remove gradient tape 95 with torch.no_grad(): 96 _, embedding = model(x) # get the embeddings (not predictions of the model) 97 embeddings.append(embedding) # save off embeddings 98 99 # Append the data 100 labels.append(y) 101 images.append(x) 102 103 # Stack the output embeddings 104 stacked_embeddings = torch.vstack(embeddings) 105 stacked_labels = torch.hstack(labels).numpy() # convert to numpy array instead of tensors 106 stacked_images = torch.vstack(images) 107 108 # Tensorboard is only able to plot a maximum number of data points in its projection. Thus, we will limit the number of data points we will display. 109 mean, std, _ = ( 110 torch.mean(stacked_embeddings), 111 torch.std(stacked_embeddings), 112 torch.var(stacked_embeddings), 113 ) 114 stacked_embeddings = (stacked_embeddings - mean) / std 115 stacked_embeddings = stacked_embeddings[0:projection_limit, :] 116 stacked_labels = stacked_labels[0:projection_limit] 117 stacked_images = stacked_images[0:projection_limit, :, :, :] 118 119 # Add the embedding information to Tensorboard logs 120 writer.add_embedding( 121 stacked_embeddings, 122 metadata=stacked_labels, 123 # label_img=stacked_images, # turn on to display images for each data point! 124 global_step=global_step, 125 )
Create a projection of the data in N-dimensional space. This can be visualized using Tensorboard.
Arguments:
- model (nn.Module): Torch model.
- dataloader_val (DataLoader): Dataloader that outputs
(x, y). - writer (SummaryWriter): Tensorboard writer.
- global_step (int, optional): The global step tracker. Defaults to 0.
- projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000.
def
saliency_map( model: torch.nn.modules.module.Module, dataset_val: torch.utils.data.dataset.Dataset, figure_name: str = 'saliency', global_step: int = 0, model_weights_path: str | None = None, save_path: str = 'save'):
128def saliency_map( 129 model: nn.Module, 130 dataset_val: Dataset, 131 figure_name: str = "saliency", 132 global_step: int = 0, 133 model_weights_path: str | None = None, 134 save_path: str = "save", 135): 136 """Create a saliency map showing which inputs contributed the most toward the predictions. 137 138 Args: 139 model (nn.Module): Torch model. 140 dataset_val (Dataset): The Dataset which provides `(x, y)`. 141 figure_name (str): The name of the figure to save as file. Defaults to "saliency". 142 global_step (int): The global step for tracking. Defaults to 0. 143 model_weights_path (str | None): Torch model weights path. 144 save_path (str): Save path. 145 """ 146 147 # Create dirs if they do not exists 148 subprocess.run(["mkdir", "-p", save_path]) 149 150 # configuration (must be even number!) 151 n_rows = 5 152 n_cols = 10 153 154 # Copy model 155 nn_model = deepcopy(model) 156 if model_weights_path is not None: # Load weights into model if available 157 nn_model.load_state_dict(torch.load(model_weights_path)) 158 159 data_iter = iter(dataset_val) 160 data_idx = 0 161 162 for ii in range(0, n_rows, 5): 163 for jj in range(1, n_cols + 1, 1): 164 x, y = next(data_iter) 165 data_idx += 1 166 167 # Calculate gradients w.r.t. input from output 168 x_base = torch.ones((1, 28, 28)) * -1 169 x_pred = x 170 171 x_base.requires_grad = True # set gradient tape to True 172 x_pred.requires_grad = True # set gradient tape to True 173 174 y_base, _ = nn_model(x_base) # forward prop baseline 175 y_pred, _ = nn_model(x_pred) # forward prop 176 177 y_base.sum().backward() # baseline backpropagation 178 y_pred.sum().backward() # prediction backpropagation 179 180 # Normalize the saliency plot 181 img_source = x.squeeze().detach().numpy() 182 img_baseline = torch.abs(x_base.grad.squeeze()) 183 img_saliency = torch.abs(x_pred.grad.squeeze()) 184 img_delta = x_pred.grad.squeeze() - x_base.grad.squeeze() 185 img_overlay = img_saliency * img_source 186 187 # Generate subplots 188 column_idx = ii * n_cols + jj 189 ax1 = plt.subplot(n_rows, n_cols, column_idx + 0 * n_cols) 190 ax2 = plt.subplot(n_rows, n_cols, column_idx + 1 * n_cols) 191 ax3 = plt.subplot(n_rows, n_cols, column_idx + 2 * n_cols) 192 ax4 = plt.subplot(n_rows, n_cols, column_idx + 3 * n_cols) 193 ax5 = plt.subplot(n_rows, n_cols, column_idx + 4 * n_cols) 194 195 # plot images 196 ax1.imshow(img_source, cmap=plt.cm.viridis, aspect="auto") 197 ax2.imshow(img_baseline, cmap=plt.cm.viridis, aspect="auto") 198 ax3.imshow(img_saliency, cmap=plt.cm.viridis, aspect="auto") 199 ax4.imshow(img_delta, cmap=plt.cm.viridis, aspect="auto") 200 ax5.imshow(img_overlay, cmap=plt.cm.viridis, aspect="auto") 201 202 # Add labels for first column only 203 if column_idx == 1: 204 ax1.set_ylabel("source") 205 ax2.set_ylabel("baseline") 206 ax3.set_ylabel("saliency") 207 ax4.set_ylabel("delta") 208 ax5.set_ylabel("overlay") 209 210 # Remove all ticks 211 remove_ticks(ax1) 212 remove_ticks(ax2) 213 remove_ticks(ax3) 214 remove_ticks(ax4) 215 remove_ticks(ax5) 216 217 # Save figure 218 plt.savefig(f"{save_path}/{figure_name}_{global_step}.png", dpi=250)
Create a saliency map showing which inputs contributed the most toward the predictions.
Arguments:
- model (nn.Module): Torch model.
- dataset_val (Dataset): The Dataset which provides
(x, y). - figure_name (str): The name of the figure to save as file. Defaults to "saliency".
- global_step (int): The global step for tracking. Defaults to 0.
- model_weights_path (str | None): Torch model weights path.
- save_path (str): Save path.
def
remove_ticks(axis):
224def remove_ticks(axis): 225 axis.tick_params( 226 axis="y", # changes apply to the x-axis 227 which="both", # both major and minor ticks are affected 228 left=False, # ticks along the bottom edge are off 229 labelleft=False, # labels along the left edge are off 230 ) 231 232 axis.tick_params( 233 axis="x", # changes apply to the x-axis 234 which="both", # both major and minor ticks are affected 235 bottom=False, # ticks along the bottom edge are off 236 top=False, # ticks along the top edge are off 237 labelbottom=False, # labels along the bottom edge are off 238 )