src.data
1import os 2import random 3from glob import glob 4from random import choice 5 6import numpy as np 7import pandas as pd 8import torch 9import torchaudio 10from matplotlib import pyplot as plt 11from torch.utils.data import DataLoader 12from torch.utils.data import Dataset 13from torchaudio import transforms 14from tqdm import tqdm 15 16 17# Set random seed 18random.seed(10) 19 20 21def get_spectrogram( 22 waveform: torch.Tensor, 23 sample_rate: int, 24 n_mels=64, 25 n_fft=1024, 26 hop_len=None, 27 debug: bool = False, 28) -> torch.TensorType: 29 """This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image. This is easier to work with than the raw audio tensor format. 30 31 Args: 32 waveform (torch.Tensor): Audio data 33 sample_rate (int): Sample rate (should be 44.1kHz) 34 n_mels (int, optional): _description_. Defaults to 64. 35 n_fft (int, optional): _description_. Defaults to 1024. 36 hop_len (_type_, optional): _description_. Defaults to None. 37 38 Returns: 39 torch.TensorType: Tensor of shape 2x64x19654 40 """ 41 # Cite: https://towardsdatascience.com/audio-deep-learning-made-simple-sound-classification-step-by-step-cebc936bbe5 42 43 top_db = 80 44 45 # spec has shape [channel, n_mels, time], where channel is mono, stereo etc 46 spectrogram = transforms.MelSpectrogram( 47 sample_rate, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels 48 )(torch.Tensor(waveform)) 49 50 # Convert to decibels 51 spectrogram = transforms.AmplitudeToDB(top_db=top_db)(spectrogram) 52 53 check1 = spectrogram.shape[0] == 2 54 check2 = spectrogram.shape[1] == 64 55 check3 = spectrogram.shape[2] == 2584 56 57 if (check1 and check2 and check3) is False: 58 raise ValueError("Invalid sizing for Mel Spectogram!") 59 60 if debug: 61 plt.figure() 62 plt.imshow(spectrogram.log2()[0, :, :].numpy(), cmap="viridis") 63 plt.close() 64 65 return spectrogram 66 67 68class DatasetMusic(Dataset): 69 def __init__( 70 self, 71 train: bool = True, 72 spectrogram: bool = False, 73 ) -> None: 74 super().__init__() 75 """Generate a dataframe that specifies the training and validation split of the data. 76 """ 77 78 # Set path to source data 79 self.source_dir = "data/musicians" 80 81 # Initialize variables 82 self._df = pd.DataFrame 83 self._df_path = "save/music_df.pkl" 84 self._new_sample_rate = 44100 85 self._train = train 86 self.spectogram = spectrogram 87 88 self.N_ANCHOR = 1 89 self.N_POSITIVE = 1 90 self.N_NEGATIVE = 1 91 self.DEBUG = False 92 93 if os.path.isfile(self._df_path): 94 self._df = pd.read_pickle(self._df_path) 95 else: 96 self.generate_df() 97 98 def generate_df(self): 99 """Generate a dataframe that has the training and validation data.""" 100 101 data = [] 102 103 # Create iterator for musician folders 104 parent_folders = os.listdir(self.source_dir) 105 musician_folders = tqdm(parent_folders, colour="green") 106 107 # Iterate over all musician folders 108 for label in musician_folders: 109 musician_folders.set_description(f"Creating Dataset: {label}") 110 111 # Ignore irrelevant files 112 if label == ".DS_Store": 113 continue # this is a MacOSX generated file 114 115 # Create iterator for MIDI files 116 wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red") 117 118 # Iterate over all files under musician folders 119 for ii, wav_file_path in enumerate(wav_files_path): 120 # Create a flag for train/validation 121 prob = random.random() 122 if ii == 1: 123 anchor = True 124 flag_train = True 125 flag_valid = True 126 elif prob <= 0.8: 127 anchor = False 128 flag_train = True 129 flag_valid = False 130 else: 131 anchor = False 132 flag_train = False 133 flag_valid = True 134 135 data.append((wav_file_path, label, anchor, flag_train, flag_valid)) 136 137 # Create a dataframe that contains all the wav files and their labels 138 self._df = pd.DataFrame( 139 data, columns=["wav_file_path", "label", "anchor", "train", "valid"] 140 ) 141 142 self._df.to_pickle(self._df_path) 143 144 def __getitem__(self, index: int) -> tuple: 145 """Get an item from the dataset (i.e. data and label) 146 147 Args: 148 index (int): Unused at the moment. 149 train (bool, optional): The training flag. Defaults to True. 150 151 Returns: 152 tuple: data, label 153 """ 154 # Define musicians to choose from 155 musicians = ["Bach", "Beethoven", "Brahms", "Schubert"] 156 157 # Randomly select a musician 158 anchor_musician = choice(musicians) 159 160 # Filter on the selected musician 161 filter_positive = self._df["label"] == anchor_musician 162 filter_negative = self._df["label"] != anchor_musician 163 filter_training = self._df["train"] == True 164 filter_validate = self._df["valid"] == True 165 filter_anchor = self._df["anchor"] == True 166 167 # Refine based on training vs validation 168 if self._train is True: 169 filter_positive = filter_positive & filter_training 170 filter_negative = filter_negative & filter_training 171 if self._train is False: 172 filter_positive = filter_positive & filter_validate 173 filter_negative = filter_negative & filter_validate 174 175 # Get positive and negative data entries 176 filter_data_anc = self._df[filter_positive & filter_anchor] 177 filter_data_pos = self._df[filter_positive & ~filter_anchor] 178 filter_data_neg = self._df[filter_negative] 179 180 # Extract the selected info 181 df_anc = filter_data_anc.sample(n=self.N_ANCHOR).reset_index(drop=True) 182 df_pos = filter_data_pos.sample(n=self.N_POSITIVE).reset_index(drop=True) 183 df_neg = filter_data_neg.sample(n=self.N_NEGATIVE).reset_index(drop=True) 184 185 anc_path_ = df_anc["wav_file_path"][0] 186 pos_path_ = df_pos["wav_file_path"][0] 187 neg_path_ = df_neg["wav_file_path"][0] 188 189 anc_label = df_anc["label"][0] 190 pos_label = df_pos["label"][0] 191 neg_label = df_neg["label"][0] 192 193 # Load the torchaudio waveform 194 waveform_anc, sample_rate_anc = torchaudio.load(anc_path_) 195 waveform_pos, sample_rate_pos = torchaudio.load(pos_path_) 196 waveform_neg, sample_rate_neg = torchaudio.load(neg_path_) 197 198 waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_) 199 waveform_pos = self.resample(waveform_pos, sample_rate_pos, pos_path_) 200 waveform_neg = self.resample(waveform_neg, sample_rate_neg, neg_path_) 201 202 # Check shape of waveform 203 if waveform_anc.shape != (2, 1323001): 204 print("Invalid shape detected!") 205 if waveform_pos.shape != (2, 1323001): 206 print("Invalid shape detected!") 207 if waveform_neg.shape != (2, 1323001): 208 print("Invalid shape detected!") 209 210 if self.spectogram: 211 waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate) 212 waveform_pos = get_spectrogram(waveform_pos, self._new_sample_rate) 213 waveform_neg = get_spectrogram(waveform_neg, self._new_sample_rate) 214 215 return waveform_anc, waveform_pos, waveform_neg, anc_label, pos_label, neg_label 216 217 def __len__(self): 218 if self._train: 219 return sum(self._df["train"] == True) 220 else: 221 return sum(self._df["valid"] == True) 222 223 def get_anchors(self) -> dict: 224 """Retrieve all anchors from the dataset.""" 225 226 my_dict = {} 227 228 # Extract the selected info 229 df_anc = self._df[self._df["anchor"]] 230 231 for ii, row in df_anc.iterrows(): 232 # Extract row information 233 anc_path_ = row["wav_file_path"] 234 anc_label = row["label"] 235 236 # Load the torchaudio waveform 237 waveform_anc, sample_rate_anc = torchaudio.load(anc_path_) 238 239 # Resample to specified window 240 waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_) 241 242 # Convert to spectrogram 243 if self.spectogram: 244 waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate) 245 246 my_dict[anc_label] = waveform_anc 247 248 return my_dict 249 250 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 251 # Conversion to 44.1kHz if not already 252 if sample_rate != self._new_sample_rate: 253 # Resample the data to match 44.1kHz sampling rate 254 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 255 print(f"Waveform Resampled: {path}") 256 257 # Crop out 30 seconds of data randomly from the music file 258 waveform = waveform.numpy() 259 num_channels, num_frames = waveform.shape 260 time_axis = np.arange(0, num_frames) / sample_rate 261 262 # Set window for 30 second interval 263 min_time_axis = 0 264 max_time_axis = 30 265 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 266 waveform = waveform[:, sel_time_axis] 267 268 # Adding debugging statement 269 if self.DEBUG: 270 torchaudio.save( 271 "test_dataset_resample.wav", 272 torch.tensor(waveform), 273 sample_rate=self._new_sample_rate, 274 ) 275 276 return waveform 277 278 279class DatasetMusicTest(Dataset): 280 def __init__(self, spectrogram: bool = False) -> None: 281 super().__init__() 282 """Generate a dataframe that is used to test the data.""" 283 284 # Set path to source data 285 self.source_dir = "data/test" 286 287 # Initialize variables 288 self._df = pd.DataFrame 289 self._df_path = "save/test_music_df.pkl" 290 self._new_sample_rate = 44100 291 self.spectogram = spectrogram 292 293 if os.path.isfile(self._df_path): 294 self._df = pd.read_pickle(self._df_path) 295 else: 296 self.generate_df() 297 298 def generate_df(self): 299 data = [] 300 301 # Create iterator for MIDI files 302 wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red") 303 304 # Iterate over all files under musician folders 305 for ii, wav_file_path in enumerate(wav_files_path): 306 data.append(wav_file_path) 307 308 # Create a dataframe that contains all the wav files and their labels 309 self._df = pd.DataFrame(data, columns=["wav_file_path"]) 310 311 self._df.to_pickle(self._df_path) 312 313 def __getitem__(self, index: int) -> torch.Tensor: 314 """Get an item from the dataset (i.e. data only) 315 316 Args: 317 index (int): Unused at the moment. 318 319 Returns: 320 tuple: data, label 321 """ 322 323 # Get positive and negative data entries 324 indexed_data = self._df.loc[index] 325 326 # Extract the selected info 327 path = indexed_data["wav_file_path"] 328 name = str(path).split("/")[-1] 329 330 # Load the torchaudio waveform 331 waveform, sample_rate = torchaudio.load(path) 332 waveform = self.resample(waveform, sample_rate, path) 333 334 # Check shape of waveform 335 if waveform.shape != (2, 1323001): 336 print("Invalid shape detected!") 337 338 if self.spectogram: 339 waveform = get_spectrogram(waveform, self._new_sample_rate) 340 341 return waveform, name 342 343 def __len__(self): 344 return len(self._df) 345 346 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 347 # Conversion to 44.1kHz if not already 348 if sample_rate != self._new_sample_rate: 349 # Resample the data to match 44.1kHz sampling rate 350 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 351 print(f"Waveform Resampled: {path}") 352 353 # Crop out 30 seconds of data randomly from the music file 354 waveform = waveform.numpy() 355 num_channels, num_frames = waveform.shape 356 time_axis = np.arange(0, num_frames) / sample_rate 357 358 # Set window for 30 second interval 359 min_time_axis = 0 360 max_time_axis = 30 361 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 362 waveform = waveform[:, sel_time_axis] 363 364 return waveform 365 366 367if __name__ == "__main__": 368 # Setup training/validation dataset 369 dataset_music_train = DatasetMusic(train=True) 370 dataset_music_valid = DatasetMusic(train=False) 371 372 # Setup training/validation dataset 373 dataset_music_train.spectogram = True 374 dataset_music_valid.spectogram = True 375 376 # Convert to iterators 377 data_train = iter(dataset_music_train) 378 data_valid = iter(dataset_music_valid) 379 380 # Test training iterator 381 for ii in range(0, 10): 382 data_train.__next__() 383 384 # Test validation iterator 385 for ii in range(0, 10): 386 data_valid.__next__() 387 388 # Create dataloaders for training/validation 389 dataloader_valid = DataLoader( 390 dataset_music_train, 391 batch_size=4, 392 shuffle=True, 393 ) 394 395 dataloader_valid = DataLoader( 396 dataset_music_valid, 397 batch_size=4, 398 shuffle=True, 399 )
22def get_spectrogram( 23 waveform: torch.Tensor, 24 sample_rate: int, 25 n_mels=64, 26 n_fft=1024, 27 hop_len=None, 28 debug: bool = False, 29) -> torch.TensorType: 30 """This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image. This is easier to work with than the raw audio tensor format. 31 32 Args: 33 waveform (torch.Tensor): Audio data 34 sample_rate (int): Sample rate (should be 44.1kHz) 35 n_mels (int, optional): _description_. Defaults to 64. 36 n_fft (int, optional): _description_. Defaults to 1024. 37 hop_len (_type_, optional): _description_. Defaults to None. 38 39 Returns: 40 torch.TensorType: Tensor of shape 2x64x19654 41 """ 42 # Cite: https://towardsdatascience.com/audio-deep-learning-made-simple-sound-classification-step-by-step-cebc936bbe5 43 44 top_db = 80 45 46 # spec has shape [channel, n_mels, time], where channel is mono, stereo etc 47 spectrogram = transforms.MelSpectrogram( 48 sample_rate, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels 49 )(torch.Tensor(waveform)) 50 51 # Convert to decibels 52 spectrogram = transforms.AmplitudeToDB(top_db=top_db)(spectrogram) 53 54 check1 = spectrogram.shape[0] == 2 55 check2 = spectrogram.shape[1] == 64 56 check3 = spectrogram.shape[2] == 2584 57 58 if (check1 and check2 and check3) is False: 59 raise ValueError("Invalid sizing for Mel Spectogram!") 60 61 if debug: 62 plt.figure() 63 plt.imshow(spectrogram.log2()[0, :, :].numpy(), cmap="viridis") 64 plt.close() 65 66 return spectrogram
This converts the audio file into a tensor with 2 channels with length and width which can be represented as an image. This is easier to work with than the raw audio tensor format.
Arguments:
- waveform (torch.Tensor): Audio data
- sample_rate (int): Sample rate (should be 44.1kHz)
- n_mels (int, optional): _description_. Defaults to 64.
- n_fft (int, optional): _description_. Defaults to 1024.
- hop_len (_type_, optional): _description_. Defaults to None.
Returns:
torch.TensorType: Tensor of shape 2x64x19654
69class DatasetMusic(Dataset): 70 def __init__( 71 self, 72 train: bool = True, 73 spectrogram: bool = False, 74 ) -> None: 75 super().__init__() 76 """Generate a dataframe that specifies the training and validation split of the data. 77 """ 78 79 # Set path to source data 80 self.source_dir = "data/musicians" 81 82 # Initialize variables 83 self._df = pd.DataFrame 84 self._df_path = "save/music_df.pkl" 85 self._new_sample_rate = 44100 86 self._train = train 87 self.spectogram = spectrogram 88 89 self.N_ANCHOR = 1 90 self.N_POSITIVE = 1 91 self.N_NEGATIVE = 1 92 self.DEBUG = False 93 94 if os.path.isfile(self._df_path): 95 self._df = pd.read_pickle(self._df_path) 96 else: 97 self.generate_df() 98 99 def generate_df(self): 100 """Generate a dataframe that has the training and validation data.""" 101 102 data = [] 103 104 # Create iterator for musician folders 105 parent_folders = os.listdir(self.source_dir) 106 musician_folders = tqdm(parent_folders, colour="green") 107 108 # Iterate over all musician folders 109 for label in musician_folders: 110 musician_folders.set_description(f"Creating Dataset: {label}") 111 112 # Ignore irrelevant files 113 if label == ".DS_Store": 114 continue # this is a MacOSX generated file 115 116 # Create iterator for MIDI files 117 wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red") 118 119 # Iterate over all files under musician folders 120 for ii, wav_file_path in enumerate(wav_files_path): 121 # Create a flag for train/validation 122 prob = random.random() 123 if ii == 1: 124 anchor = True 125 flag_train = True 126 flag_valid = True 127 elif prob <= 0.8: 128 anchor = False 129 flag_train = True 130 flag_valid = False 131 else: 132 anchor = False 133 flag_train = False 134 flag_valid = True 135 136 data.append((wav_file_path, label, anchor, flag_train, flag_valid)) 137 138 # Create a dataframe that contains all the wav files and their labels 139 self._df = pd.DataFrame( 140 data, columns=["wav_file_path", "label", "anchor", "train", "valid"] 141 ) 142 143 self._df.to_pickle(self._df_path) 144 145 def __getitem__(self, index: int) -> tuple: 146 """Get an item from the dataset (i.e. data and label) 147 148 Args: 149 index (int): Unused at the moment. 150 train (bool, optional): The training flag. Defaults to True. 151 152 Returns: 153 tuple: data, label 154 """ 155 # Define musicians to choose from 156 musicians = ["Bach", "Beethoven", "Brahms", "Schubert"] 157 158 # Randomly select a musician 159 anchor_musician = choice(musicians) 160 161 # Filter on the selected musician 162 filter_positive = self._df["label"] == anchor_musician 163 filter_negative = self._df["label"] != anchor_musician 164 filter_training = self._df["train"] == True 165 filter_validate = self._df["valid"] == True 166 filter_anchor = self._df["anchor"] == True 167 168 # Refine based on training vs validation 169 if self._train is True: 170 filter_positive = filter_positive & filter_training 171 filter_negative = filter_negative & filter_training 172 if self._train is False: 173 filter_positive = filter_positive & filter_validate 174 filter_negative = filter_negative & filter_validate 175 176 # Get positive and negative data entries 177 filter_data_anc = self._df[filter_positive & filter_anchor] 178 filter_data_pos = self._df[filter_positive & ~filter_anchor] 179 filter_data_neg = self._df[filter_negative] 180 181 # Extract the selected info 182 df_anc = filter_data_anc.sample(n=self.N_ANCHOR).reset_index(drop=True) 183 df_pos = filter_data_pos.sample(n=self.N_POSITIVE).reset_index(drop=True) 184 df_neg = filter_data_neg.sample(n=self.N_NEGATIVE).reset_index(drop=True) 185 186 anc_path_ = df_anc["wav_file_path"][0] 187 pos_path_ = df_pos["wav_file_path"][0] 188 neg_path_ = df_neg["wav_file_path"][0] 189 190 anc_label = df_anc["label"][0] 191 pos_label = df_pos["label"][0] 192 neg_label = df_neg["label"][0] 193 194 # Load the torchaudio waveform 195 waveform_anc, sample_rate_anc = torchaudio.load(anc_path_) 196 waveform_pos, sample_rate_pos = torchaudio.load(pos_path_) 197 waveform_neg, sample_rate_neg = torchaudio.load(neg_path_) 198 199 waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_) 200 waveform_pos = self.resample(waveform_pos, sample_rate_pos, pos_path_) 201 waveform_neg = self.resample(waveform_neg, sample_rate_neg, neg_path_) 202 203 # Check shape of waveform 204 if waveform_anc.shape != (2, 1323001): 205 print("Invalid shape detected!") 206 if waveform_pos.shape != (2, 1323001): 207 print("Invalid shape detected!") 208 if waveform_neg.shape != (2, 1323001): 209 print("Invalid shape detected!") 210 211 if self.spectogram: 212 waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate) 213 waveform_pos = get_spectrogram(waveform_pos, self._new_sample_rate) 214 waveform_neg = get_spectrogram(waveform_neg, self._new_sample_rate) 215 216 return waveform_anc, waveform_pos, waveform_neg, anc_label, pos_label, neg_label 217 218 def __len__(self): 219 if self._train: 220 return sum(self._df["train"] == True) 221 else: 222 return sum(self._df["valid"] == True) 223 224 def get_anchors(self) -> dict: 225 """Retrieve all anchors from the dataset.""" 226 227 my_dict = {} 228 229 # Extract the selected info 230 df_anc = self._df[self._df["anchor"]] 231 232 for ii, row in df_anc.iterrows(): 233 # Extract row information 234 anc_path_ = row["wav_file_path"] 235 anc_label = row["label"] 236 237 # Load the torchaudio waveform 238 waveform_anc, sample_rate_anc = torchaudio.load(anc_path_) 239 240 # Resample to specified window 241 waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_) 242 243 # Convert to spectrogram 244 if self.spectogram: 245 waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate) 246 247 my_dict[anc_label] = waveform_anc 248 249 return my_dict 250 251 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 252 # Conversion to 44.1kHz if not already 253 if sample_rate != self._new_sample_rate: 254 # Resample the data to match 44.1kHz sampling rate 255 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 256 print(f"Waveform Resampled: {path}") 257 258 # Crop out 30 seconds of data randomly from the music file 259 waveform = waveform.numpy() 260 num_channels, num_frames = waveform.shape 261 time_axis = np.arange(0, num_frames) / sample_rate 262 263 # Set window for 30 second interval 264 min_time_axis = 0 265 max_time_axis = 30 266 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 267 waveform = waveform[:, sel_time_axis] 268 269 # Adding debugging statement 270 if self.DEBUG: 271 torchaudio.save( 272 "test_dataset_resample.wav", 273 torch.tensor(waveform), 274 sample_rate=self._new_sample_rate, 275 ) 276 277 return waveform
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
70 def __init__( 71 self, 72 train: bool = True, 73 spectrogram: bool = False, 74 ) -> None: 75 super().__init__() 76 """Generate a dataframe that specifies the training and validation split of the data. 77 """ 78 79 # Set path to source data 80 self.source_dir = "data/musicians" 81 82 # Initialize variables 83 self._df = pd.DataFrame 84 self._df_path = "save/music_df.pkl" 85 self._new_sample_rate = 44100 86 self._train = train 87 self.spectogram = spectrogram 88 89 self.N_ANCHOR = 1 90 self.N_POSITIVE = 1 91 self.N_NEGATIVE = 1 92 self.DEBUG = False 93 94 if os.path.isfile(self._df_path): 95 self._df = pd.read_pickle(self._df_path) 96 else: 97 self.generate_df()
99 def generate_df(self): 100 """Generate a dataframe that has the training and validation data.""" 101 102 data = [] 103 104 # Create iterator for musician folders 105 parent_folders = os.listdir(self.source_dir) 106 musician_folders = tqdm(parent_folders, colour="green") 107 108 # Iterate over all musician folders 109 for label in musician_folders: 110 musician_folders.set_description(f"Creating Dataset: {label}") 111 112 # Ignore irrelevant files 113 if label == ".DS_Store": 114 continue # this is a MacOSX generated file 115 116 # Create iterator for MIDI files 117 wav_files_path = tqdm(glob(self.source_dir + "/" + label + "/*.wav"), colour="red") 118 119 # Iterate over all files under musician folders 120 for ii, wav_file_path in enumerate(wav_files_path): 121 # Create a flag for train/validation 122 prob = random.random() 123 if ii == 1: 124 anchor = True 125 flag_train = True 126 flag_valid = True 127 elif prob <= 0.8: 128 anchor = False 129 flag_train = True 130 flag_valid = False 131 else: 132 anchor = False 133 flag_train = False 134 flag_valid = True 135 136 data.append((wav_file_path, label, anchor, flag_train, flag_valid)) 137 138 # Create a dataframe that contains all the wav files and their labels 139 self._df = pd.DataFrame( 140 data, columns=["wav_file_path", "label", "anchor", "train", "valid"] 141 ) 142 143 self._df.to_pickle(self._df_path)
Generate a dataframe that has the training and validation data.
224 def get_anchors(self) -> dict: 225 """Retrieve all anchors from the dataset.""" 226 227 my_dict = {} 228 229 # Extract the selected info 230 df_anc = self._df[self._df["anchor"]] 231 232 for ii, row in df_anc.iterrows(): 233 # Extract row information 234 anc_path_ = row["wav_file_path"] 235 anc_label = row["label"] 236 237 # Load the torchaudio waveform 238 waveform_anc, sample_rate_anc = torchaudio.load(anc_path_) 239 240 # Resample to specified window 241 waveform_anc = self.resample(waveform_anc, sample_rate_anc, anc_path_) 242 243 # Convert to spectrogram 244 if self.spectogram: 245 waveform_anc = get_spectrogram(waveform_anc, self._new_sample_rate) 246 247 my_dict[anc_label] = waveform_anc 248 249 return my_dict
Retrieve all anchors from the dataset.
251 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 252 # Conversion to 44.1kHz if not already 253 if sample_rate != self._new_sample_rate: 254 # Resample the data to match 44.1kHz sampling rate 255 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 256 print(f"Waveform Resampled: {path}") 257 258 # Crop out 30 seconds of data randomly from the music file 259 waveform = waveform.numpy() 260 num_channels, num_frames = waveform.shape 261 time_axis = np.arange(0, num_frames) / sample_rate 262 263 # Set window for 30 second interval 264 min_time_axis = 0 265 max_time_axis = 30 266 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 267 waveform = waveform[:, sel_time_axis] 268 269 # Adding debugging statement 270 if self.DEBUG: 271 torchaudio.save( 272 "test_dataset_resample.wav", 273 torch.tensor(waveform), 274 sample_rate=self._new_sample_rate, 275 ) 276 277 return waveform
280class DatasetMusicTest(Dataset): 281 def __init__(self, spectrogram: bool = False) -> None: 282 super().__init__() 283 """Generate a dataframe that is used to test the data.""" 284 285 # Set path to source data 286 self.source_dir = "data/test" 287 288 # Initialize variables 289 self._df = pd.DataFrame 290 self._df_path = "save/test_music_df.pkl" 291 self._new_sample_rate = 44100 292 self.spectogram = spectrogram 293 294 if os.path.isfile(self._df_path): 295 self._df = pd.read_pickle(self._df_path) 296 else: 297 self.generate_df() 298 299 def generate_df(self): 300 data = [] 301 302 # Create iterator for MIDI files 303 wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red") 304 305 # Iterate over all files under musician folders 306 for ii, wav_file_path in enumerate(wav_files_path): 307 data.append(wav_file_path) 308 309 # Create a dataframe that contains all the wav files and their labels 310 self._df = pd.DataFrame(data, columns=["wav_file_path"]) 311 312 self._df.to_pickle(self._df_path) 313 314 def __getitem__(self, index: int) -> torch.Tensor: 315 """Get an item from the dataset (i.e. data only) 316 317 Args: 318 index (int): Unused at the moment. 319 320 Returns: 321 tuple: data, label 322 """ 323 324 # Get positive and negative data entries 325 indexed_data = self._df.loc[index] 326 327 # Extract the selected info 328 path = indexed_data["wav_file_path"] 329 name = str(path).split("/")[-1] 330 331 # Load the torchaudio waveform 332 waveform, sample_rate = torchaudio.load(path) 333 waveform = self.resample(waveform, sample_rate, path) 334 335 # Check shape of waveform 336 if waveform.shape != (2, 1323001): 337 print("Invalid shape detected!") 338 339 if self.spectogram: 340 waveform = get_spectrogram(waveform, self._new_sample_rate) 341 342 return waveform, name 343 344 def __len__(self): 345 return len(self._df) 346 347 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 348 # Conversion to 44.1kHz if not already 349 if sample_rate != self._new_sample_rate: 350 # Resample the data to match 44.1kHz sampling rate 351 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 352 print(f"Waveform Resampled: {path}") 353 354 # Crop out 30 seconds of data randomly from the music file 355 waveform = waveform.numpy() 356 num_channels, num_frames = waveform.shape 357 time_axis = np.arange(0, num_frames) / sample_rate 358 359 # Set window for 30 second interval 360 min_time_axis = 0 361 max_time_axis = 30 362 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 363 waveform = waveform[:, sel_time_axis] 364 365 return waveform
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
281 def __init__(self, spectrogram: bool = False) -> None: 282 super().__init__() 283 """Generate a dataframe that is used to test the data.""" 284 285 # Set path to source data 286 self.source_dir = "data/test" 287 288 # Initialize variables 289 self._df = pd.DataFrame 290 self._df_path = "save/test_music_df.pkl" 291 self._new_sample_rate = 44100 292 self.spectogram = spectrogram 293 294 if os.path.isfile(self._df_path): 295 self._df = pd.read_pickle(self._df_path) 296 else: 297 self.generate_df()
299 def generate_df(self): 300 data = [] 301 302 # Create iterator for MIDI files 303 wav_files_path = tqdm(glob(self.source_dir + "/*.wav"), colour="red") 304 305 # Iterate over all files under musician folders 306 for ii, wav_file_path in enumerate(wav_files_path): 307 data.append(wav_file_path) 308 309 # Create a dataframe that contains all the wav files and their labels 310 self._df = pd.DataFrame(data, columns=["wav_file_path"]) 311 312 self._df.to_pickle(self._df_path)
347 def resample(self, waveform: torch.Tensor, sample_rate: int, path: str): 348 # Conversion to 44.1kHz if not already 349 if sample_rate != self._new_sample_rate: 350 # Resample the data to match 44.1kHz sampling rate 351 waveform = torchaudio.transforms.Resample(sample_rate, self._new_sample_rate)(waveform) 352 print(f"Waveform Resampled: {path}") 353 354 # Crop out 30 seconds of data randomly from the music file 355 waveform = waveform.numpy() 356 num_channels, num_frames = waveform.shape 357 time_axis = np.arange(0, num_frames) / sample_rate 358 359 # Set window for 30 second interval 360 min_time_axis = 0 361 max_time_axis = 30 362 sel_time_axis = (time_axis >= min_time_axis) & (time_axis <= max_time_axis) 363 waveform = waveform[:, sel_time_axis] 364 365 return waveform