src.runner

  1import subprocess
  2
  3import numpy as np
  4import torch
  5from rich import print
  6from rich.panel import Panel
  7from torch import nn
  8from torch.nn import CrossEntropyLoss
  9from torch.optim import Adam
 10from torch.utils.data import DataLoader
 11from torch.utils.data import Dataset
 12from torch.utils.tensorboard.writer import SummaryWriter
 13
 14from src.data import get_mnist_dataloaders
 15from src.eval import eval_iteration
 16from src.model import BasicCNN
 17from src.plot import generate_projection
 18from src.plot import plot_confusion_matrix
 19from src.plot import saliency_map
 20from src.train import generate_canocial
 21from src.train import training_iteration
 22
 23# set global seed
 24torch.manual_seed(0)
 25
 26
 27def run_pipeline(
 28    model: nn.Module,
 29    dataloader_train: DataLoader,
 30    dataloader_val: DataLoader,
 31    dataset_train: Dataset,
 32    dataset_val: Dataset,
 33    n_epochs: int,
 34    min_loss: float,
 35    writer: SummaryWriter,
 36    model_weights_path: str,
 37):
 38    """The full training pipeline performs a loop over the following:
 39
 40    ```Loop
 41    1) Train the model
 42    2) Perform evaluation
 43    3) Record the average loss
 44        If the model has the best average loss:
 45            Save off the model
 46            Generate a confusion matrix
 47        Else:
 48            continue training
 49    ```
 50
 51    .. note::
 52        The `dataloaders` must provide `(x, y)` representing inputs and true labels.
 53
 54    Args:
 55        model (nn.Module): Torch model.
 56        dataloader_train (DataLoader): DataLoader with `(x, y)` data.
 57        dataloader_val (DataLoader): DataLoader with `(x, y)` data.
 58        dataset_train (Dataset): Datset with `(x, y)` data.
 59        dataset_val (Dataset): Datset with `(x, y)` data.
 60        n_epochs (int): Number of epochs to run training.
 61        min_loss (float): Minimium loss recorded for the training iterations.
 62        writer (SummaryWriter): Tensorboard writer.
 63        model_weights_path (str): Path to save off model.
 64    """
 65    # Create optimizer and loss functions
 66    optimizer = Adam(params=model.parameters(), lr=0.00001)
 67    loss_fn = CrossEntropyLoss()
 68
 69    # Perform training on the data
 70    for global_step in range(0, n_epochs):
 71        print(f"Global step: {global_step}")
 72
 73        # Perform training
 74        avg_train_loss = training_iteration(
 75            model=model,
 76            loss_fn=loss_fn,
 77            optimizer=optimizer,
 78            dataloader_train=dataloader_train,
 79            global_step=global_step,
 80            writer=writer,
 81        )
 82
 83        # Perform evaluation
 84        avg_eval_loss = eval_iteration(
 85            model=model,
 86            loss_fn=loss_fn,
 87            dataloader_val=dataloader_val,
 88            global_step=global_step,
 89            writer=writer,
 90        )
 91
 92        # Printout
 93        print(f"Epoch:               {global_step}/{n_epochs-1}")
 94        print(f"    Training   Loss: {avg_train_loss:.3f}")
 95        print(f"    Evaluation Loss: {avg_eval_loss:.3f}")
 96
 97        # Save off best model
 98        if avg_eval_loss < min_loss:
 99            # Update the min_loss as the new baseline
100            min_loss = avg_eval_loss
101
102            # Save off the best model
103            torch.save(model.state_dict(), model_weights_path)
104
105            # Run analysis
106            run_analysis(
107                model=model,
108                dataloader_val=dataloader_val,
109                dataset_val=dataset_val,
110                global_step=global_step,
111                writer=writer,
112            )
113
114            # Printout
115            print(f"Saving best model with evaluation loss of: {avg_eval_loss:.2f}")
116
117
118def run_analysis(
119    model: nn.Module,
120    dataloader_val: DataLoader,
121    dataset_val: Dataset,
122    global_step: int,
123    writer: SummaryWriter,
124):
125    """Run analysis on each epoch where the model improves.  This should log multiple metrics that can be used in post-analysis examination.
126
127    Args:
128        model (nn.Module): Torch model.
129        dataloader_val (DataLoader): Dataloader for validation.
130        dataset_val (Dataset): Datset for validation.
131        global_step (int): Global step for tracking.
132        writer (SummaryWriter): Tensorboard writer.
133    """
134
135    # Save off confusion matrix
136    plot_confusion_matrix(
137        model=model,
138        dataloader_val=dataloader_val,
139        save_path="save/plots/confusion_matrix",
140        fig_name=f"confusion_matrix_{global_step}",
141    )
142
143    # Generate projection logs
144    generate_projection(
145        model=model,
146        dataloader_val=dataloader_val,
147        writer=writer,
148        global_step=global_step,
149    )
150
151    # Generate saliency plot
152    saliency_map(
153        model=model,
154        dataset_val=dataset_val,
155        figure_name="saliency",
156        global_step=global_step,
157        save_path=f"save/plots/saliency",
158    )
159
160
161def run_mnist(
162    n_epochs: int = 2,
163    min_loss: float = np.inf,
164    batch_size: int = 32,
165    path_model: str = "save/models",
166    path_log: str = "save/logs",
167    path_plots: str = "save/plots",
168    model_name: str = "best_model.pkl",
169):
170    """Perform MNIST training.  This function performs the following:
171
172        1) Create save/log directories
173        2) Get data iterators for train and evaluation
174        3) Create Tensorboard writer
175        4) Create model
176        5) Execute training pipeline using the data iterators, Tensorboard writer, and model provided
177
178    Args:
179        n_epochs (int, optional): Number of epochs to perform training. Defaults to 2.
180        min_loss (float, optional): The minimum loss recorded thus far. Defaults to np.inf.
181        batch_size (int, optional): The batch size to use when retrieving data from data iterators. Defaults to 16.
182        path_model (str, optional): The path to the model. Defaults to "save/models".
183        path_log (str, optional): The path to the logs. Defaults to "logs".
184        model_name (str, optional): The name of the model. Defaults to "best_model.pkl".
185    """
186    # Fully defined path to model save file
187    model_weights_path = f"{path_model}/{model_name}"
188
189    # Create directories
190    subprocess.run(["mkdir", "-p", path_model])
191    subprocess.run(["mkdir", "-p", path_log])
192    subprocess.run(["mkdir", "-p", path_plots])
193
194    # Preparing pipeline
195    dataloader_train, dataloader_val, dataset_train, dataset_val = get_mnist_dataloaders(
196        batch_size=batch_size
197    )
198    writer = SummaryWriter(path_log)  # Tensorboard
199    model = BasicCNN()  # hyperparameters
200
201    # Preprocessing
202    saliency_map(
203        model,
204        dataset_val=dataset_val,
205        figure_name="saliency",
206        global_step=-1,
207        save_path=f"{path_plots}/saliency",
208    )
209
210    # Save off model architecture.  Requires that you pass in a torch tensor for forward pass in order for logging to happen.
211    writer.add_graph(model, torch.rand(1, 1, 28, 28))
212
213    # Train the model
214    run_pipeline(
215        model=model,
216        dataloader_train=dataloader_train,
217        dataloader_val=dataloader_val,
218        dataset_train=dataset_train,
219        dataset_val=dataset_val,
220        n_epochs=n_epochs,
221        min_loss=min_loss,
222        writer=writer,
223        model_weights_path=model_weights_path,
224    )
225
226    # TODO: generate_canocial is not currently working
227    writer.close()
228
229    # Printout
230    print(Panel("To open Tensorboard type:"))
231    print("     tensorboard --logdir=save/logs")
232    print("     :warning: Chrome is recommended for viewing projections!")
233
234
235if __name__ == "__main__":
236    run_mnist()
def run_pipeline( model: torch.nn.modules.module.Module, dataloader_train: torch.utils.data.dataloader.DataLoader, dataloader_val: torch.utils.data.dataloader.DataLoader, dataset_train: torch.utils.data.dataset.Dataset, dataset_val: torch.utils.data.dataset.Dataset, n_epochs: int, min_loss: float, writer: torch.utils.tensorboard.writer.SummaryWriter, model_weights_path: str):
 28def run_pipeline(
 29    model: nn.Module,
 30    dataloader_train: DataLoader,
 31    dataloader_val: DataLoader,
 32    dataset_train: Dataset,
 33    dataset_val: Dataset,
 34    n_epochs: int,
 35    min_loss: float,
 36    writer: SummaryWriter,
 37    model_weights_path: str,
 38):
 39    """The full training pipeline performs a loop over the following:
 40
 41    ```Loop
 42    1) Train the model
 43    2) Perform evaluation
 44    3) Record the average loss
 45        If the model has the best average loss:
 46            Save off the model
 47            Generate a confusion matrix
 48        Else:
 49            continue training
 50    ```
 51
 52    .. note::
 53        The `dataloaders` must provide `(x, y)` representing inputs and true labels.
 54
 55    Args:
 56        model (nn.Module): Torch model.
 57        dataloader_train (DataLoader): DataLoader with `(x, y)` data.
 58        dataloader_val (DataLoader): DataLoader with `(x, y)` data.
 59        dataset_train (Dataset): Datset with `(x, y)` data.
 60        dataset_val (Dataset): Datset with `(x, y)` data.
 61        n_epochs (int): Number of epochs to run training.
 62        min_loss (float): Minimium loss recorded for the training iterations.
 63        writer (SummaryWriter): Tensorboard writer.
 64        model_weights_path (str): Path to save off model.
 65    """
 66    # Create optimizer and loss functions
 67    optimizer = Adam(params=model.parameters(), lr=0.00001)
 68    loss_fn = CrossEntropyLoss()
 69
 70    # Perform training on the data
 71    for global_step in range(0, n_epochs):
 72        print(f"Global step: {global_step}")
 73
 74        # Perform training
 75        avg_train_loss = training_iteration(
 76            model=model,
 77            loss_fn=loss_fn,
 78            optimizer=optimizer,
 79            dataloader_train=dataloader_train,
 80            global_step=global_step,
 81            writer=writer,
 82        )
 83
 84        # Perform evaluation
 85        avg_eval_loss = eval_iteration(
 86            model=model,
 87            loss_fn=loss_fn,
 88            dataloader_val=dataloader_val,
 89            global_step=global_step,
 90            writer=writer,
 91        )
 92
 93        # Printout
 94        print(f"Epoch:               {global_step}/{n_epochs-1}")
 95        print(f"    Training   Loss: {avg_train_loss:.3f}")
 96        print(f"    Evaluation Loss: {avg_eval_loss:.3f}")
 97
 98        # Save off best model
 99        if avg_eval_loss < min_loss:
100            # Update the min_loss as the new baseline
101            min_loss = avg_eval_loss
102
103            # Save off the best model
104            torch.save(model.state_dict(), model_weights_path)
105
106            # Run analysis
107            run_analysis(
108                model=model,
109                dataloader_val=dataloader_val,
110                dataset_val=dataset_val,
111                global_step=global_step,
112                writer=writer,
113            )
114
115            # Printout
116            print(f"Saving best model with evaluation loss of: {avg_eval_loss:.2f}")

The full training pipeline performs a loop over the following:

1) Train the model
2) Perform evaluation
3) Record the average loss
    If the model has the best average loss:
        Save off the model
        Generate a confusion matrix
    Else:
        continue training

The dataloaders must provide (x, y) representing inputs and true labels.

Arguments:
  • model (nn.Module): Torch model.
  • dataloader_train (DataLoader): DataLoader with (x, y) data.
  • dataloader_val (DataLoader): DataLoader with (x, y) data.
  • dataset_train (Dataset): Datset with (x, y) data.
  • dataset_val (Dataset): Datset with (x, y) data.
  • n_epochs (int): Number of epochs to run training.
  • min_loss (float): Minimium loss recorded for the training iterations.
  • writer (SummaryWriter): Tensorboard writer.
  • model_weights_path (str): Path to save off model.
def run_analysis( model: torch.nn.modules.module.Module, dataloader_val: torch.utils.data.dataloader.DataLoader, dataset_val: torch.utils.data.dataset.Dataset, global_step: int, writer: torch.utils.tensorboard.writer.SummaryWriter):
119def run_analysis(
120    model: nn.Module,
121    dataloader_val: DataLoader,
122    dataset_val: Dataset,
123    global_step: int,
124    writer: SummaryWriter,
125):
126    """Run analysis on each epoch where the model improves.  This should log multiple metrics that can be used in post-analysis examination.
127
128    Args:
129        model (nn.Module): Torch model.
130        dataloader_val (DataLoader): Dataloader for validation.
131        dataset_val (Dataset): Datset for validation.
132        global_step (int): Global step for tracking.
133        writer (SummaryWriter): Tensorboard writer.
134    """
135
136    # Save off confusion matrix
137    plot_confusion_matrix(
138        model=model,
139        dataloader_val=dataloader_val,
140        save_path="save/plots/confusion_matrix",
141        fig_name=f"confusion_matrix_{global_step}",
142    )
143
144    # Generate projection logs
145    generate_projection(
146        model=model,
147        dataloader_val=dataloader_val,
148        writer=writer,
149        global_step=global_step,
150    )
151
152    # Generate saliency plot
153    saliency_map(
154        model=model,
155        dataset_val=dataset_val,
156        figure_name="saliency",
157        global_step=global_step,
158        save_path=f"save/plots/saliency",
159    )

Run analysis on each epoch where the model improves. This should log multiple metrics that can be used in post-analysis examination.

Arguments:
  • model (nn.Module): Torch model.
  • dataloader_val (DataLoader): Dataloader for validation.
  • dataset_val (Dataset): Datset for validation.
  • global_step (int): Global step for tracking.
  • writer (SummaryWriter): Tensorboard writer.
def run_mnist( n_epochs: int = 2, min_loss: float = inf, batch_size: int = 32, path_model: str = 'save/models', path_log: str = 'save/logs', path_plots: str = 'save/plots', model_name: str = 'best_model.pkl'):
162def run_mnist(
163    n_epochs: int = 2,
164    min_loss: float = np.inf,
165    batch_size: int = 32,
166    path_model: str = "save/models",
167    path_log: str = "save/logs",
168    path_plots: str = "save/plots",
169    model_name: str = "best_model.pkl",
170):
171    """Perform MNIST training.  This function performs the following:
172
173        1) Create save/log directories
174        2) Get data iterators for train and evaluation
175        3) Create Tensorboard writer
176        4) Create model
177        5) Execute training pipeline using the data iterators, Tensorboard writer, and model provided
178
179    Args:
180        n_epochs (int, optional): Number of epochs to perform training. Defaults to 2.
181        min_loss (float, optional): The minimum loss recorded thus far. Defaults to np.inf.
182        batch_size (int, optional): The batch size to use when retrieving data from data iterators. Defaults to 16.
183        path_model (str, optional): The path to the model. Defaults to "save/models".
184        path_log (str, optional): The path to the logs. Defaults to "logs".
185        model_name (str, optional): The name of the model. Defaults to "best_model.pkl".
186    """
187    # Fully defined path to model save file
188    model_weights_path = f"{path_model}/{model_name}"
189
190    # Create directories
191    subprocess.run(["mkdir", "-p", path_model])
192    subprocess.run(["mkdir", "-p", path_log])
193    subprocess.run(["mkdir", "-p", path_plots])
194
195    # Preparing pipeline
196    dataloader_train, dataloader_val, dataset_train, dataset_val = get_mnist_dataloaders(
197        batch_size=batch_size
198    )
199    writer = SummaryWriter(path_log)  # Tensorboard
200    model = BasicCNN()  # hyperparameters
201
202    # Preprocessing
203    saliency_map(
204        model,
205        dataset_val=dataset_val,
206        figure_name="saliency",
207        global_step=-1,
208        save_path=f"{path_plots}/saliency",
209    )
210
211    # Save off model architecture.  Requires that you pass in a torch tensor for forward pass in order for logging to happen.
212    writer.add_graph(model, torch.rand(1, 1, 28, 28))
213
214    # Train the model
215    run_pipeline(
216        model=model,
217        dataloader_train=dataloader_train,
218        dataloader_val=dataloader_val,
219        dataset_train=dataset_train,
220        dataset_val=dataset_val,
221        n_epochs=n_epochs,
222        min_loss=min_loss,
223        writer=writer,
224        model_weights_path=model_weights_path,
225    )
226
227    # TODO: generate_canocial is not currently working
228    writer.close()
229
230    # Printout
231    print(Panel("To open Tensorboard type:"))
232    print("     tensorboard --logdir=save/logs")
233    print("     :warning: Chrome is recommended for viewing projections!")

Perform MNIST training. This function performs the following:

1) Create save/log directories
2) Get data iterators for train and evaluation
3) Create Tensorboard writer
4) Create model
5) Execute training pipeline using the data iterators, Tensorboard writer, and model provided
Arguments:
  • n_epochs (int, optional): Number of epochs to perform training. Defaults to 2.
  • min_loss (float, optional): The minimum loss recorded thus far. Defaults to np.inf.
  • batch_size (int, optional): The batch size to use when retrieving data from data iterators. Defaults to 16.
  • path_model (str, optional): The path to the model. Defaults to "save/models".
  • path_log (str, optional): The path to the logs. Defaults to "logs".
  • model_name (str, optional): The name of the model. Defaults to "best_model.pkl".