src.eval
1import numpy as np 2import pandas as pd 3import torch 4from sklearn.metrics import ConfusionMatrixDisplay 5from torch.utils.data import DataLoader 6from tqdm import tqdm 7 8from src.data import DatasetMusic 9from src.data import DatasetMusicTest 10from src.network import AudioNetwork 11 12 13def get_anchor_embeddings(network: torch.nn.Module, anchors_waveform: dict) -> dict: 14 # Get the embeddings for each anchor 15 anchors_embedding = {} 16 for name, waveform in anchors_waveform.items(): 17 # Perform inferencing to get the embedding 18 with torch.no_grad(): 19 output = network(waveform.unsqueeze(0)) 20 anchors_embedding[name] = output.detach().numpy() 21 22 return anchors_embedding 23 24 25def plot_confusion_matrix(): 26 # Variables 27 model_save_path = "save/audio_network.pkl" 28 29 # Create the network 30 network = AudioNetwork() 31 network.load_state_dict(torch.load(model_save_path)) 32 33 # Create the dataloader for validation 34 dataset_valid = DatasetMusic(train=False, spectrogram=True) 35 dataloader_valid = DataLoader(dataset_valid, batch_size=1) 36 37 # Get the anchor waveforms 38 anchors_waveforms = dataset_valid.get_anchors() 39 anchors_embedding = get_anchor_embeddings(network, anchors_waveforms) 40 41 y_true = [] 42 y_pred = [] 43 44 for sample in tqdm(dataloader_valid, desc="Validating", colour="red"): 45 # Waveforms 46 waveform_pos = sample[1] 47 48 # Labels 49 label_pos = sample[4][0] 50 51 # Generate prediction 52 prediction = network(waveform_pos) 53 54 # Need to compare the prediction versus all anchors 55 association = None 56 max_dist = 1000 57 for key, val in anchors_embedding.items(): 58 # We are going to use Euclidean distance to figure out what to associate the embedding with 59 dist = np.linalg.norm(prediction.detach().numpy() - val) 60 if dist < max_dist: 61 max_dist = dist 62 association = key 63 64 # Check if dist is out of bounds 65 if max_dist > 5: 66 association = "Other" 67 68 # Assign confusion matrix values 69 y_true.append(label_pos) 70 y_pred.append(association) 71 72 ConfusionMatrixDisplay.from_predictions( 73 y_true, 74 y_pred, 75 labels=["Beethoven", "Schubert", "Bach", "Brahms", "Other"], 76 ) 77 78 print("Stop here to view the confusion matrix!") 79 80 81def infer_test_data(): 82 # Variables 83 model_save_path = "save/audio_network.pkl" 84 85 # Create the network 86 network = AudioNetwork() 87 network.load_state_dict(torch.load(model_save_path)) 88 network.eval() 89 90 # Create the dataloader for validation 91 dataset_valid = DatasetMusic(train=False, spectrogram=True) 92 dataset_test = DatasetMusicTest(spectrogram=True) 93 94 # Setup dataloader 95 dataloader_test = DataLoader(dataset_test, batch_size=1) 96 97 # Get the anchor waveforms 98 anchors_waveforms = dataset_valid.get_anchors() 99 anchors_embedding = get_anchor_embeddings(network, anchors_waveforms) 100 101 distances = [] 102 103 for sample in tqdm(dataloader_test, desc="Validating", colour="red"): 104 # Waveforms 105 waveform = sample[0] 106 filename = sample[1][0] 107 108 # Generate prediction 109 prediction = network(waveform) 110 111 # Need to compare the prediction versus all anchors 112 min_dist = 1000 113 for key, val in anchors_embedding.items(): 114 # We are going to use Euclidean distance to figure out what to associate the embedding with 115 dist = np.linalg.norm(prediction.detach().numpy() - val) 116 if dist < min_dist: 117 min_dist = dist 118 119 # Check if dist is out of bounds 120 distances.append((filename, min_dist)) 121 122 # Print out the mininum distances 123 df = pd.DataFrame(distances, columns=["filename", "min_dist"]) 124 df = df.sort_values("min_dist").reset_index(drop=True) 125 print(df) 126 print("done!") 127 128 129if __name__ == "__main__": 130 plot_confusion_matrix() 131 # infer_test_data()
def
get_anchor_embeddings(network: torch.nn.modules.module.Module, anchors_waveform: dict) -> dict:
14def get_anchor_embeddings(network: torch.nn.Module, anchors_waveform: dict) -> dict: 15 # Get the embeddings for each anchor 16 anchors_embedding = {} 17 for name, waveform in anchors_waveform.items(): 18 # Perform inferencing to get the embedding 19 with torch.no_grad(): 20 output = network(waveform.unsqueeze(0)) 21 anchors_embedding[name] = output.detach().numpy() 22 23 return anchors_embedding
def
plot_confusion_matrix():
26def plot_confusion_matrix(): 27 # Variables 28 model_save_path = "save/audio_network.pkl" 29 30 # Create the network 31 network = AudioNetwork() 32 network.load_state_dict(torch.load(model_save_path)) 33 34 # Create the dataloader for validation 35 dataset_valid = DatasetMusic(train=False, spectrogram=True) 36 dataloader_valid = DataLoader(dataset_valid, batch_size=1) 37 38 # Get the anchor waveforms 39 anchors_waveforms = dataset_valid.get_anchors() 40 anchors_embedding = get_anchor_embeddings(network, anchors_waveforms) 41 42 y_true = [] 43 y_pred = [] 44 45 for sample in tqdm(dataloader_valid, desc="Validating", colour="red"): 46 # Waveforms 47 waveform_pos = sample[1] 48 49 # Labels 50 label_pos = sample[4][0] 51 52 # Generate prediction 53 prediction = network(waveform_pos) 54 55 # Need to compare the prediction versus all anchors 56 association = None 57 max_dist = 1000 58 for key, val in anchors_embedding.items(): 59 # We are going to use Euclidean distance to figure out what to associate the embedding with 60 dist = np.linalg.norm(prediction.detach().numpy() - val) 61 if dist < max_dist: 62 max_dist = dist 63 association = key 64 65 # Check if dist is out of bounds 66 if max_dist > 5: 67 association = "Other" 68 69 # Assign confusion matrix values 70 y_true.append(label_pos) 71 y_pred.append(association) 72 73 ConfusionMatrixDisplay.from_predictions( 74 y_true, 75 y_pred, 76 labels=["Beethoven", "Schubert", "Bach", "Brahms", "Other"], 77 ) 78 79 print("Stop here to view the confusion matrix!")
def
infer_test_data():
82def infer_test_data(): 83 # Variables 84 model_save_path = "save/audio_network.pkl" 85 86 # Create the network 87 network = AudioNetwork() 88 network.load_state_dict(torch.load(model_save_path)) 89 network.eval() 90 91 # Create the dataloader for validation 92 dataset_valid = DatasetMusic(train=False, spectrogram=True) 93 dataset_test = DatasetMusicTest(spectrogram=True) 94 95 # Setup dataloader 96 dataloader_test = DataLoader(dataset_test, batch_size=1) 97 98 # Get the anchor waveforms 99 anchors_waveforms = dataset_valid.get_anchors() 100 anchors_embedding = get_anchor_embeddings(network, anchors_waveforms) 101 102 distances = [] 103 104 for sample in tqdm(dataloader_test, desc="Validating", colour="red"): 105 # Waveforms 106 waveform = sample[0] 107 filename = sample[1][0] 108 109 # Generate prediction 110 prediction = network(waveform) 111 112 # Need to compare the prediction versus all anchors 113 min_dist = 1000 114 for key, val in anchors_embedding.items(): 115 # We are going to use Euclidean distance to figure out what to associate the embedding with 116 dist = np.linalg.norm(prediction.detach().numpy() - val) 117 if dist < min_dist: 118 min_dist = dist 119 120 # Check if dist is out of bounds 121 distances.append((filename, min_dist)) 122 123 # Print out the mininum distances 124 df = pd.DataFrame(distances, columns=["filename", "min_dist"]) 125 df = df.sort_values("min_dist").reset_index(drop=True) 126 print(df) 127 print("done!")