src.helper

  1# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
  2import matplotlib
  3import matplotlib.pyplot as plt
  4import numpy as np
  5import torch
  6from mido import MidiFile
  7
  8from src.data import DatasetMusic
  9
 10
 11[width, height] = matplotlib.rcParams["figure.figsize"]
 12if width < 10:
 13    matplotlib.rcParams["figure.figsize"] = [width * 2.5, height]
 14
 15
 16def print_stats(waveform: torch.Tensor, sample_rate: int | None = None, src: str | None = None):
 17    """Prints out general stats.
 18
 19    Args:
 20        waveform (tensor): Audio waveform
 21        sample_rate (int, optional): Sample rate. Defaults to None.
 22        src (path, optional): Path of file. Defaults to None.
 23    """
 24    if src:
 25        print("-" * 10)
 26        print("Source:", src)
 27        print("-" * 10)
 28    if sample_rate:
 29        print("Sample Rate:", sample_rate)
 30    print("Shape:", tuple(waveform.shape))
 31    print("Dtype:", waveform.dtype)
 32    print(f" - Max:     {waveform.max().item():6.3f}")
 33    print(f" - Min:     {waveform.min().item():6.3f}")
 34    print(f" - Mean:    {waveform.mean().item():6.3f}")
 35    print(f" - Std Dev: {waveform.std().item():6.3f}")
 36    print()
 37    print(waveform)
 38    print()
 39
 40
 41def print_midi(path: str):
 42    """Prints out general MIDI information.
 43
 44    Args:
 45        path (_type_): _description_
 46    """
 47    mid = MidiFile(path, clip=True)
 48    print(mid)
 49
 50
 51def plot_waveform(
 52    waveform: torch.Tensor,
 53    sample_rate: int,
 54    title: str = "Waveform",
 55    xlim: int | None = None,
 56    ylim: int | None = None,
 57):
 58    """_summary_
 59
 60    Args:
 61        waveform (torch.Tensor): Audio waveform.
 62        sample_rate (int): Sample rate.
 63        title (str, optional): Title. Defaults to "Waveform".
 64        xlim (int, optional): Limits. Defaults to None.
 65        ylim (int, optional): Limits. Defaults to None.
 66    """
 67    waveform = waveform.numpy()
 68
 69    num_channels, num_frames = waveform.shape
 70    time_axis = np.arange(0, num_frames) / sample_rate
 71
 72    figure, axes = plt.subplots(num_channels, 1)
 73    if num_channels == 1:
 74        axes = [axes]
 75    for c in range(num_channels):
 76        axes[c].plot(time_axis, waveform[c], linewidth=1)
 77        axes[c].grid(True)
 78        if num_channels > 1:
 79            axes[c].set_ylabel(f"Channel {c+1}")
 80        if xlim:
 81            axes[c].set_xlim(xlim)
 82        if ylim:
 83            axes[c].set_ylim(ylim)
 84    figure.suptitle(title)
 85    plt.show(block=False)
 86
 87
 88def plot_specgram(
 89    waveform: torch.Tensor,
 90    sample_rate: int,
 91    title: str = "Spectrogram",
 92    xlim: int | None = None,
 93):
 94    """Plots the spectogram as 2 channels.
 95
 96    Args:
 97        waveform (torch.Tensor): Audio waveform.
 98        sample_rate (int): Sample rate.
 99        title (str, optional): Title. Defaults to "Spectrogram".
100        xlim (int, optional): Limit. Defaults to None.
101    """
102    waveform = waveform.numpy()
103
104    num_channels, num_frames = waveform.shape
105
106    figure, axes = plt.subplots(num_channels, 1)
107    if num_channels == 1:
108        axes = [axes]
109    for c in range(num_channels):
110        sliced_waveform = waveform[c]
111        axes[c].specgram(sliced_waveform, Fs=sample_rate)
112        if num_channels > 1:
113            axes[c].set_ylabel(f"Channel {c+1}")
114        if xlim:
115            axes[c].set_xlim(xlim)
116
117    figure.suptitle(title)
118    plt.show(block=False)
119
120
121def plot_music():
122    """Plotting function for debugging and checking purposes."""
123
124    # Template from: https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor
125
126    dataset = iter(DatasetMusic())
127
128    sample_rate = 44100
129
130    waveform_anc, waveform_pos, waveform_neg, _, _, _ = next(dataset)
131    print_stats(waveform_anc, sample_rate=sample_rate)
132    plot_waveform(waveform_anc, sample_rate)
133    plot_specgram(waveform_anc, sample_rate)
134
135
136if __name__ == "__main__":
137    plot_music()
138    print("done")
def plot_waveform( waveform: torch.Tensor, sample_rate: int, title: str = 'Waveform', xlim: int | None = None, ylim: int | None = None):
52def plot_waveform(
53    waveform: torch.Tensor,
54    sample_rate: int,
55    title: str = "Waveform",
56    xlim: int | None = None,
57    ylim: int | None = None,
58):
59    """_summary_
60
61    Args:
62        waveform (torch.Tensor): Audio waveform.
63        sample_rate (int): Sample rate.
64        title (str, optional): Title. Defaults to "Waveform".
65        xlim (int, optional): Limits. Defaults to None.
66        ylim (int, optional): Limits. Defaults to None.
67    """
68    waveform = waveform.numpy()
69
70    num_channels, num_frames = waveform.shape
71    time_axis = np.arange(0, num_frames) / sample_rate
72
73    figure, axes = plt.subplots(num_channels, 1)
74    if num_channels == 1:
75        axes = [axes]
76    for c in range(num_channels):
77        axes[c].plot(time_axis, waveform[c], linewidth=1)
78        axes[c].grid(True)
79        if num_channels > 1:
80            axes[c].set_ylabel(f"Channel {c+1}")
81        if xlim:
82            axes[c].set_xlim(xlim)
83        if ylim:
84            axes[c].set_ylim(ylim)
85    figure.suptitle(title)
86    plt.show(block=False)

_summary_

Arguments:
  • waveform (torch.Tensor): Audio waveform.
  • sample_rate (int): Sample rate.
  • title (str, optional): Title. Defaults to "Waveform".
  • xlim (int, optional): Limits. Defaults to None.
  • ylim (int, optional): Limits. Defaults to None.
def plot_specgram( waveform: torch.Tensor, sample_rate: int, title: str = 'Spectrogram', xlim: int | None = None):
 89def plot_specgram(
 90    waveform: torch.Tensor,
 91    sample_rate: int,
 92    title: str = "Spectrogram",
 93    xlim: int | None = None,
 94):
 95    """Plots the spectogram as 2 channels.
 96
 97    Args:
 98        waveform (torch.Tensor): Audio waveform.
 99        sample_rate (int): Sample rate.
100        title (str, optional): Title. Defaults to "Spectrogram".
101        xlim (int, optional): Limit. Defaults to None.
102    """
103    waveform = waveform.numpy()
104
105    num_channels, num_frames = waveform.shape
106
107    figure, axes = plt.subplots(num_channels, 1)
108    if num_channels == 1:
109        axes = [axes]
110    for c in range(num_channels):
111        sliced_waveform = waveform[c]
112        axes[c].specgram(sliced_waveform, Fs=sample_rate)
113        if num_channels > 1:
114            axes[c].set_ylabel(f"Channel {c+1}")
115        if xlim:
116            axes[c].set_xlim(xlim)
117
118    figure.suptitle(title)
119    plt.show(block=False)

Plots the spectogram as 2 channels.

Arguments:
  • waveform (torch.Tensor): Audio waveform.
  • sample_rate (int): Sample rate.
  • title (str, optional): Title. Defaults to "Spectrogram".
  • xlim (int, optional): Limit. Defaults to None.
def plot_music():
122def plot_music():
123    """Plotting function for debugging and checking purposes."""
124
125    # Template from: https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor
126
127    dataset = iter(DatasetMusic())
128
129    sample_rate = 44100
130
131    waveform_anc, waveform_pos, waveform_neg, _, _, _ = next(dataset)
132    print_stats(waveform_anc, sample_rate=sample_rate)
133    plot_waveform(waveform_anc, sample_rate)
134    plot_specgram(waveform_anc, sample_rate)

Plotting function for debugging and checking purposes.