src.data

  1import os
  2import random
  3from glob import glob
  4from random import choice
  5
  6import numpy as np
  7import pandas as pd
  8import torch
  9import torchaudio
 10from matplotlib import pyplot as plt
 11from torch.utils.data import DataLoader
 12from torch.utils.data import Dataset
 13from torchaudio import transforms
 14from tqdm import tqdm
 15
 16
 17# Set random seed
 18random.seed(10)
 19
 20
 21def get_spectrogram(
 22    waveform: torch.Tensor,
 23    sample_rate: int,
 24    n_mels=64,
 25    n_fft=1024,
 26    hop_len=None,
 27    debug: bool = False,
 28) -> torch.TensorType:
 29    """This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image.  This is easier to work with than the raw audio tensor format.
 30
 31    Args:
 32        waveform (torch.Tensor): Audio data
 33        sample_rate (int): Sample rate (should be 44.1kHz)
 34        n_mels (int, optional): _description_. Defaults to 64.
 35        n_fft (int, optional): _description_. Defaults to 1024.
 36        hop_len (_type_, optional): _description_. Defaults to None.
 37
 38    Returns:
 39        torch.TensorType: Tensor of shape 2x64x19654
 40    """
 41    # Cite: https://towardsdatascience.com/audio-deep-learning-made-simple-sound-classification-step-by-step-cebc936bbe5
 42
 43    top_db = 80
 44
 45    # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
 46    spectrogram = transforms.MelSpectrogram(
 47        sample_rate, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels
 48    )(torch.Tensor(waveform))
 49
 50    # Convert to decibels
 51    spectrogram = transforms.AmplitudeToDB(top_db=top_db)(spectrogram)
 52
 53    check1 = spectrogram.shape[0] == 2
 54    check2 = spectrogram.shape[1] == 64
 55    check3 = spectrogram.shape[2] == 2584
 56
 57    if (check1 and check2 and check3) is False:
 58        raise ValueError("Invalid sizing for Mel Spectogram!")
 59
 60    if debug:
 61        plt.figure()
 62        plt.imshow(spectrogram.log2()[0, :, :].numpy(), cmap="viridis")
 63        plt.close()
 64
 65    return spectrogram
 66
 67
 68class DatasetMusic(Dataset):
 69    def __init__(
 70        self,
 71        train: bool = True,
 72        spectrogram: bool = False,
 73    ) -> None:
 74        super().__init__()
 75        """Generate a dataframe that specifies the training and validation split of the data.
 76        """
 77
 78        # Set path to source data
 79        self.source_dir = "data/musicians"
 80
 81        # Initialize variables
 82        self._df = pd.DataFrame
 83        self._df_path = "save/music_df.pkl"
 84        self._new_sample_rate = 44100
 85        self._train = train
 86        self.spectogram = spectrogram
 87
 88        self.N_ANCHOR = 1
 89        self.N_POSITIVE = 1
 90        self.N_NEGATIVE = 1
 91        self.DEBUG = False
 92
 93        if os.path.isfile(self._df_path):
 94            self._df = pd.read_pickle(self._df_path)
 95        else:
 96            self.generate_df()
 97
 98    def generate_df(self):
 99        """Generate a dataframe that has the training and validation data."""
100
101        data = []
102
103        # Create iterator for musician folders
104        parent_folders = os.listdir(self.source_dir)
105        musician_folders = tqdm(parent_folders, colour="green")
106
107        # Iterate over all musician folders
108        for label in musician_folders:
109            musician_folders.set_description(f"Creating Dataset: {label}")
110
111            # Ignore irrelevant files
112            if label == ".DS_Store":
113                continue  # this is a MacOSX generated file
114
115            # Create iterator for MIDI files
116            wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red")
117
118            # Iterate over all files under musician folders
119            for ii, wav_file_path in enumerate(wav_files_path):
120                # Create a flag for train/validation
121                prob = random.random()
122                if ii == 1:
123                    anchor = True
124                    flag_train = True
125                    flag_valid = True
126                elif prob <= 0.8:
127                    anchor = False
128                    flag_train = True
129                    flag_valid = False
130                else:
131                    anchor = False
132                    flag_train = False
133                    flag_valid = True
134
135                data.append((wav_file_path, label, anchor, flag_train, flag_valid))
136
137        # Create a dataframe that contains all the wav files and their labels
138        self._df = pd.DataFrame(
139            data, columns=["wav_file_path", "label", "anchor", "train", "valid"]
140        )
141
142        self._df.to_pickle(self._df_path)
143
144    def __getitem__(self, index: int) -> tuple:
145        """Get an item from the dataset (i.e. data and label)
146
147        Args:
148            index (int): Unused at the moment.
149            train (bool, optional): The training flag. Defaults to True.
150
151        Returns:
152            tuple: data, label
153        """
154        # Define musicians to choose from
155        musicians = ["Bach", "Beethoven", "Brahms", "Schubert"]
156
157        # Randomly select a musician
158        anchor_musician = choice(musicians)
159
160        # Filter on the selected musician
161        filter_positive = self._df["label"] == anchor_musician
162        filter_negative = self._df["label"] != anchor_musician
163        filter_training = self._df["train"] == True
164        filter_validate = self._df["valid"] == True
165        filter_anchor = self._df["anchor"] == True
166
167        # Refine based on training vs validation
168        if self._train is True:
169            filter_positive = filter_positive & filter_training
170            filter_negative = filter_negative & filter_training
171        if self._train is False:
172            filter_positive = filter_positive & filter_validate
173            filter_negative = filter_negative & filter_validate
174
175        # Get positive and negative data entries
176        filter_data_anc = self._df[filter_positive & filter_anchor]
177        filter_data_pos = self._df[filter_positive & ~filter_anchor]
178        filter_data_neg = self._df[filter_negative]
179
180        # Extract the selected info
181        df_anc = filter_data_anc.sample(n=self.N_ANCHOR).reset_index(drop=True)
182        df_pos = filter_data_pos.sample(n=self.N_POSITIVE).reset_index(drop=True)
183        df_neg = filter_data_neg.sample(n=self.N_NEGATIVE).reset_index(drop=True)
184
185        anc_path_ = df_anc["wav_file_path"][0]
186        pos_path_ = df_pos["wav_file_path"][0]
187        neg_path_ = df_neg["wav_file_path"][0]
188
189        anc_label = df_anc["label"][0]
190        pos_label = df_pos["label"][0]
191        neg_label = df_neg["label"][0]
192
193        # Load the torchaudio waveform
194        waveform_anc, sample_rate_anc = torchaudio.load(anc_path_)
195        waveform_pos, sample_rate_pos = torchaudio.load(pos_path_)
196        waveform_neg, sample_rate_neg = torchaudio.load(neg_path_)
197
198        waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_)
199        waveform_pos = self.resample(waveform_pos, sample_rate_pos, pos_path_)
200        waveform_neg = self.resample(waveform_neg, sample_rate_neg, neg_path_)
201
202        # Check shape of waveform
203        if waveform_anc.shape != (2, 1323001):
204            print("Invalid shape detected!")
205        if waveform_pos.shape != (2, 1323001):
206            print("Invalid shape detected!")
207        if waveform_neg.shape != (2, 1323001):
208            print("Invalid shape detected!")
209
210        if self.spectogram:
211            waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate)
212            waveform_pos = get_spectrogram(waveform_pos, self._new_sample_rate)
213            waveform_neg = get_spectrogram(waveform_neg, self._new_sample_rate)
214
215        return waveform_anc, waveform_pos, waveform_neg, anc_label, pos_label, neg_label
216
217    def __len__(self):
218        if self._train:
219            return sum(self._df["train"] == True)
220        else:
221            return sum(self._df["valid"] == True)
222
223    def get_anchors(self) -> dict:
224        """Retrieve all anchors from the dataset."""
225
226        my_dict = {}
227
228        # Extract the selected info
229        df_anc = self._df[self._df["anchor"]]
230
231        for ii, row in df_anc.iterrows():
232            # Extract row information
233            anc_path_ = row["wav_file_path"]
234            anc_label = row["label"]
235
236            # Load the torchaudio waveform
237            waveform_anc, sample_rate_anc = torchaudio.load(anc_path_)
238
239            # Resample to specified window
240            waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_)
241
242            # Convert to spectrogram
243            if self.spectogram:
244                waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate)
245
246            my_dict[anc_label] = waveform_anc
247
248        return my_dict
249
250    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
251        # Conversion to 44.1kHz if not already
252        if sample_rate != self._new_sample_rate:
253            # Resample the data to match 44.1kHz sampling rate
254            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
255            print(f"Waveform Resampled: {path}")
256
257        # Crop out 30 seconds of data randomly from the music file
258        waveform = waveform.numpy()
259        num_channels, num_frames = waveform.shape
260        time_axis = np.arange(0, num_frames) / sample_rate
261
262        # Set window for 30 second interval
263        min_time_axis = 0
264        max_time_axis = 30
265        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
266        waveform = waveform[:, sel_time_axis]
267
268        # Adding debugging statement
269        if self.DEBUG:
270            torchaudio.save(
271                "test_dataset_resample.wav",
272                torch.tensor(waveform),
273                sample_rate=self._new_sample_rate,
274            )
275
276        return waveform
277
278
279class DatasetMusicTest(Dataset):
280    def __init__(self, spectrogram: bool = False) -> None:
281        super().__init__()
282        """Generate a dataframe that is used to test the data."""
283
284        # Set path to source data
285        self.source_dir = "data/test"
286
287        # Initialize variables
288        self._df = pd.DataFrame
289        self._df_path = "save/test_music_df.pkl"
290        self._new_sample_rate = 44100
291        self.spectogram = spectrogram
292
293        if os.path.isfile(self._df_path):
294            self._df = pd.read_pickle(self._df_path)
295        else:
296            self.generate_df()
297
298    def generate_df(self):
299        data = []
300
301        # Create iterator for MIDI files
302        wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red")
303
304        # Iterate over all files under musician folders
305        for ii, wav_file_path in enumerate(wav_files_path):
306            data.append(wav_file_path)
307
308        # Create a dataframe that contains all the wav files and their labels
309        self._df = pd.DataFrame(data, columns=["wav_file_path"])
310
311        self._df.to_pickle(self._df_path)
312
313    def __getitem__(self, index: int) -> torch.Tensor:
314        """Get an item from the dataset (i.e. data only)
315
316        Args:
317            index (int): Unused at the moment.
318
319        Returns:
320            tuple: data, label
321        """
322
323        # Get positive and negative data entries
324        indexed_data = self._df.loc[index]
325
326        # Extract the selected info
327        path = indexed_data["wav_file_path"]
328        name = str(path).split("/")[-1]
329
330        # Load the torchaudio waveform
331        waveform, sample_rate = torchaudio.load(path)
332        waveform = self.resample(waveform, sample_rate, path)
333
334        # Check shape of waveform
335        if waveform.shape != (2, 1323001):
336            print("Invalid shape detected!")
337
338        if self.spectogram:
339            waveform = get_spectrogram(waveform, self._new_sample_rate)
340
341        return waveform, name
342
343    def __len__(self):
344        return len(self._df)
345
346    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
347        # Conversion to 44.1kHz if not already
348        if sample_rate != self._new_sample_rate:
349            # Resample the data to match 44.1kHz sampling rate
350            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
351            print(f"Waveform Resampled: {path}")
352
353        # Crop out 30 seconds of data randomly from the music file
354        waveform = waveform.numpy()
355        num_channels, num_frames = waveform.shape
356        time_axis = np.arange(0, num_frames) / sample_rate
357
358        # Set window for 30 second interval
359        min_time_axis = 0
360        max_time_axis = 30
361        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
362        waveform = waveform[:, sel_time_axis]
363
364        return waveform
365
366
367if __name__ == "__main__":
368    # Setup training/validation dataset
369    dataset_music_train = DatasetMusic(train=True)
370    dataset_music_valid = DatasetMusic(train=False)
371
372    # Setup training/validation dataset
373    dataset_music_train.spectogram = True
374    dataset_music_valid.spectogram = True
375
376    # Convert to iterators
377    data_train = iter(dataset_music_train)
378    data_valid = iter(dataset_music_valid)
379
380    # Test training iterator
381    for ii in range(0, 10):
382        data_train.__next__()
383
384    # Test validation iterator
385    for ii in range(0, 10):
386        data_valid.__next__()
387
388    # Create dataloaders for training/validation
389    dataloader_valid = DataLoader(
390        dataset_music_train,
391        batch_size=4,
392        shuffle=True,
393    )
394
395    dataloader_valid = DataLoader(
396        dataset_music_valid,
397        batch_size=4,
398        shuffle=True,
399    )
def get_spectrogram( waveform: torch.Tensor, sample_rate: int, n_mels=64, n_fft=1024, hop_len=None, debug: bool = False) -> torch.TensorType:
22def get_spectrogram(
23    waveform: torch.Tensor,
24    sample_rate: int,
25    n_mels=64,
26    n_fft=1024,
27    hop_len=None,
28    debug: bool = False,
29) -> torch.TensorType:
30    """This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image.  This is easier to work with than the raw audio tensor format.
31
32    Args:
33        waveform (torch.Tensor): Audio data
34        sample_rate (int): Sample rate (should be 44.1kHz)
35        n_mels (int, optional): _description_. Defaults to 64.
36        n_fft (int, optional): _description_. Defaults to 1024.
37        hop_len (_type_, optional): _description_. Defaults to None.
38
39    Returns:
40        torch.TensorType: Tensor of shape 2x64x19654
41    """
42    # Cite: https://towardsdatascience.com/audio-deep-learning-made-simple-sound-classification-step-by-step-cebc936bbe5
43
44    top_db = 80
45
46    # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
47    spectrogram = transforms.MelSpectrogram(
48        sample_rate, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels
49    )(torch.Tensor(waveform))
50
51    # Convert to decibels
52    spectrogram = transforms.AmplitudeToDB(top_db=top_db)(spectrogram)
53
54    check1 = spectrogram.shape[0] == 2
55    check2 = spectrogram.shape[1] == 64
56    check3 = spectrogram.shape[2] == 2584
57
58    if (check1 and check2 and check3) is False:
59        raise ValueError("Invalid sizing for Mel Spectogram!")
60
61    if debug:
62        plt.figure()
63        plt.imshow(spectrogram.log2()[0, :, :].numpy(), cmap="viridis")
64        plt.close()
65
66    return spectrogram

This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image. This is easier to work with than the raw audio tensor format.

Arguments:
  • waveform (torch.Tensor): Audio data
  • sample_rate (int): Sample rate (should be 44.1kHz)
  • n_mels (int, optional): _description_. Defaults to 64.
  • n_fft (int, optional): _description_. Defaults to 1024.
  • hop_len (_type_, optional): _description_. Defaults to None.
Returns:

torch.TensorType: Tensor of shape 2x64x19654

class DatasetMusic(typing.Generic[+T_co]):
 69class DatasetMusic(Dataset):
 70    def __init__(
 71        self,
 72        train: bool = True,
 73        spectrogram: bool = False,
 74    ) -> None:
 75        super().__init__()
 76        """Generate a dataframe that specifies the training and validation split of the data.
 77        """
 78
 79        # Set path to source data
 80        self.source_dir = "data/musicians"
 81
 82        # Initialize variables
 83        self._df = pd.DataFrame
 84        self._df_path = "save/music_df.pkl"
 85        self._new_sample_rate = 44100
 86        self._train = train
 87        self.spectogram = spectrogram
 88
 89        self.N_ANCHOR = 1
 90        self.N_POSITIVE = 1
 91        self.N_NEGATIVE = 1
 92        self.DEBUG = False
 93
 94        if os.path.isfile(self._df_path):
 95            self._df = pd.read_pickle(self._df_path)
 96        else:
 97            self.generate_df()
 98
 99    def generate_df(self):
100        """Generate a dataframe that has the training and validation data."""
101
102        data = []
103
104        # Create iterator for musician folders
105        parent_folders = os.listdir(self.source_dir)
106        musician_folders = tqdm(parent_folders, colour="green")
107
108        # Iterate over all musician folders
109        for label in musician_folders:
110            musician_folders.set_description(f"Creating Dataset: {label}")
111
112            # Ignore irrelevant files
113            if label == ".DS_Store":
114                continue  # this is a MacOSX generated file
115
116            # Create iterator for MIDI files
117            wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red")
118
119            # Iterate over all files under musician folders
120            for ii, wav_file_path in enumerate(wav_files_path):
121                # Create a flag for train/validation
122                prob = random.random()
123                if ii == 1:
124                    anchor = True
125                    flag_train = True
126                    flag_valid = True
127                elif prob <= 0.8:
128                    anchor = False
129                    flag_train = True
130                    flag_valid = False
131                else:
132                    anchor = False
133                    flag_train = False
134                    flag_valid = True
135
136                data.append((wav_file_path, label, anchor, flag_train, flag_valid))
137
138        # Create a dataframe that contains all the wav files and their labels
139        self._df = pd.DataFrame(
140            data, columns=["wav_file_path", "label", "anchor", "train", "valid"]
141        )
142
143        self._df.to_pickle(self._df_path)
144
145    def __getitem__(self, index: int) -> tuple:
146        """Get an item from the dataset (i.e. data and label)
147
148        Args:
149            index (int): Unused at the moment.
150            train (bool, optional): The training flag. Defaults to True.
151
152        Returns:
153            tuple: data, label
154        """
155        # Define musicians to choose from
156        musicians = ["Bach", "Beethoven", "Brahms", "Schubert"]
157
158        # Randomly select a musician
159        anchor_musician = choice(musicians)
160
161        # Filter on the selected musician
162        filter_positive = self._df["label"] == anchor_musician
163        filter_negative = self._df["label"] != anchor_musician
164        filter_training = self._df["train"] == True
165        filter_validate = self._df["valid"] == True
166        filter_anchor = self._df["anchor"] == True
167
168        # Refine based on training vs validation
169        if self._train is True:
170            filter_positive = filter_positive & filter_training
171            filter_negative = filter_negative & filter_training
172        if self._train is False:
173            filter_positive = filter_positive & filter_validate
174            filter_negative = filter_negative & filter_validate
175
176        # Get positive and negative data entries
177        filter_data_anc = self._df[filter_positive & filter_anchor]
178        filter_data_pos = self._df[filter_positive & ~filter_anchor]
179        filter_data_neg = self._df[filter_negative]
180
181        # Extract the selected info
182        df_anc = filter_data_anc.sample(n=self.N_ANCHOR).reset_index(drop=True)
183        df_pos = filter_data_pos.sample(n=self.N_POSITIVE).reset_index(drop=True)
184        df_neg = filter_data_neg.sample(n=self.N_NEGATIVE).reset_index(drop=True)
185
186        anc_path_ = df_anc["wav_file_path"][0]
187        pos_path_ = df_pos["wav_file_path"][0]
188        neg_path_ = df_neg["wav_file_path"][0]
189
190        anc_label = df_anc["label"][0]
191        pos_label = df_pos["label"][0]
192        neg_label = df_neg["label"][0]
193
194        # Load the torchaudio waveform
195        waveform_anc, sample_rate_anc = torchaudio.load(anc_path_)
196        waveform_pos, sample_rate_pos = torchaudio.load(pos_path_)
197        waveform_neg, sample_rate_neg = torchaudio.load(neg_path_)
198
199        waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_)
200        waveform_pos = self.resample(waveform_pos, sample_rate_pos, pos_path_)
201        waveform_neg = self.resample(waveform_neg, sample_rate_neg, neg_path_)
202
203        # Check shape of waveform
204        if waveform_anc.shape != (2, 1323001):
205            print("Invalid shape detected!")
206        if waveform_pos.shape != (2, 1323001):
207            print("Invalid shape detected!")
208        if waveform_neg.shape != (2, 1323001):
209            print("Invalid shape detected!")
210
211        if self.spectogram:
212            waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate)
213            waveform_pos = get_spectrogram(waveform_pos, self._new_sample_rate)
214            waveform_neg = get_spectrogram(waveform_neg, self._new_sample_rate)
215
216        return waveform_anc, waveform_pos, waveform_neg, anc_label, pos_label, neg_label
217
218    def __len__(self):
219        if self._train:
220            return sum(self._df["train"] == True)
221        else:
222            return sum(self._df["valid"] == True)
223
224    def get_anchors(self) -> dict:
225        """Retrieve all anchors from the dataset."""
226
227        my_dict = {}
228
229        # Extract the selected info
230        df_anc = self._df[self._df["anchor"]]
231
232        for ii, row in df_anc.iterrows():
233            # Extract row information
234            anc_path_ = row["wav_file_path"]
235            anc_label = row["label"]
236
237            # Load the torchaudio waveform
238            waveform_anc, sample_rate_anc = torchaudio.load(anc_path_)
239
240            # Resample to specified window
241            waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_)
242
243            # Convert to spectrogram
244            if self.spectogram:
245                waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate)
246
247            my_dict[anc_label] = waveform_anc
248
249        return my_dict
250
251    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
252        # Conversion to 44.1kHz if not already
253        if sample_rate != self._new_sample_rate:
254            # Resample the data to match 44.1kHz sampling rate
255            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
256            print(f"Waveform Resampled: {path}")
257
258        # Crop out 30 seconds of data randomly from the music file
259        waveform = waveform.numpy()
260        num_channels, num_frames = waveform.shape
261        time_axis = np.arange(0, num_frames) / sample_rate
262
263        # Set window for 30 second interval
264        min_time_axis = 0
265        max_time_axis = 30
266        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
267        waveform = waveform[:, sel_time_axis]
268
269        # Adding debugging statement
270        if self.DEBUG:
271            torchaudio.save(
272                "test_dataset_resample.wav",
273                torch.tensor(waveform),
274                sample_rate=self._new_sample_rate,
275            )
276
277        return waveform

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

DatasetMusic(train: bool = True, spectrogram: bool = False)
70    def __init__(
71        self,
72        train: bool = True,
73        spectrogram: bool = False,
74    ) -> None:
75        super().__init__()
76        """Generate a dataframe that specifies the training and validation split of the data.
77        """
78
79        # Set path to source data
80        self.source_dir = "data/musicians"
81
82        # Initialize variables
83        self._df = pd.DataFrame
84        self._df_path = "save/music_df.pkl"
85        self._new_sample_rate = 44100
86        self._train = train
87        self.spectogram = spectrogram
88
89        self.N_ANCHOR = 1
90        self.N_POSITIVE = 1
91        self.N_NEGATIVE = 1
92        self.DEBUG = False
93
94        if os.path.isfile(self._df_path):
95            self._df = pd.read_pickle(self._df_path)
96        else:
97            self.generate_df()
source_dir
spectogram
N_ANCHOR
N_POSITIVE
N_NEGATIVE
DEBUG
def generate_df(self):
 99    def generate_df(self):
100        """Generate a dataframe that has the training and validation data."""
101
102        data = []
103
104        # Create iterator for musician folders
105        parent_folders = os.listdir(self.source_dir)
106        musician_folders = tqdm(parent_folders, colour="green")
107
108        # Iterate over all musician folders
109        for label in musician_folders:
110            musician_folders.set_description(f"Creating Dataset: {label}")
111
112            # Ignore irrelevant files
113            if label == ".DS_Store":
114                continue  # this is a MacOSX generated file
115
116            # Create iterator for MIDI files
117            wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red")
118
119            # Iterate over all files under musician folders
120            for ii, wav_file_path in enumerate(wav_files_path):
121                # Create a flag for train/validation
122                prob = random.random()
123                if ii == 1:
124                    anchor = True
125                    flag_train = True
126                    flag_valid = True
127                elif prob <= 0.8:
128                    anchor = False
129                    flag_train = True
130                    flag_valid = False
131                else:
132                    anchor = False
133                    flag_train = False
134                    flag_valid = True
135
136                data.append((wav_file_path, label, anchor, flag_train, flag_valid))
137
138        # Create a dataframe that contains all the wav files and their labels
139        self._df = pd.DataFrame(
140            data, columns=["wav_file_path", "label", "anchor", "train", "valid"]
141        )
142
143        self._df.to_pickle(self._df_path)

Generate a dataframe that has the training and validation data.

def get_anchors(self) -> dict:
224    def get_anchors(self) -> dict:
225        """Retrieve all anchors from the dataset."""
226
227        my_dict = {}
228
229        # Extract the selected info
230        df_anc = self._df[self._df["anchor"]]
231
232        for ii, row in df_anc.iterrows():
233            # Extract row information
234            anc_path_ = row["wav_file_path"]
235            anc_label = row["label"]
236
237            # Load the torchaudio waveform
238            waveform_anc, sample_rate_anc = torchaudio.load(anc_path_)
239
240            # Resample to specified window
241            waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_)
242
243            # Convert to spectrogram
244            if self.spectogram:
245                waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate)
246
247            my_dict[anc_label] = waveform_anc
248
249        return my_dict

Retrieve all anchors from the dataset.

def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
251    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
252        # Conversion to 44.1kHz if not already
253        if sample_rate != self._new_sample_rate:
254            # Resample the data to match 44.1kHz sampling rate
255            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
256            print(f"Waveform Resampled: {path}")
257
258        # Crop out 30 seconds of data randomly from the music file
259        waveform = waveform.numpy()
260        num_channels, num_frames = waveform.shape
261        time_axis = np.arange(0, num_frames) / sample_rate
262
263        # Set window for 30 second interval
264        min_time_axis = 0
265        max_time_axis = 30
266        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
267        waveform = waveform[:, sel_time_axis]
268
269        # Adding debugging statement
270        if self.DEBUG:
271            torchaudio.save(
272                "test_dataset_resample.wav",
273                torch.tensor(waveform),
274                sample_rate=self._new_sample_rate,
275            )
276
277        return waveform
class DatasetMusicTest(typing.Generic[+T_co]):
280class DatasetMusicTest(Dataset):
281    def __init__(self, spectrogram: bool = False) -> None:
282        super().__init__()
283        """Generate a dataframe that is used to test the data."""
284
285        # Set path to source data
286        self.source_dir = "data/test"
287
288        # Initialize variables
289        self._df = pd.DataFrame
290        self._df_path = "save/test_music_df.pkl"
291        self._new_sample_rate = 44100
292        self.spectogram = spectrogram
293
294        if os.path.isfile(self._df_path):
295            self._df = pd.read_pickle(self._df_path)
296        else:
297            self.generate_df()
298
299    def generate_df(self):
300        data = []
301
302        # Create iterator for MIDI files
303        wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red")
304
305        # Iterate over all files under musician folders
306        for ii, wav_file_path in enumerate(wav_files_path):
307            data.append(wav_file_path)
308
309        # Create a dataframe that contains all the wav files and their labels
310        self._df = pd.DataFrame(data, columns=["wav_file_path"])
311
312        self._df.to_pickle(self._df_path)
313
314    def __getitem__(self, index: int) -> torch.Tensor:
315        """Get an item from the dataset (i.e. data only)
316
317        Args:
318            index (int): Unused at the moment.
319
320        Returns:
321            tuple: data, label
322        """
323
324        # Get positive and negative data entries
325        indexed_data = self._df.loc[index]
326
327        # Extract the selected info
328        path = indexed_data["wav_file_path"]
329        name = str(path).split("/")[-1]
330
331        # Load the torchaudio waveform
332        waveform, sample_rate = torchaudio.load(path)
333        waveform = self.resample(waveform, sample_rate, path)
334
335        # Check shape of waveform
336        if waveform.shape != (2, 1323001):
337            print("Invalid shape detected!")
338
339        if self.spectogram:
340            waveform = get_spectrogram(waveform, self._new_sample_rate)
341
342        return waveform, name
343
344    def __len__(self):
345        return len(self._df)
346
347    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
348        # Conversion to 44.1kHz if not already
349        if sample_rate != self._new_sample_rate:
350            # Resample the data to match 44.1kHz sampling rate
351            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
352            print(f"Waveform Resampled: {path}")
353
354        # Crop out 30 seconds of data randomly from the music file
355        waveform = waveform.numpy()
356        num_channels, num_frames = waveform.shape
357        time_axis = np.arange(0, num_frames) / sample_rate
358
359        # Set window for 30 second interval
360        min_time_axis = 0
361        max_time_axis = 30
362        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
363        waveform = waveform[:, sel_time_axis]
364
365        return waveform

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

DatasetMusicTest(spectrogram: bool = False)
281    def __init__(self, spectrogram: bool = False) -> None:
282        super().__init__()
283        """Generate a dataframe that is used to test the data."""
284
285        # Set path to source data
286        self.source_dir = "data/test"
287
288        # Initialize variables
289        self._df = pd.DataFrame
290        self._df_path = "save/test_music_df.pkl"
291        self._new_sample_rate = 44100
292        self.spectogram = spectrogram
293
294        if os.path.isfile(self._df_path):
295            self._df = pd.read_pickle(self._df_path)
296        else:
297            self.generate_df()
source_dir
spectogram
def generate_df(self):
299    def generate_df(self):
300        data = []
301
302        # Create iterator for MIDI files
303        wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red")
304
305        # Iterate over all files under musician folders
306        for ii, wav_file_path in enumerate(wav_files_path):
307            data.append(wav_file_path)
308
309        # Create a dataframe that contains all the wav files and their labels
310        self._df = pd.DataFrame(data, columns=["wav_file_path"])
311
312        self._df.to_pickle(self._df_path)
def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
347    def resample(self, waveform: torch.Tensor, sample_rate: int, path: str):
348        # Conversion to 44.1kHz if not already
349        if sample_rate != self._new_sample_rate:
350            # Resample the data to match 44.1kHz sampling rate
351            waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform)
352            print(f"Waveform Resampled: {path}")
353
354        # Crop out 30 seconds of data randomly from the music file
355        waveform = waveform.numpy()
356        num_channels, num_frames = waveform.shape
357        time_axis = np.arange(0, num_frames) / sample_rate
358
359        # Set window for 30 second interval
360        min_time_axis = 0
361        max_time_axis = 30
362        sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis)
363        waveform = waveform[:, sel_time_axis]
364
365        return waveform