audio2midi 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audio2midi/__init__.py +0 -0
- audio2midi/basic_pitch_pitch_detector.py +783 -0
- audio2midi/crepe_pitch_detector.py +130 -0
- audio2midi/librosa_pitch_detector.py +153 -0
- audio2midi/melodia_pitch_detector.py +58 -0
- audio2midi/pop2piano.py +2604 -0
- audio2midi/py.typed +0 -0
- audio2midi/violin_pitch_detector.py +1281 -0
- audio2midi-0.1.0.dist-info/METADATA +100 -0
- audio2midi-0.1.0.dist-info/RECORD +11 -0
- audio2midi-0.1.0.dist-info/WHEEL +5 -0
@@ -0,0 +1,783 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import DefaultDict , List, Optional , Dict , Tuple , Callable
|
3
|
+
from librosa import load as librosa_load,midi_to_hz,hz_to_midi
|
4
|
+
from librosa.core import frames_to_time
|
5
|
+
from librosa.util import frame as librosa_util_frame
|
6
|
+
from pretty_midi_fix import PrettyMIDI , PitchBend , Instrument,Note
|
7
|
+
from scipy.signal.windows import gaussian
|
8
|
+
from scipy.signal import argrelmax
|
9
|
+
from huggingface_hub import hf_hub_download
|
10
|
+
from torch import nn, Tensor
|
11
|
+
import torch.nn.functional as F
|
12
|
+
import numpy as np
|
13
|
+
import torch
|
14
|
+
import math
|
15
|
+
|
16
|
+
from nnAudio.features import CQT2010v2
|
17
|
+
|
18
|
+
FFT_HOP = 256
|
19
|
+
AUDIO_SAMPLE_RATE = 22050
|
20
|
+
AUDIO_WINDOW_LENGTH = 2
|
21
|
+
NOTES_BINS_PER_SEMITONE = 1
|
22
|
+
CONTOURS_BINS_PER_SEMITONE = 3
|
23
|
+
ANNOTATIONS_BASE_FREQUENCY = 27.5
|
24
|
+
ANNOTATIONS_N_SEMITONES = 88
|
25
|
+
AUDIO_WINDOW_LENGTH = 2
|
26
|
+
MIDI_OFFSET = 21
|
27
|
+
N_PITCH_BEND_TICKS = 8192
|
28
|
+
MAX_FREQ_IDX = 87
|
29
|
+
N_OVERLAPPING_FRAMES = 30
|
30
|
+
ANNOTATIONS_FPS = AUDIO_SAMPLE_RATE // FFT_HOP
|
31
|
+
AUDIO_N_SAMPLES = AUDIO_SAMPLE_RATE * AUDIO_WINDOW_LENGTH - FFT_HOP
|
32
|
+
N_FFT = 8 * FFT_HOP
|
33
|
+
N_FREQ_BINS_NOTES = ANNOTATIONS_N_SEMITONES * NOTES_BINS_PER_SEMITONE
|
34
|
+
N_FREQ_BINS_CONTOURS = ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE
|
35
|
+
ANNOT_N_FRAMES = ANNOTATIONS_FPS * AUDIO_WINDOW_LENGTH
|
36
|
+
AUDIO_N_SAMPLES = AUDIO_SAMPLE_RATE * AUDIO_WINDOW_LENGTH - FFT_HOP
|
37
|
+
OVERLAP_LEN = N_OVERLAPPING_FRAMES * FFT_HOP
|
38
|
+
HOP_SIZE = AUDIO_N_SAMPLES - OVERLAP_LEN
|
39
|
+
MAX_N_SEMITONES = int(np.floor(12.0 * np.log2(0.5 * AUDIO_SAMPLE_RATE / ANNOTATIONS_BASE_FREQUENCY)))
|
40
|
+
|
41
|
+
def frame_with_pad(x: np.array, frame_length: int, hop_size: int) -> np.array:
|
42
|
+
"""
|
43
|
+
Extends librosa.util.frame with end padding if required, similar to
|
44
|
+
tf.signal.frame(pad_end=True).
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
framed_audio: tensor with shape (n_windows, AUDIO_N_SAMPLES)
|
48
|
+
"""
|
49
|
+
n_frames = int(np.ceil((x.shape[0] - frame_length) / hop_size)) + 1
|
50
|
+
n_pads = (n_frames - 1) * hop_size + frame_length - x.shape[0]
|
51
|
+
x = np.pad(x, (0, n_pads), mode="constant")
|
52
|
+
framed_audio = librosa_util_frame(x, frame_length=frame_length, hop_length=hop_size)
|
53
|
+
return framed_audio
|
54
|
+
|
55
|
+
|
56
|
+
def window_audio_file(audio_original: np.array, hop_size: int) -> Tuple[np.array, List[Dict[str, int]]]:
|
57
|
+
"""
|
58
|
+
Pad appropriately an audio file, and return as
|
59
|
+
windowed signal, with window length = AUDIO_N_SAMPLES
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
|
63
|
+
audio windowed into fixed length chunks
|
64
|
+
window_times: list of {'start':.., 'end':...} objects (times in seconds)
|
65
|
+
|
66
|
+
"""
|
67
|
+
audio_windowed = frame_with_pad(audio_original, AUDIO_N_SAMPLES, hop_size)
|
68
|
+
window_times = [
|
69
|
+
{
|
70
|
+
"start": t_start,
|
71
|
+
"end": t_start + (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE),
|
72
|
+
}
|
73
|
+
for t_start in np.arange(audio_windowed.shape[0]) * hop_size / AUDIO_SAMPLE_RATE
|
74
|
+
]
|
75
|
+
return audio_windowed, window_times
|
76
|
+
|
77
|
+
|
78
|
+
def get_audio_input(
|
79
|
+
audio_path: str, overlap_len: int, hop_size: int
|
80
|
+
) -> Tuple[Tensor, List[Dict[str, int]], int]:
|
81
|
+
"""
|
82
|
+
Read wave file (as mono), pad appropriately, and return as
|
83
|
+
windowed signal, with window length = AUDIO_N_SAMPLES
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
|
87
|
+
audio windowed into fixed length chunks
|
88
|
+
window_times: list of {'start':.., 'end':...} objects (times in seconds)
|
89
|
+
audio_original_length: int
|
90
|
+
length of original audio file, in frames, BEFORE padding.
|
91
|
+
|
92
|
+
"""
|
93
|
+
assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)
|
94
|
+
|
95
|
+
audio_original, _ = librosa_load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
|
96
|
+
|
97
|
+
original_length = audio_original.shape[0]
|
98
|
+
audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
|
99
|
+
audio_windowed, window_times = window_audio_file(audio_original, hop_size)
|
100
|
+
return audio_windowed, window_times, original_length
|
101
|
+
|
102
|
+
|
103
|
+
def unwrap_output(output: Tensor, audio_original_length: int, n_overlapping_frames: int) -> np.array:
|
104
|
+
"""Unwrap batched model predictions to a single matrix.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
output: array (n_batches, n_times_short, n_freqs)
|
108
|
+
audio_original_length: length of original audio signal (in samples)
|
109
|
+
n_overlapping_frames: number of overlapping frames in the output
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
array (n_times, n_freqs)
|
113
|
+
"""
|
114
|
+
raw_output = output.cpu().detach().numpy()
|
115
|
+
if len(raw_output.shape) != 3:
|
116
|
+
return None
|
117
|
+
|
118
|
+
n_olap = int(0.5 * n_overlapping_frames)
|
119
|
+
if n_olap > 0:
|
120
|
+
# remove half of the overlapping frames from beginning and end
|
121
|
+
raw_output = raw_output[:, n_olap:-n_olap, :]
|
122
|
+
|
123
|
+
output_shape = raw_output.shape
|
124
|
+
n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
|
125
|
+
unwrapped_output = raw_output.reshape(output_shape[0] * output_shape[1], output_shape[2])
|
126
|
+
return unwrapped_output[:n_output_frames_original, :] # trim to original audio length
|
127
|
+
|
128
|
+
|
129
|
+
|
130
|
+
|
131
|
+
def model_output_to_notes(
|
132
|
+
output: Dict[str, np.array],
|
133
|
+
onset_thresh: float,
|
134
|
+
frame_thresh: float,
|
135
|
+
infer_onsets: bool = True,
|
136
|
+
min_note_len: int = 11,
|
137
|
+
min_freq: Optional[float] = None,
|
138
|
+
max_freq: Optional[float] = None,
|
139
|
+
include_pitch_bends: bool = True,
|
140
|
+
multiple_pitch_bends: bool = False,
|
141
|
+
melodia_trick: bool = True,
|
142
|
+
midi_tempo: float = 120,
|
143
|
+
) -> PrettyMIDI:
|
144
|
+
"""Convert model output to MIDI
|
145
|
+
|
146
|
+
Args:
|
147
|
+
output: A dictionary with shape
|
148
|
+
{
|
149
|
+
'frame': array of shape (n_times, n_freqs),
|
150
|
+
'onset': array of shape (n_times, n_freqs),
|
151
|
+
'contour': array of shape (n_times, 3*n_freqs)
|
152
|
+
}
|
153
|
+
representing the output of the basic pitch model.
|
154
|
+
onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
|
155
|
+
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
|
156
|
+
min_note_len: The minimum allowed note length in frames.
|
157
|
+
min_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
|
158
|
+
max_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
|
159
|
+
include_pitch_bends: If True, include pitch bends.
|
160
|
+
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
|
161
|
+
melodia_trick: Use the melodia post-processing step.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
midi : PrettyMIDI object
|
165
|
+
note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
|
166
|
+
"""
|
167
|
+
frames = output["note"]
|
168
|
+
onsets = output["onset"]
|
169
|
+
contours = output["contour"]
|
170
|
+
|
171
|
+
estimated_notes = output_to_notes_polyphonic(
|
172
|
+
frames,
|
173
|
+
onsets,
|
174
|
+
onset_thresh=onset_thresh,
|
175
|
+
frame_thresh=frame_thresh,
|
176
|
+
infer_onsets=infer_onsets,
|
177
|
+
min_note_len=min_note_len,
|
178
|
+
min_freq=min_freq,
|
179
|
+
max_freq=max_freq,
|
180
|
+
melodia_trick=melodia_trick,
|
181
|
+
)
|
182
|
+
if include_pitch_bends:
|
183
|
+
estimated_notes_with_pitch_bend = get_pitch_bends(contours, estimated_notes)
|
184
|
+
else:
|
185
|
+
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
|
186
|
+
|
187
|
+
times_s = model_frames_to_time(contours.shape[0])
|
188
|
+
estimated_notes_time_seconds = [
|
189
|
+
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
|
190
|
+
]
|
191
|
+
|
192
|
+
return note_events_to_midi(estimated_notes_time_seconds, multiple_pitch_bends, midi_tempo)
|
193
|
+
|
194
|
+
|
195
|
+
def midi_pitch_to_contour_bin(pitch_midi: int) -> np.array:
|
196
|
+
"""Convert midi pitch to conrresponding index in contour matrix
|
197
|
+
|
198
|
+
Args:
|
199
|
+
pitch_midi: pitch in midi
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
index in contour matrix
|
203
|
+
|
204
|
+
"""
|
205
|
+
pitch_hz = midi_to_hz(pitch_midi)
|
206
|
+
return 12.0 * CONTOURS_BINS_PER_SEMITONE * np.log2(pitch_hz / ANNOTATIONS_BASE_FREQUENCY)
|
207
|
+
|
208
|
+
|
209
|
+
def get_pitch_bends(
|
210
|
+
contours: np.ndarray, note_events: List[Tuple[int, int, int, float]], n_bins_tolerance: int = 25
|
211
|
+
) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
|
212
|
+
"""Given note events and contours, estimate pitch bends per note.
|
213
|
+
Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
|
214
|
+
The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
|
215
|
+
the start and end times of each note event.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
contours: Matrix of estimated pitch contours
|
219
|
+
note_events: note event tuple
|
220
|
+
n_bins_tolerance: Pitch bend estimation range. Defaults to 25.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
note events with pitch bends
|
224
|
+
"""
|
225
|
+
window_length = n_bins_tolerance * 2 + 1
|
226
|
+
freq_gaussian = gaussian(window_length, std=5)
|
227
|
+
note_events_with_pitch_bends = []
|
228
|
+
for start_idx, end_idx, pitch_midi, amplitude in note_events:
|
229
|
+
freq_idx = int(np.round(midi_pitch_to_contour_bin(pitch_midi)))
|
230
|
+
freq_start_idx = np.max([freq_idx - n_bins_tolerance, 0])
|
231
|
+
freq_end_idx = np.min([N_FREQ_BINS_CONTOURS, freq_idx + n_bins_tolerance + 1])
|
232
|
+
|
233
|
+
pitch_bend_submatrix = (
|
234
|
+
contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
|
235
|
+
* freq_gaussian[
|
236
|
+
np.max([0, n_bins_tolerance - freq_idx]) : window_length
|
237
|
+
- np.max([0, freq_idx - (N_FREQ_BINS_CONTOURS - n_bins_tolerance - 1)])
|
238
|
+
]
|
239
|
+
)
|
240
|
+
pb_shift = n_bins_tolerance - np.max([0, n_bins_tolerance - freq_idx])
|
241
|
+
|
242
|
+
bends: Optional[List[int]] = list(
|
243
|
+
np.argmax(pitch_bend_submatrix, axis=1) - pb_shift
|
244
|
+
) # this is in units of 1/3 semitones
|
245
|
+
note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
|
246
|
+
return note_events_with_pitch_bends
|
247
|
+
|
248
|
+
|
249
|
+
def note_events_to_midi(
|
250
|
+
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
|
251
|
+
multiple_pitch_bends: bool = False,
|
252
|
+
midi_tempo: float = 120,
|
253
|
+
) -> PrettyMIDI:
|
254
|
+
"""Create a pretty_midi_fix object from note events
|
255
|
+
|
256
|
+
Args:
|
257
|
+
note_events : list of tuples [(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
|
258
|
+
where amplitude is a number between 0 and 1
|
259
|
+
multiple_pitch_bends : If True, allow overlapping notes to have pitch bends
|
260
|
+
Note: this will assign each pitch to its own midi instrument, as midi does not yet
|
261
|
+
support per-note pitch bends
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
PrettyMIDI() object
|
265
|
+
|
266
|
+
"""
|
267
|
+
mid = PrettyMIDI(initial_tempo=midi_tempo)
|
268
|
+
if not multiple_pitch_bends:
|
269
|
+
note_events_with_pitch_bends = drop_overlapping_pitch_bends(note_events_with_pitch_bends)
|
270
|
+
instruments: DefaultDict[int, Instrument] = defaultdict(
|
271
|
+
lambda: Instrument(program=40)
|
272
|
+
)
|
273
|
+
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
|
274
|
+
instrument = instruments[note_number] if multiple_pitch_bends else instruments[0]
|
275
|
+
note = Note(
|
276
|
+
velocity=int(np.round(127 * amplitude)),
|
277
|
+
pitch=note_number,
|
278
|
+
start=start_time,
|
279
|
+
end=end_time,
|
280
|
+
)
|
281
|
+
instrument.notes.append(note)
|
282
|
+
if not pitch_bend:
|
283
|
+
continue
|
284
|
+
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
|
285
|
+
pitch_bend_midi_ticks = np.round(np.array(pitch_bend) * 4096 / CONTOURS_BINS_PER_SEMITONE).astype(int)
|
286
|
+
# This supports pitch bends up to 2 semitones
|
287
|
+
# If we estimate pitch bends above/below 2 semitones, crop them here when adding them to the midi file
|
288
|
+
pitch_bend_midi_ticks[pitch_bend_midi_ticks > N_PITCH_BEND_TICKS - 1] = N_PITCH_BEND_TICKS - 1
|
289
|
+
pitch_bend_midi_ticks[pitch_bend_midi_ticks < -N_PITCH_BEND_TICKS] = -N_PITCH_BEND_TICKS
|
290
|
+
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend_midi_ticks):
|
291
|
+
instrument.pitch_bends.append(PitchBend(pb_midi, pb_time))
|
292
|
+
mid.instruments.extend(instruments.values())
|
293
|
+
|
294
|
+
return mid
|
295
|
+
|
296
|
+
|
297
|
+
def drop_overlapping_pitch_bends(
|
298
|
+
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]]
|
299
|
+
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
|
300
|
+
"""Drop pitch bends from any notes that overlap in time with another note"""
|
301
|
+
note_events = sorted(note_events_with_pitch_bends)
|
302
|
+
for i in range(len(note_events) - 1):
|
303
|
+
for j in range(i + 1, len(note_events)):
|
304
|
+
if note_events[j][0] >= note_events[i][1]: # start j > end i
|
305
|
+
break
|
306
|
+
note_events[i] = note_events[i][:-1] + (None,) # last field is pitch bend
|
307
|
+
note_events[j] = note_events[j][:-1] + (None,)
|
308
|
+
|
309
|
+
return note_events
|
310
|
+
|
311
|
+
|
312
|
+
def get_infered_onsets(onsets: np.array, frames: np.array, n_diff: int = 2) -> np.array:
|
313
|
+
"""Infer onsets from large changes in frame amplitudes.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
onsets: Array of note onset predictions.
|
317
|
+
frames: Audio frames.
|
318
|
+
n_diff: Differences used to detect onsets.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
The maximum between the predicted onsets and its differences.
|
322
|
+
"""
|
323
|
+
diffs = []
|
324
|
+
for n in range(1, n_diff + 1):
|
325
|
+
frames_appended = np.concatenate([np.zeros((n, frames.shape[1])), frames])
|
326
|
+
diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
|
327
|
+
frame_diff = np.min(diffs, axis=0)
|
328
|
+
frame_diff[frame_diff < 0] = 0
|
329
|
+
frame_diff[:n_diff, :] = 0
|
330
|
+
frame_diff = np.max(onsets) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
|
331
|
+
|
332
|
+
max_onsets_diff = np.max([onsets, frame_diff], axis=0) # use the max of the predicted onsets and the differences
|
333
|
+
|
334
|
+
return max_onsets_diff
|
335
|
+
|
336
|
+
|
337
|
+
def constrain_frequency(
|
338
|
+
onsets: np.array, frames: np.array, max_freq: Optional[float], min_freq: Optional[float]
|
339
|
+
) -> Tuple[np.array, np.array]:
|
340
|
+
"""Zero out activations above or below the max/min frequencies
|
341
|
+
|
342
|
+
Args:
|
343
|
+
onsets: Onset activation matrix (n_times, n_freqs)
|
344
|
+
frames: Frame activation matrix (n_times, n_freqs)
|
345
|
+
max_freq: The maximum frequency to keep.
|
346
|
+
min_freq: the minimum frequency to keep.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
The onset and frame activation matrices, with frequencies outside the min and max
|
350
|
+
frequency set to 0.
|
351
|
+
"""
|
352
|
+
if max_freq is not None:
|
353
|
+
max_freq_idx = int(np.round(hz_to_midi(max_freq) - MIDI_OFFSET))
|
354
|
+
onsets[:, max_freq_idx:] = 0
|
355
|
+
frames[:, max_freq_idx:] = 0
|
356
|
+
if min_freq is not None:
|
357
|
+
min_freq_idx = int(np.round(hz_to_midi(min_freq) - MIDI_OFFSET))
|
358
|
+
onsets[:, :min_freq_idx] = 0
|
359
|
+
frames[:, :min_freq_idx] = 0
|
360
|
+
|
361
|
+
return onsets, frames
|
362
|
+
|
363
|
+
|
364
|
+
def model_frames_to_time(n_frames: int) -> np.ndarray:
|
365
|
+
original_times = frames_to_time(
|
366
|
+
np.arange(n_frames),
|
367
|
+
sr=AUDIO_SAMPLE_RATE,
|
368
|
+
hop_length=FFT_HOP,
|
369
|
+
)
|
370
|
+
window_numbers = np.floor(np.arange(n_frames) / ANNOT_N_FRAMES)
|
371
|
+
window_offset = (FFT_HOP / AUDIO_SAMPLE_RATE) * (
|
372
|
+
ANNOT_N_FRAMES - (AUDIO_N_SAMPLES / FFT_HOP)
|
373
|
+
) + 0.0018 # this is a magic number, but it's needed for this to align properly
|
374
|
+
times = original_times - (window_offset * window_numbers)
|
375
|
+
return times
|
376
|
+
|
377
|
+
|
378
|
+
def output_to_notes_polyphonic(
|
379
|
+
frames: np.array,
|
380
|
+
onsets: np.array,
|
381
|
+
onset_thresh: float,
|
382
|
+
frame_thresh: float,
|
383
|
+
min_note_len: int,
|
384
|
+
infer_onsets: bool,
|
385
|
+
max_freq: Optional[float],
|
386
|
+
min_freq: Optional[float],
|
387
|
+
melodia_trick: bool = True,
|
388
|
+
energy_tol: int = 11,
|
389
|
+
) -> List[Tuple[int, int, int, float]]:
|
390
|
+
"""Decode raw model output to polyphonic note events
|
391
|
+
|
392
|
+
Args:
|
393
|
+
frames: Frame activation matrix (n_times, n_freqs).
|
394
|
+
onsets: Onset activation matrix (n_times, n_freqs).
|
395
|
+
onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
|
396
|
+
frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
|
397
|
+
min_note_len: Minimum allowed note length in frames.
|
398
|
+
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
|
399
|
+
max_freq: Maximum allowed output frequency, in Hz.
|
400
|
+
min_freq: Minimum allowed output frequency, in Hz.
|
401
|
+
melodia_trick : Whether to use the melodia trick to better detect notes.
|
402
|
+
energy_tol: Drop notes below this energy.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
|
406
|
+
representing the note events, where amplitude is a number between 0 and 1
|
407
|
+
"""
|
408
|
+
|
409
|
+
n_frames = frames.shape[0]
|
410
|
+
|
411
|
+
onsets, frames = constrain_frequency(onsets, frames, max_freq, min_freq)
|
412
|
+
# use onsets inferred from frames in addition to the predicted onsets
|
413
|
+
if infer_onsets:
|
414
|
+
onsets = get_infered_onsets(onsets, frames)
|
415
|
+
|
416
|
+
peak_thresh_mat = np.zeros(onsets.shape)
|
417
|
+
peaks = argrelmax(onsets, axis=0)
|
418
|
+
peak_thresh_mat[peaks] = onsets[peaks]
|
419
|
+
|
420
|
+
onset_idx = np.where(peak_thresh_mat >= onset_thresh)
|
421
|
+
onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
|
422
|
+
onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
|
423
|
+
|
424
|
+
remaining_energy = np.zeros(frames.shape)
|
425
|
+
remaining_energy[:, :] = frames[:, :]
|
426
|
+
|
427
|
+
# loop over onsets
|
428
|
+
note_events = []
|
429
|
+
for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
|
430
|
+
# if we're too close to the end of the audio, continue
|
431
|
+
if note_start_idx >= n_frames - 1:
|
432
|
+
continue
|
433
|
+
|
434
|
+
# find time index at this frequency band where the frames drop below an energy threshold
|
435
|
+
i = note_start_idx + 1
|
436
|
+
k = 0 # number of frames since energy dropped below threshold
|
437
|
+
while i < n_frames - 1 and k < energy_tol:
|
438
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
439
|
+
k += 1
|
440
|
+
else:
|
441
|
+
k = 0
|
442
|
+
i += 1
|
443
|
+
|
444
|
+
i -= k # go back to frame above threshold
|
445
|
+
|
446
|
+
# if the note is too short, skip it
|
447
|
+
if i - note_start_idx <= min_note_len:
|
448
|
+
continue
|
449
|
+
|
450
|
+
remaining_energy[note_start_idx:i, freq_idx] = 0
|
451
|
+
if freq_idx < MAX_FREQ_IDX:
|
452
|
+
remaining_energy[note_start_idx:i, freq_idx + 1] = 0
|
453
|
+
if freq_idx > 0:
|
454
|
+
remaining_energy[note_start_idx:i, freq_idx - 1] = 0
|
455
|
+
|
456
|
+
# add the note
|
457
|
+
amplitude = np.mean(frames[note_start_idx:i, freq_idx])
|
458
|
+
note_events.append(
|
459
|
+
(
|
460
|
+
note_start_idx,
|
461
|
+
i,
|
462
|
+
freq_idx + MIDI_OFFSET,
|
463
|
+
amplitude,
|
464
|
+
)
|
465
|
+
)
|
466
|
+
|
467
|
+
if melodia_trick:
|
468
|
+
energy_shape = remaining_energy.shape
|
469
|
+
|
470
|
+
while np.max(remaining_energy) > frame_thresh:
|
471
|
+
i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
|
472
|
+
remaining_energy[i_mid, freq_idx] = 0
|
473
|
+
|
474
|
+
# forward pass
|
475
|
+
i = i_mid + 1
|
476
|
+
k = 0
|
477
|
+
while i < n_frames - 1 and k < energy_tol:
|
478
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
479
|
+
k += 1
|
480
|
+
else:
|
481
|
+
k = 0
|
482
|
+
|
483
|
+
remaining_energy[i, freq_idx] = 0
|
484
|
+
if freq_idx < MAX_FREQ_IDX:
|
485
|
+
remaining_energy[i, freq_idx + 1] = 0
|
486
|
+
if freq_idx > 0:
|
487
|
+
remaining_energy[i, freq_idx - 1] = 0
|
488
|
+
|
489
|
+
i += 1
|
490
|
+
|
491
|
+
i_end = i - 1 - k # go back to frame above threshold
|
492
|
+
|
493
|
+
# backward pass
|
494
|
+
i = i_mid - 1
|
495
|
+
k = 0
|
496
|
+
while i > 0 and k < energy_tol:
|
497
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
498
|
+
k += 1
|
499
|
+
else:
|
500
|
+
k = 0
|
501
|
+
|
502
|
+
remaining_energy[i, freq_idx] = 0
|
503
|
+
if freq_idx < MAX_FREQ_IDX:
|
504
|
+
remaining_energy[i, freq_idx + 1] = 0
|
505
|
+
if freq_idx > 0:
|
506
|
+
remaining_energy[i, freq_idx - 1] = 0
|
507
|
+
|
508
|
+
i -= 1
|
509
|
+
|
510
|
+
i_start = i + 1 + k # go back to frame above threshold
|
511
|
+
assert i_start >= 0, "{}".format(i_start)
|
512
|
+
assert i_end < n_frames
|
513
|
+
|
514
|
+
if i_end - i_start <= min_note_len:
|
515
|
+
# note is too short, skip it
|
516
|
+
continue
|
517
|
+
|
518
|
+
# add the note
|
519
|
+
amplitude = np.mean(frames[i_start:i_end, freq_idx])
|
520
|
+
note_events.append(
|
521
|
+
(
|
522
|
+
i_start,
|
523
|
+
i_end,
|
524
|
+
freq_idx + MIDI_OFFSET,
|
525
|
+
amplitude,
|
526
|
+
)
|
527
|
+
)
|
528
|
+
|
529
|
+
return note_events
|
530
|
+
|
531
|
+
|
532
|
+
|
533
|
+
|
534
|
+
|
535
|
+
def log_base_b(x: Tensor, base: int) -> Tensor:
|
536
|
+
"""
|
537
|
+
Compute log_b(x)
|
538
|
+
Args:
|
539
|
+
x : input
|
540
|
+
base : log base. E.g. for log10 base=10
|
541
|
+
Returns:
|
542
|
+
log_base(x)
|
543
|
+
"""
|
544
|
+
numerator = torch.log(x)
|
545
|
+
denominator = torch.log(torch.tensor([base], dtype=numerator.dtype, device=numerator.device))
|
546
|
+
return numerator / denominator
|
547
|
+
|
548
|
+
|
549
|
+
def normalized_log(inputs: Tensor) -> Tensor:
|
550
|
+
"""
|
551
|
+
Takes an input with a shape of either (batch, x, y, z) or (batch, y, z)
|
552
|
+
and rescales each (y, z) to dB, scaled 0 - 1.
|
553
|
+
Only x=1 is supported.
|
554
|
+
This layer adds 1e-10 to all values as a way to avoid NaN math.
|
555
|
+
"""
|
556
|
+
power = torch.square(inputs)
|
557
|
+
log_power = 10 * log_base_b(power + 1e-10, 10)
|
558
|
+
|
559
|
+
log_power_min = torch.amin(log_power, dim=(1, 2)).reshape(inputs.shape[0], 1, 1)
|
560
|
+
log_power_offset = log_power - log_power_min
|
561
|
+
log_power_offset_max = torch.amax(log_power_offset, dim=(1, 2)).reshape(inputs.shape[0], 1, 1)
|
562
|
+
# equivalent to TF div_no_nan
|
563
|
+
log_power_normalized = log_power_offset / log_power_offset_max
|
564
|
+
log_power_normalized = torch.nan_to_num(log_power_normalized, nan=0.0)
|
565
|
+
|
566
|
+
return log_power_normalized.reshape(inputs.shape)
|
567
|
+
|
568
|
+
|
569
|
+
def get_cqt(
|
570
|
+
inputs: Tensor,
|
571
|
+
n_harmonics: int,
|
572
|
+
use_batch_norm: bool,
|
573
|
+
bn_layer: nn.BatchNorm2d,
|
574
|
+
):
|
575
|
+
"""Calculate the CQT of the input audio.
|
576
|
+
|
577
|
+
Input shape: (batch, number of audio samples, 1)
|
578
|
+
Output shape: (batch, number of frequency bins, number of time frames)
|
579
|
+
|
580
|
+
Args:
|
581
|
+
inputs: The audio input.
|
582
|
+
n_harmonics: The number of harmonics to capture above the maximum output frequency.
|
583
|
+
Used to calculate the number of semitones for the CQT.
|
584
|
+
use_batchnorm: If True, applies batch normalization after computing the CQT
|
585
|
+
|
586
|
+
Returns:
|
587
|
+
The log-normalized CQT of the input audio.
|
588
|
+
"""
|
589
|
+
n_semitones = np.min(
|
590
|
+
[
|
591
|
+
int(np.ceil(12.0 * np.log2(n_harmonics)) + ANNOTATIONS_N_SEMITONES),
|
592
|
+
MAX_N_SEMITONES,
|
593
|
+
]
|
594
|
+
)
|
595
|
+
cqt_layer = CQT2010v2(
|
596
|
+
sr=AUDIO_SAMPLE_RATE,
|
597
|
+
hop_length=FFT_HOP,
|
598
|
+
fmin=ANNOTATIONS_BASE_FREQUENCY,
|
599
|
+
n_bins=n_semitones * CONTOURS_BINS_PER_SEMITONE,
|
600
|
+
bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE,
|
601
|
+
verbose=False,
|
602
|
+
)
|
603
|
+
cqt_layer.to(inputs.device)
|
604
|
+
x = cqt_layer(inputs)
|
605
|
+
x = torch.transpose(x, 1, 2)
|
606
|
+
x = normalized_log(x)
|
607
|
+
|
608
|
+
x = x.unsqueeze(1)
|
609
|
+
if use_batch_norm:
|
610
|
+
x = bn_layer(x)
|
611
|
+
x = x.squeeze(1)
|
612
|
+
|
613
|
+
return x
|
614
|
+
|
615
|
+
|
616
|
+
class HarmonicStacking(nn.Module):
|
617
|
+
"""Harmonic stacking layer
|
618
|
+
|
619
|
+
Input shape: (n_batch, n_times, n_freqs, 1)
|
620
|
+
Output shape: (n_batch, n_times, n_output_freqs, len(harmonics))
|
621
|
+
|
622
|
+
n_freqs should be much larger than n_output_freqs so that information from the upper
|
623
|
+
harmonics is captured.
|
624
|
+
|
625
|
+
Attributes:
|
626
|
+
bins_per_semitone: The number of bins per semitone of the input CQT
|
627
|
+
harmonics: List of harmonics to use. Should be positive numbers.
|
628
|
+
shifts: A list containing the number of bins to shift in frequency for each harmonic
|
629
|
+
n_output_freqs: The number of frequency bins in each harmonic layer.
|
630
|
+
"""
|
631
|
+
|
632
|
+
def __init__(
|
633
|
+
self,
|
634
|
+
bins_per_semitone: int,
|
635
|
+
harmonics: List[float],
|
636
|
+
n_output_freqs: int,
|
637
|
+
):
|
638
|
+
super().__init__()
|
639
|
+
self.bins_per_semitone = bins_per_semitone
|
640
|
+
self.harmonics = harmonics
|
641
|
+
self.n_output_freqs = n_output_freqs
|
642
|
+
|
643
|
+
self.shifts = [
|
644
|
+
int(round(12.0 * self.bins_per_semitone * math.log2(h))) for h in self.harmonics
|
645
|
+
]
|
646
|
+
|
647
|
+
@torch.no_grad()
|
648
|
+
def forward(self, x):
|
649
|
+
# x: (batch, t, n_bins)
|
650
|
+
hcqt = []
|
651
|
+
for shift in self.shifts:
|
652
|
+
if shift == 0:
|
653
|
+
cur_cqt = x
|
654
|
+
if shift > 0:
|
655
|
+
cur_cqt = F.pad(x[:, :, shift:], (0, shift))
|
656
|
+
elif shift < 0: # sub-harmonic
|
657
|
+
cur_cqt = F.pad(x[:, :, :shift], (-shift, 0))
|
658
|
+
hcqt.append(cur_cqt)
|
659
|
+
hcqt = torch.stack(hcqt, dim=1)
|
660
|
+
hcqt = hcqt[:, :, :, :self.n_output_freqs]
|
661
|
+
return hcqt
|
662
|
+
|
663
|
+
|
664
|
+
class BasicPitchTorch(nn.Module):
|
665
|
+
|
666
|
+
def __init__(
|
667
|
+
self,
|
668
|
+
stack_harmonics=[0.5, 1, 2, 3, 4, 5, 6, 7],
|
669
|
+
) -> None:
|
670
|
+
super().__init__()
|
671
|
+
self.stack_harmonics = stack_harmonics
|
672
|
+
if len(stack_harmonics) > 0:
|
673
|
+
self.hs = HarmonicStacking(
|
674
|
+
bins_per_semitone=CONTOURS_BINS_PER_SEMITONE,
|
675
|
+
harmonics=stack_harmonics,
|
676
|
+
n_output_freqs=ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE
|
677
|
+
)
|
678
|
+
num_in_channels = len(stack_harmonics)
|
679
|
+
else:
|
680
|
+
num_in_channels = 1
|
681
|
+
|
682
|
+
self.bn_layer = nn.BatchNorm2d(1, eps=0.001)
|
683
|
+
self.conv_contour = nn.Sequential(
|
684
|
+
# NOTE: in the original implementation, this part of the network should be dangling...
|
685
|
+
# nn.Conv2d(num_in_channels, 32, kernel_size=5, padding="same"),
|
686
|
+
# nn.BatchNorm2d(32),
|
687
|
+
# nn.ReLU(),
|
688
|
+
nn.Conv2d(num_in_channels, 8, kernel_size=(3, 3 * 13), padding="same"),
|
689
|
+
nn.BatchNorm2d(8, eps=0.001),
|
690
|
+
nn.ReLU(),
|
691
|
+
nn.Conv2d(8, 1, kernel_size=5, padding="same"),
|
692
|
+
nn.Sigmoid()
|
693
|
+
)
|
694
|
+
self.conv_note = nn.Sequential(
|
695
|
+
nn.Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
|
696
|
+
nn.ReLU(),
|
697
|
+
nn.Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
|
698
|
+
nn.Sigmoid()
|
699
|
+
)
|
700
|
+
self.conv_onset_pre = nn.Sequential(
|
701
|
+
nn.Conv2d(num_in_channels, 32, kernel_size=5, stride=(1, 3)),
|
702
|
+
nn.BatchNorm2d(32, eps=0.001),
|
703
|
+
nn.ReLU(),
|
704
|
+
)
|
705
|
+
self.conv_onset_post = nn.Sequential(
|
706
|
+
nn.Conv2d(32 + 1, 1, kernel_size=3, stride=1, padding="same"),
|
707
|
+
nn.Sigmoid()
|
708
|
+
)
|
709
|
+
|
710
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
711
|
+
cqt = get_cqt(
|
712
|
+
x,
|
713
|
+
len(self.stack_harmonics),
|
714
|
+
True,
|
715
|
+
self.bn_layer,
|
716
|
+
)
|
717
|
+
if hasattr(self, "hs"):
|
718
|
+
cqt = self.hs(cqt)
|
719
|
+
else:
|
720
|
+
cqt = cqt.unsqueeze(1)
|
721
|
+
|
722
|
+
x_contour = self.conv_contour(cqt)
|
723
|
+
|
724
|
+
# for strided conv, padding is different between PyTorch and TensorFlow
|
725
|
+
# we use this equation: pad = [ (stride * (output-1)) - input + kernel ] / 2
|
726
|
+
# (172, 264) --(1, 3)--> (172, 88), pad = ((1 * 171 - 172 + 7) / 2, (3 * 87 - 264 + 7) / 2) = (3, 2)
|
727
|
+
# F.pad process from the last dimension, so it's (2, 2, 3, 3)
|
728
|
+
x_contour_for_note = F.pad(x_contour, (2,2,3,3))
|
729
|
+
x_note = self.conv_note(x_contour_for_note)
|
730
|
+
|
731
|
+
# (172, 264) --(1, 3)--> (172, 88), pad = ((1 * 171 - 172 + 5) / 2, (3 * 87 - 264 + 5) / 2) = (2, 1)
|
732
|
+
# F.pad process from the last dimension, so it's (1, 1, 2, 2)
|
733
|
+
cqt_for_onset = F.pad(cqt, (1,1,2,2))
|
734
|
+
x_onset_pre = self.conv_onset_pre(cqt_for_onset)
|
735
|
+
x_onset_pre = torch.cat([x_note, x_onset_pre], dim=1)
|
736
|
+
x_onset = self.conv_onset_post(x_onset_pre)
|
737
|
+
outputs = {"onset": x_onset.squeeze(1), "contour": x_contour.squeeze(1), "note": x_note.squeeze(1)}
|
738
|
+
return outputs
|
739
|
+
|
740
|
+
|
741
|
+
class BasicPitch():
|
742
|
+
def __init__(self,model_path=hf_hub_download("shethjenil/Audio2Midi_Models","basicpitch/nmp.pth"),device="cpu"):
|
743
|
+
self.model = BasicPitchTorch()
|
744
|
+
self.model.load_state_dict(torch.load(model_path))
|
745
|
+
self.model.to(device)
|
746
|
+
self.model.eval()
|
747
|
+
self.device = device
|
748
|
+
|
749
|
+
def run_inference(
|
750
|
+
self,
|
751
|
+
audio_path: str,
|
752
|
+
progress_callback: Callable[[int, int], None] = None
|
753
|
+
) -> Dict[str, np.array]:
|
754
|
+
audio_windowed, _, audio_original_length = get_audio_input(audio_path, OVERLAP_LEN, HOP_SIZE)
|
755
|
+
audio_windowed = torch.from_numpy(np.copy(audio_windowed)).T.to(self.device) # Shape: [num_windows, window_len]
|
756
|
+
|
757
|
+
outputs = []
|
758
|
+
total = audio_windowed.shape[0]
|
759
|
+
|
760
|
+
with torch.no_grad():
|
761
|
+
for i, window in enumerate(audio_windowed):
|
762
|
+
window = window.unsqueeze(0) # Add batch dimension
|
763
|
+
output = self.model(window)
|
764
|
+
outputs.append(output)
|
765
|
+
|
766
|
+
# Call the callback if provided
|
767
|
+
if progress_callback:
|
768
|
+
progress_callback(i + 1, total)
|
769
|
+
|
770
|
+
# Merge outputs (assuming model returns a dict of tensors)
|
771
|
+
merged_output = {}
|
772
|
+
for key in outputs[0]:
|
773
|
+
merged_output[key] = torch.cat([o[key] for o in outputs], dim=0)
|
774
|
+
|
775
|
+
unwrapped_output = {
|
776
|
+
k: unwrap_output(merged_output[k], audio_original_length, N_OVERLAPPING_FRAMES)
|
777
|
+
for k in merged_output
|
778
|
+
}
|
779
|
+
return unwrapped_output
|
780
|
+
|
781
|
+
def predict(self,audio,onset_thresh=0.5,frame_thresh=0.3,min_note_len=127.70,midi_tempo=120,infer_onsets=True,include_pitch_bends=True,multiple_pitch_bends=True,melodia_trick=True,progress_callback: Callable[[int, int], None] = None,min_freqat=None,max_freqat=None,output_file="output.mid"):
|
782
|
+
model_output_to_notes(self.run_inference(audio,progress_callback),onset_thresh = onset_thresh,frame_thresh = frame_thresh,infer_onsets = infer_onsets,min_note_len = min_note_len,min_freq = min_freqat,max_freq = max_freqat,include_pitch_bends = include_pitch_bends,multiple_pitch_bends = multiple_pitch_bends,melodia_trick = melodia_trick,midi_tempo = midi_tempo).write(output_file)
|
783
|
+
return output_file
|