src.eval

  1import numpy as np
  2import pandas as pd
  3import torch
  4from sklearn.metrics import ConfusionMatrixDisplay
  5from torch.utils.data import DataLoader
  6from tqdm import tqdm
  7
  8from src.data import DatasetMusic
  9from src.data import DatasetMusicTest
 10from src.network import AudioNetwork
 11
 12
 13def get_anchor_embeddings(network: torch.nn.Module, anchors_waveform: dict) -> dict:
 14    # Get the embeddings for each anchor
 15    anchors_embedding = {}
 16    for name, waveform in anchors_waveform.items():
 17        # Perform inferencing to get the embedding
 18        with torch.no_grad():
 19            output = network(waveform.unsqueeze(0))
 20            anchors_embedding[name] = output.detach().numpy()
 21
 22    return anchors_embedding
 23
 24
 25def plot_confusion_matrix():
 26    # Variables
 27    model_save_path = "save/audio_network.pkl"
 28
 29    # Create the network
 30    network = AudioNetwork()
 31    network.load_state_dict(torch.load(model_save_path))
 32
 33    # Create the dataloader for validation
 34    dataset_valid = DatasetMusic(train=False, spectrogram=True)
 35    dataloader_valid = DataLoader(dataset_valid, batch_size=1)
 36
 37    # Get the anchor waveforms
 38    anchors_waveforms = dataset_valid.get_anchors()
 39    anchors_embedding = get_anchor_embeddings(network, anchors_waveforms)
 40
 41    y_true = []
 42    y_pred = []
 43
 44    for sample in tqdm(dataloader_valid, desc="Validating", colour="red"):
 45        # Waveforms
 46        waveform_pos = sample[1]
 47
 48        # Labels
 49        label_pos = sample[4][0]
 50
 51        # Generate prediction
 52        prediction = network(waveform_pos)
 53
 54        # Need to compare the prediction versus all anchors
 55        association = None
 56        max_dist = 1000
 57        for key, val in anchors_embedding.items():
 58            # We are going to use Euclidean distance to figure out what to associate the embedding with
 59            dist = np.linalg.norm(prediction.detach().numpy() - val)
 60            if dist < max_dist:
 61                max_dist = dist
 62                association = key
 63
 64        # Check if dist is out of bounds
 65        if max_dist > 5:
 66            association = "Other"
 67
 68        # Assign confusion matrix values
 69        y_true.append(label_pos)
 70        y_pred.append(association)
 71
 72    ConfusionMatrixDisplay.from_predictions(
 73        y_true,
 74        y_pred,
 75        labels=["Beethoven", "Schubert", "Bach", "Brahms", "Other"],
 76    )
 77
 78    print("Stop here to view the confusion matrix!")
 79
 80
 81def infer_test_data():
 82    # Variables
 83    model_save_path = "save/audio_network.pkl"
 84
 85    # Create the network
 86    network = AudioNetwork()
 87    network.load_state_dict(torch.load(model_save_path))
 88    network.eval()
 89
 90    # Create the dataloader for validation
 91    dataset_valid = DatasetMusic(train=False, spectrogram=True)
 92    dataset_test = DatasetMusicTest(spectrogram=True)
 93
 94    # Setup dataloader
 95    dataloader_test = DataLoader(dataset_test, batch_size=1)
 96
 97    # Get the anchor waveforms
 98    anchors_waveforms = dataset_valid.get_anchors()
 99    anchors_embedding = get_anchor_embeddings(network, anchors_waveforms)
100
101    distances = []
102
103    for sample in tqdm(dataloader_test, desc="Validating", colour="red"):
104        # Waveforms
105        waveform = sample[0]
106        filename = sample[1][0]
107
108        # Generate prediction
109        prediction = network(waveform)
110
111        # Need to compare the prediction versus all anchors
112        min_dist = 1000
113        for key, val in anchors_embedding.items():
114            # We are going to use Euclidean distance to figure out what to associate the embedding with
115            dist = np.linalg.norm(prediction.detach().numpy() - val)
116            if dist < min_dist:
117                min_dist = dist
118
119        # Check if dist is out of bounds
120        distances.append((filename, min_dist))
121
122    # Print out the mininum distances
123    df = pd.DataFrame(distances, columns=["filename", "min_dist"])
124    df = df.sort_values("min_dist").reset_index(drop=True)
125    print(df)
126    print("done!")
127
128
129if __name__ == "__main__":
130    plot_confusion_matrix()
131    # infer_test_data()
def get_anchor_embeddings(network: torch.nn.modules.module.Module, anchors_waveform: dict) -> dict:
14def get_anchor_embeddings(network: torch.nn.Module, anchors_waveform: dict) -> dict:
15    # Get the embeddings for each anchor
16    anchors_embedding = {}
17    for name, waveform in anchors_waveform.items():
18        # Perform inferencing to get the embedding
19        with torch.no_grad():
20            output = network(waveform.unsqueeze(0))
21            anchors_embedding[name] = output.detach().numpy()
22
23    return anchors_embedding
def plot_confusion_matrix():
26def plot_confusion_matrix():
27    # Variables
28    model_save_path = "save/audio_network.pkl"
29
30    # Create the network
31    network = AudioNetwork()
32    network.load_state_dict(torch.load(model_save_path))
33
34    # Create the dataloader for validation
35    dataset_valid = DatasetMusic(train=False, spectrogram=True)
36    dataloader_valid = DataLoader(dataset_valid, batch_size=1)
37
38    # Get the anchor waveforms
39    anchors_waveforms = dataset_valid.get_anchors()
40    anchors_embedding = get_anchor_embeddings(network, anchors_waveforms)
41
42    y_true = []
43    y_pred = []
44
45    for sample in tqdm(dataloader_valid, desc="Validating", colour="red"):
46        # Waveforms
47        waveform_pos = sample[1]
48
49        # Labels
50        label_pos = sample[4][0]
51
52        # Generate prediction
53        prediction = network(waveform_pos)
54
55        # Need to compare the prediction versus all anchors
56        association = None
57        max_dist = 1000
58        for key, val in anchors_embedding.items():
59            # We are going to use Euclidean distance to figure out what to associate the embedding with
60            dist = np.linalg.norm(prediction.detach().numpy() - val)
61            if dist < max_dist:
62                max_dist = dist
63                association = key
64
65        # Check if dist is out of bounds
66        if max_dist > 5:
67            association = "Other"
68
69        # Assign confusion matrix values
70        y_true.append(label_pos)
71        y_pred.append(association)
72
73    ConfusionMatrixDisplay.from_predictions(
74        y_true,
75        y_pred,
76        labels=["Beethoven", "Schubert", "Bach", "Brahms", "Other"],
77    )
78
79    print("Stop here to view the confusion matrix!")
def infer_test_data():
 82def infer_test_data():
 83    # Variables
 84    model_save_path = "save/audio_network.pkl"
 85
 86    # Create the network
 87    network = AudioNetwork()
 88    network.load_state_dict(torch.load(model_save_path))
 89    network.eval()
 90
 91    # Create the dataloader for validation
 92    dataset_valid = DatasetMusic(train=False, spectrogram=True)
 93    dataset_test = DatasetMusicTest(spectrogram=True)
 94
 95    # Setup dataloader
 96    dataloader_test = DataLoader(dataset_test, batch_size=1)
 97
 98    # Get the anchor waveforms
 99    anchors_waveforms = dataset_valid.get_anchors()
100    anchors_embedding = get_anchor_embeddings(network, anchors_waveforms)
101
102    distances = []
103
104    for sample in tqdm(dataloader_test, desc="Validating", colour="red"):
105        # Waveforms
106        waveform = sample[0]
107        filename = sample[1][0]
108
109        # Generate prediction
110        prediction = network(waveform)
111
112        # Need to compare the prediction versus all anchors
113        min_dist = 1000
114        for key, val in anchors_embedding.items():
115            # We are going to use Euclidean distance to figure out what to associate the embedding with
116            dist = np.linalg.norm(prediction.detach().numpy() - val)
117            if dist < min_dist:
118                min_dist = dist
119
120        # Check if dist is out of bounds
121        distances.append((filename, min_dist))
122
123    # Print out the mininum distances
124    df = pd.DataFrame(distances, columns=["filename", "min_dist"])
125    df = df.sort_values("min_dist").reset_index(drop=True)
126    print(df)
127    print("done!")