src.train

  1import os
  2
  3import numpy as np
  4import torch
  5from torch import nn
  6from torch.optim.sgd import SGD
  7from torch.utils.data import DataLoader
  8from tqdm import tqdm
  9
 10from src.data import DatasetMusic
 11from src.network import AudioNetwork
 12
 13
 14def test_triplet_training():
 15    """Performs a test of the training pipeline."""
 16
 17    # Configure the network
 18    network = AudioNetwork()
 19    network.train()
 20
 21    # Setup optimizer
 22    optimizer = SGD(network.parameters(), lr=0.001)
 23
 24    # Setup the loss function
 25    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
 26
 27    waveform_anc = torch.ones(4, 2, 64, 2584, requires_grad=True)
 28    waveform_pos = torch.zeros(4, 2, 64, 2584, requires_grad=True)
 29    waveform_neg = torch.randn(4, 2, 64, 2584, requires_grad=True)
 30
 31    # Iterate over epochs
 32    for epoch_idx in range(0, 3):
 33        # Perform training
 34        for iter_idx in range(0, 3):
 35            optimizer.zero_grad()
 36
 37            anchor__ = network(waveform_anc)
 38            positive = network(waveform_pos)
 39            negative = network(waveform_neg)
 40
 41            # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html
 42            output = triplet_loss(anchor__, positive, negative)
 43
 44            print(f"output loss={output}")
 45
 46            output.backward()
 47
 48            optimizer.step()
 49
 50
 51def triplet_training():
 52    """Performs the training process."""
 53
 54    # Variables
 55    model_save_path = "save/audio_network.pkl"
 56
 57    # Setup the dataset
 58    dataset_train = DatasetMusic(train=True, spectrogram=True)
 59    dataset_valid = DatasetMusic(train=False, spectrogram=True)
 60
 61    # Setup dataloaders
 62    dataloader_train = DataLoader(dataset_train, batch_size=16, num_workers=4, prefetch_factor=1)
 63    dataloader_valid = DataLoader(dataset_valid, batch_size=16, num_workers=4, prefetch_factor=1)
 64
 65    # Configure the network
 66    network = AudioNetwork()
 67
 68    if os.path.exists(model_save_path):
 69        network.load_state_dict(torch.load(model_save_path))
 70
 71    # Setup optimizer
 72    optimizer = SGD(network.parameters(), lr=0.001)
 73
 74    # Setup the loss function
 75    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
 76
 77    # Iterate over epochs
 78    for epoch_idx in tqdm(range(0, 10), desc="Epoch", colour="green"):
 79        # Perform training
 80        network.train()
 81        for sample in tqdm(dataloader_train, desc="Iteration", colour="blue"):
 82            waveform_anc = sample[0]
 83            waveform_pos = sample[1]
 84            waveform_neg = sample[2]
 85
 86            optimizer.zero_grad()
 87
 88            anchor__ = network(waveform_anc)
 89            positive = network(waveform_pos)
 90            negative = network(waveform_neg)
 91
 92            # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html
 93            output = triplet_loss(anchor__, positive, negative)
 94            output.backward()
 95
 96            optimizer.step()
 97
 98        # Perform validation and calculate loss
 99        network.eval()
100        max_validation_loss = 100
101        losses = []
102        for sample in tqdm(dataloader_valid, desc="Validating", colour="red"):
103            waveform_anc = sample[0]
104            waveform_pos = sample[1]
105            waveform_neg = sample[2]
106
107            with torch.no_grad():
108                anchor__ = network(waveform_anc)
109                positive = network(waveform_pos)
110                negative = network(waveform_neg)
111
112                output = triplet_loss(anchor__, positive, negative)
113
114                losses.append(output)
115
116        if np.mean(losses) < max_validation_loss:
117            max_validation_loss = np.mean(losses)
118            torch.save(network.state_dict(), model_save_path)
119            print("Saving best model...")
120
121        print(f"Average Validation Loss: {np.mean(losses)}")
122
123
124if __name__ == "__main__":
125    # Test the pipeline
126    # test_triplet_training()
127
128    # Run the pipeline with the real data
129    triplet_training()
def test_triplet_training():
15def test_triplet_training():
16    """Performs a test of the training pipeline."""
17
18    # Configure the network
19    network = AudioNetwork()
20    network.train()
21
22    # Setup optimizer
23    optimizer = SGD(network.parameters(), lr=0.001)
24
25    # Setup the loss function
26    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
27
28    waveform_anc = torch.ones(4, 2, 64, 2584, requires_grad=True)
29    waveform_pos = torch.zeros(4, 2, 64, 2584, requires_grad=True)
30    waveform_neg = torch.randn(4, 2, 64, 2584, requires_grad=True)
31
32    # Iterate over epochs
33    for epoch_idx in range(0, 3):
34        # Perform training
35        for iter_idx in range(0, 3):
36            optimizer.zero_grad()
37
38            anchor__ = network(waveform_anc)
39            positive = network(waveform_pos)
40            negative = network(waveform_neg)
41
42            # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html
43            output = triplet_loss(anchor__, positive, negative)
44
45            print(f"output loss={output}")
46
47            output.backward()
48
49            optimizer.step()

Performs a test of the training pipeline.

def triplet_training():
 52def triplet_training():
 53    """Performs the training process."""
 54
 55    # Variables
 56    model_save_path = "save/audio_network.pkl"
 57
 58    # Setup the dataset
 59    dataset_train = DatasetMusic(train=True, spectrogram=True)
 60    dataset_valid = DatasetMusic(train=False, spectrogram=True)
 61
 62    # Setup dataloaders
 63    dataloader_train = DataLoader(dataset_train, batch_size=16, num_workers=4, prefetch_factor=1)
 64    dataloader_valid = DataLoader(dataset_valid, batch_size=16, num_workers=4, prefetch_factor=1)
 65
 66    # Configure the network
 67    network = AudioNetwork()
 68
 69    if os.path.exists(model_save_path):
 70        network.load_state_dict(torch.load(model_save_path))
 71
 72    # Setup optimizer
 73    optimizer = SGD(network.parameters(), lr=0.001)
 74
 75    # Setup the loss function
 76    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
 77
 78    # Iterate over epochs
 79    for epoch_idx in tqdm(range(0, 10), desc="Epoch", colour="green"):
 80        # Perform training
 81        network.train()
 82        for sample in tqdm(dataloader_train, desc="Iteration", colour="blue"):
 83            waveform_anc = sample[0]
 84            waveform_pos = sample[1]
 85            waveform_neg = sample[2]
 86
 87            optimizer.zero_grad()
 88
 89            anchor__ = network(waveform_anc)
 90            positive = network(waveform_pos)
 91            negative = network(waveform_neg)
 92
 93            # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html
 94            output = triplet_loss(anchor__, positive, negative)
 95            output.backward()
 96
 97            optimizer.step()
 98
 99        # Perform validation and calculate loss
100        network.eval()
101        max_validation_loss = 100
102        losses = []
103        for sample in tqdm(dataloader_valid, desc="Validating", colour="red"):
104            waveform_anc = sample[0]
105            waveform_pos = sample[1]
106            waveform_neg = sample[2]
107
108            with torch.no_grad():
109                anchor__ = network(waveform_anc)
110                positive = network(waveform_pos)
111                negative = network(waveform_neg)
112
113                output = triplet_loss(anchor__, positive, negative)
114
115                losses.append(output)
116
117        if np.mean(losses) < max_validation_loss:
118            max_validation_loss = np.mean(losses)
119            torch.save(network.state_dict(), model_save_path)
120            print("Saving best model...")
121
122        print(f"Average Validation Loss: {np.mean(losses)}")

Performs the training process.