src.train
1import os 2 3import numpy as np 4import torch 5from torch import nn 6from torch.optim.sgd import SGD 7from torch.utils.data import DataLoader 8from tqdm import tqdm 9 10from src.data import DatasetMusic 11from src.network import AudioNetwork 12 13 14def test_triplet_training(): 15 """Performs a test of the training pipeline.""" 16 17 # Configure the network 18 network = AudioNetwork() 19 network.train() 20 21 # Setup optimizer 22 optimizer = SGD(network.parameters(), lr=0.001) 23 24 # Setup the loss function 25 triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) 26 27 waveform_anc = torch.ones(4, 2, 64, 2584, requires_grad=True) 28 waveform_pos = torch.zeros(4, 2, 64, 2584, requires_grad=True) 29 waveform_neg = torch.randn(4, 2, 64, 2584, requires_grad=True) 30 31 # Iterate over epochs 32 for epoch_idx in range(0, 3): 33 # Perform training 34 for iter_idx in range(0, 3): 35 optimizer.zero_grad() 36 37 anchor__ = network(waveform_anc) 38 positive = network(waveform_pos) 39 negative = network(waveform_neg) 40 41 # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html 42 output = triplet_loss(anchor__, positive, negative) 43 44 print(f"output loss={output}") 45 46 output.backward() 47 48 optimizer.step() 49 50 51def triplet_training(): 52 """Performs the training process.""" 53 54 # Variables 55 model_save_path = "save/audio_network.pkl" 56 57 # Setup the dataset 58 dataset_train = DatasetMusic(train=True, spectrogram=True) 59 dataset_valid = DatasetMusic(train=False, spectrogram=True) 60 61 # Setup dataloaders 62 dataloader_train = DataLoader(dataset_train, batch_size=16, num_workers=4, prefetch_factor=1) 63 dataloader_valid = DataLoader(dataset_valid, batch_size=16, num_workers=4, prefetch_factor=1) 64 65 # Configure the network 66 network = AudioNetwork() 67 68 if os.path.exists(model_save_path): 69 network.load_state_dict(torch.load(model_save_path)) 70 71 # Setup optimizer 72 optimizer = SGD(network.parameters(), lr=0.001) 73 74 # Setup the loss function 75 triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) 76 77 # Iterate over epochs 78 for epoch_idx in tqdm(range(0, 10), desc="Epoch", colour="green"): 79 # Perform training 80 network.train() 81 for sample in tqdm(dataloader_train, desc="Iteration", colour="blue"): 82 waveform_anc = sample[0] 83 waveform_pos = sample[1] 84 waveform_neg = sample[2] 85 86 optimizer.zero_grad() 87 88 anchor__ = network(waveform_anc) 89 positive = network(waveform_pos) 90 negative = network(waveform_neg) 91 92 # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html 93 output = triplet_loss(anchor__, positive, negative) 94 output.backward() 95 96 optimizer.step() 97 98 # Perform validation and calculate loss 99 network.eval() 100 max_validation_loss = 100 101 losses = [] 102 for sample in tqdm(dataloader_valid, desc="Validating", colour="red"): 103 waveform_anc = sample[0] 104 waveform_pos = sample[1] 105 waveform_neg = sample[2] 106 107 with torch.no_grad(): 108 anchor__ = network(waveform_anc) 109 positive = network(waveform_pos) 110 negative = network(waveform_neg) 111 112 output = triplet_loss(anchor__, positive, negative) 113 114 losses.append(output) 115 116 if np.mean(losses) < max_validation_loss: 117 max_validation_loss = np.mean(losses) 118 torch.save(network.state_dict(), model_save_path) 119 print("Saving best model...") 120 121 print(f"Average Validation Loss: {np.mean(losses)}") 122 123 124if __name__ == "__main__": 125 # Test the pipeline 126 # test_triplet_training() 127 128 # Run the pipeline with the real data 129 triplet_training()
def
test_triplet_training():
15def test_triplet_training(): 16 """Performs a test of the training pipeline.""" 17 18 # Configure the network 19 network = AudioNetwork() 20 network.train() 21 22 # Setup optimizer 23 optimizer = SGD(network.parameters(), lr=0.001) 24 25 # Setup the loss function 26 triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) 27 28 waveform_anc = torch.ones(4, 2, 64, 2584, requires_grad=True) 29 waveform_pos = torch.zeros(4, 2, 64, 2584, requires_grad=True) 30 waveform_neg = torch.randn(4, 2, 64, 2584, requires_grad=True) 31 32 # Iterate over epochs 33 for epoch_idx in range(0, 3): 34 # Perform training 35 for iter_idx in range(0, 3): 36 optimizer.zero_grad() 37 38 anchor__ = network(waveform_anc) 39 positive = network(waveform_pos) 40 negative = network(waveform_neg) 41 42 # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html 43 output = triplet_loss(anchor__, positive, negative) 44 45 print(f"output loss={output}") 46 47 output.backward() 48 49 optimizer.step()
Performs a test of the training pipeline.
def
triplet_training():
52def triplet_training(): 53 """Performs the training process.""" 54 55 # Variables 56 model_save_path = "save/audio_network.pkl" 57 58 # Setup the dataset 59 dataset_train = DatasetMusic(train=True, spectrogram=True) 60 dataset_valid = DatasetMusic(train=False, spectrogram=True) 61 62 # Setup dataloaders 63 dataloader_train = DataLoader(dataset_train, batch_size=16, num_workers=4, prefetch_factor=1) 64 dataloader_valid = DataLoader(dataset_valid, batch_size=16, num_workers=4, prefetch_factor=1) 65 66 # Configure the network 67 network = AudioNetwork() 68 69 if os.path.exists(model_save_path): 70 network.load_state_dict(torch.load(model_save_path)) 71 72 # Setup optimizer 73 optimizer = SGD(network.parameters(), lr=0.001) 74 75 # Setup the loss function 76 triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) 77 78 # Iterate over epochs 79 for epoch_idx in tqdm(range(0, 10), desc="Epoch", colour="green"): 80 # Perform training 81 network.train() 82 for sample in tqdm(dataloader_train, desc="Iteration", colour="blue"): 83 waveform_anc = sample[0] 84 waveform_pos = sample[1] 85 waveform_neg = sample[2] 86 87 optimizer.zero_grad() 88 89 anchor__ = network(waveform_anc) 90 positive = network(waveform_pos) 91 negative = network(waveform_neg) 92 93 # citation: https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html 94 output = triplet_loss(anchor__, positive, negative) 95 output.backward() 96 97 optimizer.step() 98 99 # Perform validation and calculate loss 100 network.eval() 101 max_validation_loss = 100 102 losses = [] 103 for sample in tqdm(dataloader_valid, desc="Validating", colour="red"): 104 waveform_anc = sample[0] 105 waveform_pos = sample[1] 106 waveform_neg = sample[2] 107 108 with torch.no_grad(): 109 anchor__ = network(waveform_anc) 110 positive = network(waveform_pos) 111 negative = network(waveform_neg) 112 113 output = triplet_loss(anchor__, positive, negative) 114 115 losses.append(output) 116 117 if np.mean(losses) < max_validation_loss: 118 max_validation_loss = np.mean(losses) 119 torch.save(network.state_dict(), model_save_path) 120 print("Saving best model...") 121 122 print(f"Average Validation Loss: {np.mean(losses)}")
Performs the training process.