src.helper
1# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0. 2import matplotlib 3import matplotlib.pyplot as plt 4import numpy as np 5import torch 6from mido import MidiFile 7 8from src.data import DatasetMusic 9 10 11[width, height] = matplotlib.rcParams["figure.figsize"] 12if width < 10: 13 matplotlib.rcParams["figure.figsize"] = [width * 2.5, height] 14 15 16def print_stats(waveform: torch.Tensor, sample_rate: int | None = None, src: str | None = None): 17 """Prints out general stats. 18 19 Args: 20 waveform (tensor): Audio waveform 21 sample_rate (int, optional): Sample rate. Defaults to None. 22 src (path, optional): Path of file. Defaults to None. 23 """ 24 if src: 25 print("-" * 10) 26 print("Source:", src) 27 print("-" * 10) 28 if sample_rate: 29 print("Sample Rate:", sample_rate) 30 print("Shape:", tuple(waveform.shape)) 31 print("Dtype:", waveform.dtype) 32 print(f" - Max: {waveform.max().item():6.3f}") 33 print(f" - Min: {waveform.min().item():6.3f}") 34 print(f" - Mean: {waveform.mean().item():6.3f}") 35 print(f" - Std Dev: {waveform.std().item():6.3f}") 36 print() 37 print(waveform) 38 print() 39 40 41def print_midi(path: str): 42 """Prints out general MIDI information. 43 44 Args: 45 path (_type_): _description_ 46 """ 47 mid = MidiFile(path, clip=True) 48 print(mid) 49 50 51def plot_waveform( 52 waveform: torch.Tensor, 53 sample_rate: int, 54 title: str = "Waveform", 55 xlim: int | None = None, 56 ylim: int | None = None, 57): 58 """_summary_ 59 60 Args: 61 waveform (torch.Tensor): Audio waveform. 62 sample_rate (int): Sample rate. 63 title (str, optional): Title. Defaults to "Waveform". 64 xlim (int, optional): Limits. Defaults to None. 65 ylim (int, optional): Limits. Defaults to None. 66 """ 67 waveform = waveform.numpy() 68 69 num_channels, num_frames = waveform.shape 70 time_axis = np.arange(0, num_frames) / sample_rate 71 72 figure, axes = plt.subplots(num_channels, 1) 73 if num_channels == 1: 74 axes = [axes] 75 for c in range(num_channels): 76 axes[c].plot(time_axis, waveform[c], linewidth=1) 77 axes[c].grid(True) 78 if num_channels > 1: 79 axes[c].set_ylabel(f"Channel {c+1}") 80 if xlim: 81 axes[c].set_xlim(xlim) 82 if ylim: 83 axes[c].set_ylim(ylim) 84 figure.suptitle(title) 85 plt.show(block=False) 86 87 88def plot_specgram( 89 waveform: torch.Tensor, 90 sample_rate: int, 91 title: str = "Spectrogram", 92 xlim: int | None = None, 93): 94 """Plots the spectogram as 2 channels. 95 96 Args: 97 waveform (torch.Tensor): Audio waveform. 98 sample_rate (int): Sample rate. 99 title (str, optional): Title. Defaults to "Spectrogram". 100 xlim (int, optional): Limit. Defaults to None. 101 """ 102 waveform = waveform.numpy() 103 104 num_channels, num_frames = waveform.shape 105 106 figure, axes = plt.subplots(num_channels, 1) 107 if num_channels == 1: 108 axes = [axes] 109 for c in range(num_channels): 110 sliced_waveform = waveform[c] 111 axes[c].specgram(sliced_waveform, Fs=sample_rate) 112 if num_channels > 1: 113 axes[c].set_ylabel(f"Channel {c+1}") 114 if xlim: 115 axes[c].set_xlim(xlim) 116 117 figure.suptitle(title) 118 plt.show(block=False) 119 120 121def plot_music(): 122 """Plotting function for debugging and checking purposes.""" 123 124 # Template from: https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor 125 126 dataset = iter(DatasetMusic()) 127 128 sample_rate = 44100 129 130 waveform_anc, waveform_pos, waveform_neg, _, _, _ = next(dataset) 131 print_stats(waveform_anc, sample_rate=sample_rate) 132 plot_waveform(waveform_anc, sample_rate) 133 plot_specgram(waveform_anc, sample_rate) 134 135 136if __name__ == "__main__": 137 plot_music() 138 print("done")
def
print_stats( waveform: torch.Tensor, sample_rate: int | None = None, src: str | None = None):
17def print_stats(waveform: torch.Tensor, sample_rate: int | None = None, src: str | None = None): 18 """Prints out general stats. 19 20 Args: 21 waveform (tensor): Audio waveform 22 sample_rate (int, optional): Sample rate. Defaults to None. 23 src (path, optional): Path of file. Defaults to None. 24 """ 25 if src: 26 print("-" * 10) 27 print("Source:", src) 28 print("-" * 10) 29 if sample_rate: 30 print("Sample Rate:", sample_rate) 31 print("Shape:", tuple(waveform.shape)) 32 print("Dtype:", waveform.dtype) 33 print(f" - Max: {waveform.max().item():6.3f}") 34 print(f" - Min: {waveform.min().item():6.3f}") 35 print(f" - Mean: {waveform.mean().item():6.3f}") 36 print(f" - Std Dev: {waveform.std().item():6.3f}") 37 print() 38 print(waveform) 39 print()
Prints out general stats.
Arguments:
- waveform (tensor): Audio waveform
- sample_rate (int, optional): Sample rate. Defaults to None.
- src (path, optional): Path of file. Defaults to None.
def
print_midi(path: str):
42def print_midi(path: str): 43 """Prints out general MIDI information. 44 45 Args: 46 path (_type_): _description_ 47 """ 48 mid = MidiFile(path, clip=True) 49 print(mid)
Prints out general MIDI information.
Arguments:
- path (_type_): _description_
def
plot_waveform( waveform: torch.Tensor, sample_rate: int, title: str = 'Waveform', xlim: int | None = None, ylim: int | None = None):
52def plot_waveform( 53 waveform: torch.Tensor, 54 sample_rate: int, 55 title: str = "Waveform", 56 xlim: int | None = None, 57 ylim: int | None = None, 58): 59 """_summary_ 60 61 Args: 62 waveform (torch.Tensor): Audio waveform. 63 sample_rate (int): Sample rate. 64 title (str, optional): Title. Defaults to "Waveform". 65 xlim (int, optional): Limits. Defaults to None. 66 ylim (int, optional): Limits. Defaults to None. 67 """ 68 waveform = waveform.numpy() 69 70 num_channels, num_frames = waveform.shape 71 time_axis = np.arange(0, num_frames) / sample_rate 72 73 figure, axes = plt.subplots(num_channels, 1) 74 if num_channels == 1: 75 axes = [axes] 76 for c in range(num_channels): 77 axes[c].plot(time_axis, waveform[c], linewidth=1) 78 axes[c].grid(True) 79 if num_channels > 1: 80 axes[c].set_ylabel(f"Channel {c+1}") 81 if xlim: 82 axes[c].set_xlim(xlim) 83 if ylim: 84 axes[c].set_ylim(ylim) 85 figure.suptitle(title) 86 plt.show(block=False)
_summary_
Arguments:
- waveform (torch.Tensor): Audio waveform.
- sample_rate (int): Sample rate.
- title (str, optional): Title. Defaults to "Waveform".
- xlim (int, optional): Limits. Defaults to None.
- ylim (int, optional): Limits. Defaults to None.
def
plot_specgram( waveform: torch.Tensor, sample_rate: int, title: str = 'Spectrogram', xlim: int | None = None):
89def plot_specgram( 90 waveform: torch.Tensor, 91 sample_rate: int, 92 title: str = "Spectrogram", 93 xlim: int | None = None, 94): 95 """Plots the spectogram as 2 channels. 96 97 Args: 98 waveform (torch.Tensor): Audio waveform. 99 sample_rate (int): Sample rate. 100 title (str, optional): Title. Defaults to "Spectrogram". 101 xlim (int, optional): Limit. Defaults to None. 102 """ 103 waveform = waveform.numpy() 104 105 num_channels, num_frames = waveform.shape 106 107 figure, axes = plt.subplots(num_channels, 1) 108 if num_channels == 1: 109 axes = [axes] 110 for c in range(num_channels): 111 sliced_waveform = waveform[c] 112 axes[c].specgram(sliced_waveform, Fs=sample_rate) 113 if num_channels > 1: 114 axes[c].set_ylabel(f"Channel {c+1}") 115 if xlim: 116 axes[c].set_xlim(xlim) 117 118 figure.suptitle(title) 119 plt.show(block=False)
Plots the spectogram as 2 channels.
Arguments:
- waveform (torch.Tensor): Audio waveform.
- sample_rate (int): Sample rate.
- title (str, optional): Title. Defaults to "Spectrogram".
- xlim (int, optional): Limit. Defaults to None.
def
plot_music():
122def plot_music(): 123 """Plotting function for debugging and checking purposes.""" 124 125 # Template from: https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor 126 127 dataset = iter(DatasetMusic()) 128 129 sample_rate = 44100 130 131 waveform_anc, waveform_pos, waveform_neg, _, _, _ = next(dataset) 132 print_stats(waveform_anc, sample_rate=sample_rate) 133 plot_waveform(waveform_anc, sample_rate) 134 plot_specgram(waveform_anc, sample_rate)
Plotting function for debugging and checking purposes.