src.plot

  1import subprocess
  2from copy import deepcopy
  3
  4import matplotlib.pyplot as plt
  5import torch
  6from sklearn.metrics import confusion_matrix
  7from sklearn.metrics import ConfusionMatrixDisplay
  8from torch import nn
  9from torch.utils.data import DataLoader
 10from torch.utils.data import Dataset
 11from torch.utils.tensorboard.writer import SummaryWriter
 12
 13
 14def plot_confusion_matrix(
 15    model: nn.Module,
 16    dataloader_val: DataLoader,
 17    save_path: str = "save",
 18    fig_name: str = "",
 19    fig_x_size: int = 8,
 20    fig_y_size: int = 8,
 21    dpi: int = 250,
 22):
 23    """Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data.  We compare the highest predicted value against the true label and generate a confusion matrix.
 24
 25    Args:
 26        model (nn.Module): Torch model.
 27        dataloader_val (DataLoader): DataLoader that outputs `(x, y)`.
 28        save_path (str, optional): Path to save plots. Defaults to "save".
 29        fig_name (str, optional): Figure name. Defaults to "".
 30        fig_x_size (int, optional): Figure x size. Defaults to 8.
 31        fig_y_size (int, optional): Figure y size. Defaults to 8.
 32        dpi (int, optional): Increase for higher resolution. Defaults to 250.
 33    """
 34
 35    # Make dirs
 36    subprocess.run(["mkdir", "-p", save_path])
 37
 38    # Preallocate lists
 39    y_pred = []
 40    y_true = []
 41
 42    # Iterate over all inputs and labels for the data iterator
 43    for x, y in dataloader_val:
 44        # Perform a forward pass with the input data x and get highest predicted value
 45        y_hat, _ = model(x)
 46        y_idx = torch.argmax(y_hat, 1)
 47
 48        # Add the prediction and true label to the preallocated list
 49        y_pred.extend(list(y_idx.numpy()))
 50        y_true.extend(list(y.numpy()))
 51
 52    # Generate the confusion matrix using the predictions versus the true labels
 53    cm = confusion_matrix(y_true, y_pred)
 54
 55    # Generate a figure and plot the results
 56    figure = plt.figure(figsize=(fig_x_size, fig_y_size))
 57    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
 58    disp.plot()
 59
 60    # Add axes labels
 61    plt.title("Confusion Matrix")
 62    plt.ylabel("Predictions")
 63    plt.xlabel("Truth")
 64
 65    # Save off the plot
 66    plt.savefig(f"{save_path}/{fig_name}.png", dpi=dpi)
 67    plt.close("all")
 68
 69
 70def generate_projection(
 71    model: nn.Module,
 72    dataloader_val: DataLoader,
 73    writer: SummaryWriter,
 74    global_step: int = 0,
 75    projection_limit: int = 1000,
 76):
 77    """Create a projection of the data in N-dimensional space.  This can be visualized using Tensorboard.
 78
 79    Args:
 80        model (nn.Module): Torch model.
 81        dataloader_val (DataLoader): Dataloader that outputs `(x, y)`.
 82        writer (SummaryWriter): Tensorboard writer.
 83        global_step (int, optional): The global step tracker. Defaults to 0.
 84        projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000.
 85    """
 86    # Initialize variables
 87    embeddings = []
 88    labels = []
 89    images = []
 90
 91    # Iterate over all data
 92    for x, y in dataloader_val:
 93        # Remove gradient tape
 94        with torch.no_grad():
 95            _, embedding = model(x)  # get the embeddings (not predictions of the model)
 96            embeddings.append(embedding)  # save off embeddings
 97
 98        # Append the data
 99        labels.append(y)
100        images.append(x)
101
102    # Stack the output embeddings
103    stacked_embeddings = torch.vstack(embeddings)
104    stacked_labels = torch.hstack(labels).numpy()  # convert to numpy array instead of tensors
105    stacked_images = torch.vstack(images)
106
107    # Tensorboard is only able to plot a maximum number of data points in its projection.  Thus, we will limit the number of data points we will display.
108    mean, std, _ = (
109        torch.mean(stacked_embeddings),
110        torch.std(stacked_embeddings),
111        torch.var(stacked_embeddings),
112    )
113    stacked_embeddings = (stacked_embeddings - mean) / std
114    stacked_embeddings = stacked_embeddings[0:projection_limit, :]
115    stacked_labels = stacked_labels[0:projection_limit]
116    stacked_images = stacked_images[0:projection_limit, :, :, :]
117
118    # Add the embedding information to Tensorboard logs
119    writer.add_embedding(
120        stacked_embeddings,
121        metadata=stacked_labels,
122        # label_img=stacked_images,     # turn on to display images for each data point!
123        global_step=global_step,
124    )
125
126
127def saliency_map(
128    model: nn.Module,
129    dataset_val: Dataset,
130    figure_name: str = "saliency",
131    global_step: int = 0,
132    model_weights_path: str | None = None,
133    save_path: str = "save",
134):
135    """Create a saliency map showing which inputs contributed the most toward the predictions.
136
137    Args:
138        model (nn.Module): Torch model.
139        dataset_val (Dataset): The Dataset which provides `(x, y)`.
140        figure_name (str): The name of the figure to save as file. Defaults to "saliency".
141        global_step (int): The global step for tracking. Defaults to 0.
142        model_weights_path (str | None): Torch model weights path.
143        save_path (str): Save path.
144    """
145
146    # Create dirs if they do not exists
147    subprocess.run(["mkdir", "-p", save_path])
148
149    # configuration (must be even number!)
150    n_rows = 5
151    n_cols = 10
152
153    # Copy model
154    nn_model = deepcopy(model)
155    if model_weights_path is not None:  # Load weights into model if available
156        nn_model.load_state_dict(torch.load(model_weights_path))
157
158    data_iter = iter(dataset_val)
159    data_idx = 0
160
161    for ii in range(0, n_rows, 5):
162        for jj in range(1, n_cols + 1, 1):
163            x, y = next(data_iter)
164            data_idx += 1
165
166            # Calculate gradients w.r.t. input from output
167            x_base = torch.ones((1, 28, 28)) * -1
168            x_pred = x
169
170            x_base.requires_grad = True  # set gradient tape to True
171            x_pred.requires_grad = True  # set gradient tape to True
172
173            y_base, _ = nn_model(x_base)  # forward prop baseline
174            y_pred, _ = nn_model(x_pred)  # forward prop
175
176            y_base.sum().backward()  # baseline backpropagation
177            y_pred.sum().backward()  # prediction backpropagation
178
179            # Normalize the saliency plot
180            img_source = x.squeeze().detach().numpy()
181            img_baseline = torch.abs(x_base.grad.squeeze())
182            img_saliency = torch.abs(x_pred.grad.squeeze())
183            img_delta = x_pred.grad.squeeze() - x_base.grad.squeeze()
184            img_overlay = img_saliency * img_source
185
186            # Generate subplots
187            column_idx = ii * n_cols + jj
188            ax1 = plt.subplot(n_rows, n_cols, column_idx + 0 * n_cols)
189            ax2 = plt.subplot(n_rows, n_cols, column_idx + 1 * n_cols)
190            ax3 = plt.subplot(n_rows, n_cols, column_idx + 2 * n_cols)
191            ax4 = plt.subplot(n_rows, n_cols, column_idx + 3 * n_cols)
192            ax5 = plt.subplot(n_rows, n_cols, column_idx + 4 * n_cols)
193
194            # plot images
195            ax1.imshow(img_source, cmap=plt.cm.viridis, aspect="auto")
196            ax2.imshow(img_baseline, cmap=plt.cm.viridis, aspect="auto")
197            ax3.imshow(img_saliency, cmap=plt.cm.viridis, aspect="auto")
198            ax4.imshow(img_delta, cmap=plt.cm.viridis, aspect="auto")
199            ax5.imshow(img_overlay, cmap=plt.cm.viridis, aspect="auto")
200
201            # Add labels for first column only
202            if column_idx == 1:
203                ax1.set_ylabel("source")
204                ax2.set_ylabel("baseline")
205                ax3.set_ylabel("saliency")
206                ax4.set_ylabel("delta")
207                ax5.set_ylabel("overlay")
208
209            # Remove all ticks
210            remove_ticks(ax1)
211            remove_ticks(ax2)
212            remove_ticks(ax3)
213            remove_ticks(ax4)
214            remove_ticks(ax5)
215
216    # Save figure
217    plt.savefig(f"{save_path}/{figure_name}_{global_step}.png", dpi=250)
218
219
220from typing import Type
221
222
223def remove_ticks(axis):
224    axis.tick_params(
225        axis="y",  # changes apply to the x-axis
226        which="both",  # both major and minor ticks are affected
227        left=False,  # ticks along the bottom edge are off
228        labelleft=False,  # labels along the left edge are off
229    )
230
231    axis.tick_params(
232        axis="x",  # changes apply to the x-axis
233        which="both",  # both major and minor ticks are affected
234        bottom=False,  # ticks along the bottom edge are off
235        top=False,  # ticks along the top edge are off
236        labelbottom=False,  # labels along the bottom edge are off
237    )
def plot_confusion_matrix( model: torch.nn.modules.module.Module, dataloader_val: torch.utils.data.dataloader.DataLoader, save_path: str = 'save', fig_name: str = '', fig_x_size: int = 8, fig_y_size: int = 8, dpi: int = 250):
15def plot_confusion_matrix(
16    model: nn.Module,
17    dataloader_val: DataLoader,
18    save_path: str = "save",
19    fig_name: str = "",
20    fig_x_size: int = 8,
21    fig_y_size: int = 8,
22    dpi: int = 250,
23):
24    """Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data.  We compare the highest predicted value against the true label and generate a confusion matrix.
25
26    Args:
27        model (nn.Module): Torch model.
28        dataloader_val (DataLoader): DataLoader that outputs `(x, y)`.
29        save_path (str, optional): Path to save plots. Defaults to "save".
30        fig_name (str, optional): Figure name. Defaults to "".
31        fig_x_size (int, optional): Figure x size. Defaults to 8.
32        fig_y_size (int, optional): Figure y size. Defaults to 8.
33        dpi (int, optional): Increase for higher resolution. Defaults to 250.
34    """
35
36    # Make dirs
37    subprocess.run(["mkdir", "-p", save_path])
38
39    # Preallocate lists
40    y_pred = []
41    y_true = []
42
43    # Iterate over all inputs and labels for the data iterator
44    for x, y in dataloader_val:
45        # Perform a forward pass with the input data x and get highest predicted value
46        y_hat, _ = model(x)
47        y_idx = torch.argmax(y_hat, 1)
48
49        # Add the prediction and true label to the preallocated list
50        y_pred.extend(list(y_idx.numpy()))
51        y_true.extend(list(y.numpy()))
52
53    # Generate the confusion matrix using the predictions versus the true labels
54    cm = confusion_matrix(y_true, y_pred)
55
56    # Generate a figure and plot the results
57    figure = plt.figure(figsize=(fig_x_size, fig_y_size))
58    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
59    disp.plot()
60
61    # Add axes labels
62    plt.title("Confusion Matrix")
63    plt.ylabel("Predictions")
64    plt.xlabel("Truth")
65
66    # Save off the plot
67    plt.savefig(f"{save_path}/{fig_name}.png", dpi=dpi)
68    plt.close("all")

Given a model and a DataLoader (which provides inputs and labels), evaluate how well the model makes predictions on the data. We compare the highest predicted value against the true label and generate a confusion matrix.

Arguments:
  • model (nn.Module): Torch model.
  • dataloader_val (DataLoader): DataLoader that outputs (x, y).
  • save_path (str, optional): Path to save plots. Defaults to "save".
  • fig_name (str, optional): Figure name. Defaults to "".
  • fig_x_size (int, optional): Figure x size. Defaults to 8.
  • fig_y_size (int, optional): Figure y size. Defaults to 8.
  • dpi (int, optional): Increase for higher resolution. Defaults to 250.
def generate_projection( model: torch.nn.modules.module.Module, dataloader_val: torch.utils.data.dataloader.DataLoader, writer: torch.utils.tensorboard.writer.SummaryWriter, global_step: int = 0, projection_limit: int = 1000):
 71def generate_projection(
 72    model: nn.Module,
 73    dataloader_val: DataLoader,
 74    writer: SummaryWriter,
 75    global_step: int = 0,
 76    projection_limit: int = 1000,
 77):
 78    """Create a projection of the data in N-dimensional space.  This can be visualized using Tensorboard.
 79
 80    Args:
 81        model (nn.Module): Torch model.
 82        dataloader_val (DataLoader): Dataloader that outputs `(x, y)`.
 83        writer (SummaryWriter): Tensorboard writer.
 84        global_step (int, optional): The global step tracker. Defaults to 0.
 85        projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000.
 86    """
 87    # Initialize variables
 88    embeddings = []
 89    labels = []
 90    images = []
 91
 92    # Iterate over all data
 93    for x, y in dataloader_val:
 94        # Remove gradient tape
 95        with torch.no_grad():
 96            _, embedding = model(x)  # get the embeddings (not predictions of the model)
 97            embeddings.append(embedding)  # save off embeddings
 98
 99        # Append the data
100        labels.append(y)
101        images.append(x)
102
103    # Stack the output embeddings
104    stacked_embeddings = torch.vstack(embeddings)
105    stacked_labels = torch.hstack(labels).numpy()  # convert to numpy array instead of tensors
106    stacked_images = torch.vstack(images)
107
108    # Tensorboard is only able to plot a maximum number of data points in its projection.  Thus, we will limit the number of data points we will display.
109    mean, std, _ = (
110        torch.mean(stacked_embeddings),
111        torch.std(stacked_embeddings),
112        torch.var(stacked_embeddings),
113    )
114    stacked_embeddings = (stacked_embeddings - mean) / std
115    stacked_embeddings = stacked_embeddings[0:projection_limit, :]
116    stacked_labels = stacked_labels[0:projection_limit]
117    stacked_images = stacked_images[0:projection_limit, :, :, :]
118
119    # Add the embedding information to Tensorboard logs
120    writer.add_embedding(
121        stacked_embeddings,
122        metadata=stacked_labels,
123        # label_img=stacked_images,     # turn on to display images for each data point!
124        global_step=global_step,
125    )

Create a projection of the data in N-dimensional space. This can be visualized using Tensorboard.

Arguments:
  • model (nn.Module): Torch model.
  • dataloader_val (DataLoader): Dataloader that outputs (x, y).
  • writer (SummaryWriter): Tensorboard writer.
  • global_step (int, optional): The global step tracker. Defaults to 0.
  • projection_limit (int, optional): The maximum number of projections allowed. Defaults to 1000.
def saliency_map( model: torch.nn.modules.module.Module, dataset_val: torch.utils.data.dataset.Dataset, figure_name: str = 'saliency', global_step: int = 0, model_weights_path: str | None = None, save_path: str = 'save'):
128def saliency_map(
129    model: nn.Module,
130    dataset_val: Dataset,
131    figure_name: str = "saliency",
132    global_step: int = 0,
133    model_weights_path: str | None = None,
134    save_path: str = "save",
135):
136    """Create a saliency map showing which inputs contributed the most toward the predictions.
137
138    Args:
139        model (nn.Module): Torch model.
140        dataset_val (Dataset): The Dataset which provides `(x, y)`.
141        figure_name (str): The name of the figure to save as file. Defaults to "saliency".
142        global_step (int): The global step for tracking. Defaults to 0.
143        model_weights_path (str | None): Torch model weights path.
144        save_path (str): Save path.
145    """
146
147    # Create dirs if they do not exists
148    subprocess.run(["mkdir", "-p", save_path])
149
150    # configuration (must be even number!)
151    n_rows = 5
152    n_cols = 10
153
154    # Copy model
155    nn_model = deepcopy(model)
156    if model_weights_path is not None:  # Load weights into model if available
157        nn_model.load_state_dict(torch.load(model_weights_path))
158
159    data_iter = iter(dataset_val)
160    data_idx = 0
161
162    for ii in range(0, n_rows, 5):
163        for jj in range(1, n_cols + 1, 1):
164            x, y = next(data_iter)
165            data_idx += 1
166
167            # Calculate gradients w.r.t. input from output
168            x_base = torch.ones((1, 28, 28)) * -1
169            x_pred = x
170
171            x_base.requires_grad = True  # set gradient tape to True
172            x_pred.requires_grad = True  # set gradient tape to True
173
174            y_base, _ = nn_model(x_base)  # forward prop baseline
175            y_pred, _ = nn_model(x_pred)  # forward prop
176
177            y_base.sum().backward()  # baseline backpropagation
178            y_pred.sum().backward()  # prediction backpropagation
179
180            # Normalize the saliency plot
181            img_source = x.squeeze().detach().numpy()
182            img_baseline = torch.abs(x_base.grad.squeeze())
183            img_saliency = torch.abs(x_pred.grad.squeeze())
184            img_delta = x_pred.grad.squeeze() - x_base.grad.squeeze()
185            img_overlay = img_saliency * img_source
186
187            # Generate subplots
188            column_idx = ii * n_cols + jj
189            ax1 = plt.subplot(n_rows, n_cols, column_idx + 0 * n_cols)
190            ax2 = plt.subplot(n_rows, n_cols, column_idx + 1 * n_cols)
191            ax3 = plt.subplot(n_rows, n_cols, column_idx + 2 * n_cols)
192            ax4 = plt.subplot(n_rows, n_cols, column_idx + 3 * n_cols)
193            ax5 = plt.subplot(n_rows, n_cols, column_idx + 4 * n_cols)
194
195            # plot images
196            ax1.imshow(img_source, cmap=plt.cm.viridis, aspect="auto")
197            ax2.imshow(img_baseline, cmap=plt.cm.viridis, aspect="auto")
198            ax3.imshow(img_saliency, cmap=plt.cm.viridis, aspect="auto")
199            ax4.imshow(img_delta, cmap=plt.cm.viridis, aspect="auto")
200            ax5.imshow(img_overlay, cmap=plt.cm.viridis, aspect="auto")
201
202            # Add labels for first column only
203            if column_idx == 1:
204                ax1.set_ylabel("source")
205                ax2.set_ylabel("baseline")
206                ax3.set_ylabel("saliency")
207                ax4.set_ylabel("delta")
208                ax5.set_ylabel("overlay")
209
210            # Remove all ticks
211            remove_ticks(ax1)
212            remove_ticks(ax2)
213            remove_ticks(ax3)
214            remove_ticks(ax4)
215            remove_ticks(ax5)
216
217    # Save figure
218    plt.savefig(f"{save_path}/{figure_name}_{global_step}.png", dpi=250)

Create a saliency map showing which inputs contributed the most toward the predictions.

Arguments:
  • model (nn.Module): Torch model.
  • dataset_val (Dataset): The Dataset which provides (x, y).
  • figure_name (str): The name of the figure to save as file. Defaults to "saliency".
  • global_step (int): The global step for tracking. Defaults to 0.
  • model_weights_path (str | None): Torch model weights path.
  • save_path (str): Save path.
def remove_ticks(axis):
224def remove_ticks(axis):
225    axis.tick_params(
226        axis="y",  # changes apply to the x-axis
227        which="both",  # both major and minor ticks are affected
228        left=False,  # ticks along the bottom edge are off
229        labelleft=False,  # labels along the left edge are off
230    )
231
232    axis.tick_params(
233        axis="x",  # changes apply to the x-axis
234        which="both",  # both major and minor ticks are affected
235        bottom=False,  # ticks along the bottom edge are off
236        top=False,  # ticks along the top edge are off
237        labelbottom=False,  # labels along the bottom edge are off
238    )