audio2midi 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,783 @@
1
+ from collections import defaultdict
2
+ from typing import DefaultDict , List, Optional , Dict , Tuple , Callable
3
+ from librosa import load as librosa_load,midi_to_hz,hz_to_midi
4
+ from librosa.core import frames_to_time
5
+ from librosa.util import frame as librosa_util_frame
6
+ from pretty_midi_fix import PrettyMIDI , PitchBend , Instrument,Note
7
+ from scipy.signal.windows import gaussian
8
+ from scipy.signal import argrelmax
9
+ from huggingface_hub import hf_hub_download
10
+ from torch import nn, Tensor
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import torch
14
+ import math
15
+
16
+ from nnAudio.features import CQT2010v2
17
+
18
+ FFT_HOP = 256
19
+ AUDIO_SAMPLE_RATE = 22050
20
+ AUDIO_WINDOW_LENGTH = 2
21
+ NOTES_BINS_PER_SEMITONE = 1
22
+ CONTOURS_BINS_PER_SEMITONE = 3
23
+ ANNOTATIONS_BASE_FREQUENCY = 27.5
24
+ ANNOTATIONS_N_SEMITONES = 88
25
+ AUDIO_WINDOW_LENGTH = 2
26
+ MIDI_OFFSET = 21
27
+ N_PITCH_BEND_TICKS = 8192
28
+ MAX_FREQ_IDX = 87
29
+ N_OVERLAPPING_FRAMES = 30
30
+ ANNOTATIONS_FPS = AUDIO_SAMPLE_RATE // FFT_HOP
31
+ AUDIO_N_SAMPLES = AUDIO_SAMPLE_RATE * AUDIO_WINDOW_LENGTH - FFT_HOP
32
+ N_FFT = 8 * FFT_HOP
33
+ N_FREQ_BINS_NOTES = ANNOTATIONS_N_SEMITONES * NOTES_BINS_PER_SEMITONE
34
+ N_FREQ_BINS_CONTOURS = ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE
35
+ ANNOT_N_FRAMES = ANNOTATIONS_FPS * AUDIO_WINDOW_LENGTH
36
+ AUDIO_N_SAMPLES = AUDIO_SAMPLE_RATE * AUDIO_WINDOW_LENGTH - FFT_HOP
37
+ OVERLAP_LEN = N_OVERLAPPING_FRAMES * FFT_HOP
38
+ HOP_SIZE = AUDIO_N_SAMPLES - OVERLAP_LEN
39
+ MAX_N_SEMITONES = int(np.floor(12.0 * np.log2(0.5 * AUDIO_SAMPLE_RATE / ANNOTATIONS_BASE_FREQUENCY)))
40
+
41
+ def frame_with_pad(x: np.array, frame_length: int, hop_size: int) -> np.array:
42
+ """
43
+ Extends librosa.util.frame with end padding if required, similar to
44
+ tf.signal.frame(pad_end=True).
45
+
46
+ Returns:
47
+ framed_audio: tensor with shape (n_windows, AUDIO_N_SAMPLES)
48
+ """
49
+ n_frames = int(np.ceil((x.shape[0] - frame_length) / hop_size)) + 1
50
+ n_pads = (n_frames - 1) * hop_size + frame_length - x.shape[0]
51
+ x = np.pad(x, (0, n_pads), mode="constant")
52
+ framed_audio = librosa_util_frame(x, frame_length=frame_length, hop_length=hop_size)
53
+ return framed_audio
54
+
55
+
56
+ def window_audio_file(audio_original: np.array, hop_size: int) -> Tuple[np.array, List[Dict[str, int]]]:
57
+ """
58
+ Pad appropriately an audio file, and return as
59
+ windowed signal, with window length = AUDIO_N_SAMPLES
60
+
61
+ Returns:
62
+ audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
63
+ audio windowed into fixed length chunks
64
+ window_times: list of {'start':.., 'end':...} objects (times in seconds)
65
+
66
+ """
67
+ audio_windowed = frame_with_pad(audio_original, AUDIO_N_SAMPLES, hop_size)
68
+ window_times = [
69
+ {
70
+ "start": t_start,
71
+ "end": t_start + (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE),
72
+ }
73
+ for t_start in np.arange(audio_windowed.shape[0]) * hop_size / AUDIO_SAMPLE_RATE
74
+ ]
75
+ return audio_windowed, window_times
76
+
77
+
78
+ def get_audio_input(
79
+ audio_path: str, overlap_len: int, hop_size: int
80
+ ) -> Tuple[Tensor, List[Dict[str, int]], int]:
81
+ """
82
+ Read wave file (as mono), pad appropriately, and return as
83
+ windowed signal, with window length = AUDIO_N_SAMPLES
84
+
85
+ Returns:
86
+ audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
87
+ audio windowed into fixed length chunks
88
+ window_times: list of {'start':.., 'end':...} objects (times in seconds)
89
+ audio_original_length: int
90
+ length of original audio file, in frames, BEFORE padding.
91
+
92
+ """
93
+ assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)
94
+
95
+ audio_original, _ = librosa_load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
96
+
97
+ original_length = audio_original.shape[0]
98
+ audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
99
+ audio_windowed, window_times = window_audio_file(audio_original, hop_size)
100
+ return audio_windowed, window_times, original_length
101
+
102
+
103
+ def unwrap_output(output: Tensor, audio_original_length: int, n_overlapping_frames: int) -> np.array:
104
+ """Unwrap batched model predictions to a single matrix.
105
+
106
+ Args:
107
+ output: array (n_batches, n_times_short, n_freqs)
108
+ audio_original_length: length of original audio signal (in samples)
109
+ n_overlapping_frames: number of overlapping frames in the output
110
+
111
+ Returns:
112
+ array (n_times, n_freqs)
113
+ """
114
+ raw_output = output.cpu().detach().numpy()
115
+ if len(raw_output.shape) != 3:
116
+ return None
117
+
118
+ n_olap = int(0.5 * n_overlapping_frames)
119
+ if n_olap > 0:
120
+ # remove half of the overlapping frames from beginning and end
121
+ raw_output = raw_output[:, n_olap:-n_olap, :]
122
+
123
+ output_shape = raw_output.shape
124
+ n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
125
+ unwrapped_output = raw_output.reshape(output_shape[0] * output_shape[1], output_shape[2])
126
+ return unwrapped_output[:n_output_frames_original, :] # trim to original audio length
127
+
128
+
129
+
130
+
131
+ def model_output_to_notes(
132
+ output: Dict[str, np.array],
133
+ onset_thresh: float,
134
+ frame_thresh: float,
135
+ infer_onsets: bool = True,
136
+ min_note_len: int = 11,
137
+ min_freq: Optional[float] = None,
138
+ max_freq: Optional[float] = None,
139
+ include_pitch_bends: bool = True,
140
+ multiple_pitch_bends: bool = False,
141
+ melodia_trick: bool = True,
142
+ midi_tempo: float = 120,
143
+ ) -> PrettyMIDI:
144
+ """Convert model output to MIDI
145
+
146
+ Args:
147
+ output: A dictionary with shape
148
+ {
149
+ 'frame': array of shape (n_times, n_freqs),
150
+ 'onset': array of shape (n_times, n_freqs),
151
+ 'contour': array of shape (n_times, 3*n_freqs)
152
+ }
153
+ representing the output of the basic pitch model.
154
+ onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
155
+ infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
156
+ min_note_len: The minimum allowed note length in frames.
157
+ min_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
158
+ max_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
159
+ include_pitch_bends: If True, include pitch bends.
160
+ multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
161
+ melodia_trick: Use the melodia post-processing step.
162
+
163
+ Returns:
164
+ midi : PrettyMIDI object
165
+ note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
166
+ """
167
+ frames = output["note"]
168
+ onsets = output["onset"]
169
+ contours = output["contour"]
170
+
171
+ estimated_notes = output_to_notes_polyphonic(
172
+ frames,
173
+ onsets,
174
+ onset_thresh=onset_thresh,
175
+ frame_thresh=frame_thresh,
176
+ infer_onsets=infer_onsets,
177
+ min_note_len=min_note_len,
178
+ min_freq=min_freq,
179
+ max_freq=max_freq,
180
+ melodia_trick=melodia_trick,
181
+ )
182
+ if include_pitch_bends:
183
+ estimated_notes_with_pitch_bend = get_pitch_bends(contours, estimated_notes)
184
+ else:
185
+ estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
186
+
187
+ times_s = model_frames_to_time(contours.shape[0])
188
+ estimated_notes_time_seconds = [
189
+ (times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
190
+ ]
191
+
192
+ return note_events_to_midi(estimated_notes_time_seconds, multiple_pitch_bends, midi_tempo)
193
+
194
+
195
+ def midi_pitch_to_contour_bin(pitch_midi: int) -> np.array:
196
+ """Convert midi pitch to conrresponding index in contour matrix
197
+
198
+ Args:
199
+ pitch_midi: pitch in midi
200
+
201
+ Returns:
202
+ index in contour matrix
203
+
204
+ """
205
+ pitch_hz = midi_to_hz(pitch_midi)
206
+ return 12.0 * CONTOURS_BINS_PER_SEMITONE * np.log2(pitch_hz / ANNOTATIONS_BASE_FREQUENCY)
207
+
208
+
209
+ def get_pitch_bends(
210
+ contours: np.ndarray, note_events: List[Tuple[int, int, int, float]], n_bins_tolerance: int = 25
211
+ ) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
212
+ """Given note events and contours, estimate pitch bends per note.
213
+ Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
214
+ The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
215
+ the start and end times of each note event.
216
+
217
+ Args:
218
+ contours: Matrix of estimated pitch contours
219
+ note_events: note event tuple
220
+ n_bins_tolerance: Pitch bend estimation range. Defaults to 25.
221
+
222
+ Returns:
223
+ note events with pitch bends
224
+ """
225
+ window_length = n_bins_tolerance * 2 + 1
226
+ freq_gaussian = gaussian(window_length, std=5)
227
+ note_events_with_pitch_bends = []
228
+ for start_idx, end_idx, pitch_midi, amplitude in note_events:
229
+ freq_idx = int(np.round(midi_pitch_to_contour_bin(pitch_midi)))
230
+ freq_start_idx = np.max([freq_idx - n_bins_tolerance, 0])
231
+ freq_end_idx = np.min([N_FREQ_BINS_CONTOURS, freq_idx + n_bins_tolerance + 1])
232
+
233
+ pitch_bend_submatrix = (
234
+ contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
235
+ * freq_gaussian[
236
+ np.max([0, n_bins_tolerance - freq_idx]) : window_length
237
+ - np.max([0, freq_idx - (N_FREQ_BINS_CONTOURS - n_bins_tolerance - 1)])
238
+ ]
239
+ )
240
+ pb_shift = n_bins_tolerance - np.max([0, n_bins_tolerance - freq_idx])
241
+
242
+ bends: Optional[List[int]] = list(
243
+ np.argmax(pitch_bend_submatrix, axis=1) - pb_shift
244
+ ) # this is in units of 1/3 semitones
245
+ note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
246
+ return note_events_with_pitch_bends
247
+
248
+
249
+ def note_events_to_midi(
250
+ note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
251
+ multiple_pitch_bends: bool = False,
252
+ midi_tempo: float = 120,
253
+ ) -> PrettyMIDI:
254
+ """Create a pretty_midi_fix object from note events
255
+
256
+ Args:
257
+ note_events : list of tuples [(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
258
+ where amplitude is a number between 0 and 1
259
+ multiple_pitch_bends : If True, allow overlapping notes to have pitch bends
260
+ Note: this will assign each pitch to its own midi instrument, as midi does not yet
261
+ support per-note pitch bends
262
+
263
+ Returns:
264
+ PrettyMIDI() object
265
+
266
+ """
267
+ mid = PrettyMIDI(initial_tempo=midi_tempo)
268
+ if not multiple_pitch_bends:
269
+ note_events_with_pitch_bends = drop_overlapping_pitch_bends(note_events_with_pitch_bends)
270
+ instruments: DefaultDict[int, Instrument] = defaultdict(
271
+ lambda: Instrument(program=40)
272
+ )
273
+ for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
274
+ instrument = instruments[note_number] if multiple_pitch_bends else instruments[0]
275
+ note = Note(
276
+ velocity=int(np.round(127 * amplitude)),
277
+ pitch=note_number,
278
+ start=start_time,
279
+ end=end_time,
280
+ )
281
+ instrument.notes.append(note)
282
+ if not pitch_bend:
283
+ continue
284
+ pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
285
+ pitch_bend_midi_ticks = np.round(np.array(pitch_bend) * 4096 / CONTOURS_BINS_PER_SEMITONE).astype(int)
286
+ # This supports pitch bends up to 2 semitones
287
+ # If we estimate pitch bends above/below 2 semitones, crop them here when adding them to the midi file
288
+ pitch_bend_midi_ticks[pitch_bend_midi_ticks > N_PITCH_BEND_TICKS - 1] = N_PITCH_BEND_TICKS - 1
289
+ pitch_bend_midi_ticks[pitch_bend_midi_ticks < -N_PITCH_BEND_TICKS] = -N_PITCH_BEND_TICKS
290
+ for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend_midi_ticks):
291
+ instrument.pitch_bends.append(PitchBend(pb_midi, pb_time))
292
+ mid.instruments.extend(instruments.values())
293
+
294
+ return mid
295
+
296
+
297
+ def drop_overlapping_pitch_bends(
298
+ note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]]
299
+ ) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
300
+ """Drop pitch bends from any notes that overlap in time with another note"""
301
+ note_events = sorted(note_events_with_pitch_bends)
302
+ for i in range(len(note_events) - 1):
303
+ for j in range(i + 1, len(note_events)):
304
+ if note_events[j][0] >= note_events[i][1]: # start j > end i
305
+ break
306
+ note_events[i] = note_events[i][:-1] + (None,) # last field is pitch bend
307
+ note_events[j] = note_events[j][:-1] + (None,)
308
+
309
+ return note_events
310
+
311
+
312
+ def get_infered_onsets(onsets: np.array, frames: np.array, n_diff: int = 2) -> np.array:
313
+ """Infer onsets from large changes in frame amplitudes.
314
+
315
+ Args:
316
+ onsets: Array of note onset predictions.
317
+ frames: Audio frames.
318
+ n_diff: Differences used to detect onsets.
319
+
320
+ Returns:
321
+ The maximum between the predicted onsets and its differences.
322
+ """
323
+ diffs = []
324
+ for n in range(1, n_diff + 1):
325
+ frames_appended = np.concatenate([np.zeros((n, frames.shape[1])), frames])
326
+ diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
327
+ frame_diff = np.min(diffs, axis=0)
328
+ frame_diff[frame_diff < 0] = 0
329
+ frame_diff[:n_diff, :] = 0
330
+ frame_diff = np.max(onsets) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
331
+
332
+ max_onsets_diff = np.max([onsets, frame_diff], axis=0) # use the max of the predicted onsets and the differences
333
+
334
+ return max_onsets_diff
335
+
336
+
337
+ def constrain_frequency(
338
+ onsets: np.array, frames: np.array, max_freq: Optional[float], min_freq: Optional[float]
339
+ ) -> Tuple[np.array, np.array]:
340
+ """Zero out activations above or below the max/min frequencies
341
+
342
+ Args:
343
+ onsets: Onset activation matrix (n_times, n_freqs)
344
+ frames: Frame activation matrix (n_times, n_freqs)
345
+ max_freq: The maximum frequency to keep.
346
+ min_freq: the minimum frequency to keep.
347
+
348
+ Returns:
349
+ The onset and frame activation matrices, with frequencies outside the min and max
350
+ frequency set to 0.
351
+ """
352
+ if max_freq is not None:
353
+ max_freq_idx = int(np.round(hz_to_midi(max_freq) - MIDI_OFFSET))
354
+ onsets[:, max_freq_idx:] = 0
355
+ frames[:, max_freq_idx:] = 0
356
+ if min_freq is not None:
357
+ min_freq_idx = int(np.round(hz_to_midi(min_freq) - MIDI_OFFSET))
358
+ onsets[:, :min_freq_idx] = 0
359
+ frames[:, :min_freq_idx] = 0
360
+
361
+ return onsets, frames
362
+
363
+
364
+ def model_frames_to_time(n_frames: int) -> np.ndarray:
365
+ original_times = frames_to_time(
366
+ np.arange(n_frames),
367
+ sr=AUDIO_SAMPLE_RATE,
368
+ hop_length=FFT_HOP,
369
+ )
370
+ window_numbers = np.floor(np.arange(n_frames) / ANNOT_N_FRAMES)
371
+ window_offset = (FFT_HOP / AUDIO_SAMPLE_RATE) * (
372
+ ANNOT_N_FRAMES - (AUDIO_N_SAMPLES / FFT_HOP)
373
+ ) + 0.0018 # this is a magic number, but it's needed for this to align properly
374
+ times = original_times - (window_offset * window_numbers)
375
+ return times
376
+
377
+
378
+ def output_to_notes_polyphonic(
379
+ frames: np.array,
380
+ onsets: np.array,
381
+ onset_thresh: float,
382
+ frame_thresh: float,
383
+ min_note_len: int,
384
+ infer_onsets: bool,
385
+ max_freq: Optional[float],
386
+ min_freq: Optional[float],
387
+ melodia_trick: bool = True,
388
+ energy_tol: int = 11,
389
+ ) -> List[Tuple[int, int, int, float]]:
390
+ """Decode raw model output to polyphonic note events
391
+
392
+ Args:
393
+ frames: Frame activation matrix (n_times, n_freqs).
394
+ onsets: Onset activation matrix (n_times, n_freqs).
395
+ onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
396
+ frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
397
+ min_note_len: Minimum allowed note length in frames.
398
+ infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
399
+ max_freq: Maximum allowed output frequency, in Hz.
400
+ min_freq: Minimum allowed output frequency, in Hz.
401
+ melodia_trick : Whether to use the melodia trick to better detect notes.
402
+ energy_tol: Drop notes below this energy.
403
+
404
+ Returns:
405
+ list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
406
+ representing the note events, where amplitude is a number between 0 and 1
407
+ """
408
+
409
+ n_frames = frames.shape[0]
410
+
411
+ onsets, frames = constrain_frequency(onsets, frames, max_freq, min_freq)
412
+ # use onsets inferred from frames in addition to the predicted onsets
413
+ if infer_onsets:
414
+ onsets = get_infered_onsets(onsets, frames)
415
+
416
+ peak_thresh_mat = np.zeros(onsets.shape)
417
+ peaks = argrelmax(onsets, axis=0)
418
+ peak_thresh_mat[peaks] = onsets[peaks]
419
+
420
+ onset_idx = np.where(peak_thresh_mat >= onset_thresh)
421
+ onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
422
+ onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
423
+
424
+ remaining_energy = np.zeros(frames.shape)
425
+ remaining_energy[:, :] = frames[:, :]
426
+
427
+ # loop over onsets
428
+ note_events = []
429
+ for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
430
+ # if we're too close to the end of the audio, continue
431
+ if note_start_idx >= n_frames - 1:
432
+ continue
433
+
434
+ # find time index at this frequency band where the frames drop below an energy threshold
435
+ i = note_start_idx + 1
436
+ k = 0 # number of frames since energy dropped below threshold
437
+ while i < n_frames - 1 and k < energy_tol:
438
+ if remaining_energy[i, freq_idx] < frame_thresh:
439
+ k += 1
440
+ else:
441
+ k = 0
442
+ i += 1
443
+
444
+ i -= k # go back to frame above threshold
445
+
446
+ # if the note is too short, skip it
447
+ if i - note_start_idx <= min_note_len:
448
+ continue
449
+
450
+ remaining_energy[note_start_idx:i, freq_idx] = 0
451
+ if freq_idx < MAX_FREQ_IDX:
452
+ remaining_energy[note_start_idx:i, freq_idx + 1] = 0
453
+ if freq_idx > 0:
454
+ remaining_energy[note_start_idx:i, freq_idx - 1] = 0
455
+
456
+ # add the note
457
+ amplitude = np.mean(frames[note_start_idx:i, freq_idx])
458
+ note_events.append(
459
+ (
460
+ note_start_idx,
461
+ i,
462
+ freq_idx + MIDI_OFFSET,
463
+ amplitude,
464
+ )
465
+ )
466
+
467
+ if melodia_trick:
468
+ energy_shape = remaining_energy.shape
469
+
470
+ while np.max(remaining_energy) > frame_thresh:
471
+ i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
472
+ remaining_energy[i_mid, freq_idx] = 0
473
+
474
+ # forward pass
475
+ i = i_mid + 1
476
+ k = 0
477
+ while i < n_frames - 1 and k < energy_tol:
478
+ if remaining_energy[i, freq_idx] < frame_thresh:
479
+ k += 1
480
+ else:
481
+ k = 0
482
+
483
+ remaining_energy[i, freq_idx] = 0
484
+ if freq_idx < MAX_FREQ_IDX:
485
+ remaining_energy[i, freq_idx + 1] = 0
486
+ if freq_idx > 0:
487
+ remaining_energy[i, freq_idx - 1] = 0
488
+
489
+ i += 1
490
+
491
+ i_end = i - 1 - k # go back to frame above threshold
492
+
493
+ # backward pass
494
+ i = i_mid - 1
495
+ k = 0
496
+ while i > 0 and k < energy_tol:
497
+ if remaining_energy[i, freq_idx] < frame_thresh:
498
+ k += 1
499
+ else:
500
+ k = 0
501
+
502
+ remaining_energy[i, freq_idx] = 0
503
+ if freq_idx < MAX_FREQ_IDX:
504
+ remaining_energy[i, freq_idx + 1] = 0
505
+ if freq_idx > 0:
506
+ remaining_energy[i, freq_idx - 1] = 0
507
+
508
+ i -= 1
509
+
510
+ i_start = i + 1 + k # go back to frame above threshold
511
+ assert i_start >= 0, "{}".format(i_start)
512
+ assert i_end < n_frames
513
+
514
+ if i_end - i_start <= min_note_len:
515
+ # note is too short, skip it
516
+ continue
517
+
518
+ # add the note
519
+ amplitude = np.mean(frames[i_start:i_end, freq_idx])
520
+ note_events.append(
521
+ (
522
+ i_start,
523
+ i_end,
524
+ freq_idx + MIDI_OFFSET,
525
+ amplitude,
526
+ )
527
+ )
528
+
529
+ return note_events
530
+
531
+
532
+
533
+
534
+
535
+ def log_base_b(x: Tensor, base: int) -> Tensor:
536
+ """
537
+ Compute log_b(x)
538
+ Args:
539
+ x : input
540
+ base : log base. E.g. for log10 base=10
541
+ Returns:
542
+ log_base(x)
543
+ """
544
+ numerator = torch.log(x)
545
+ denominator = torch.log(torch.tensor([base], dtype=numerator.dtype, device=numerator.device))
546
+ return numerator / denominator
547
+
548
+
549
+ def normalized_log(inputs: Tensor) -> Tensor:
550
+ """
551
+ Takes an input with a shape of either (batch, x, y, z) or (batch, y, z)
552
+ and rescales each (y, z) to dB, scaled 0 - 1.
553
+ Only x=1 is supported.
554
+ This layer adds 1e-10 to all values as a way to avoid NaN math.
555
+ """
556
+ power = torch.square(inputs)
557
+ log_power = 10 * log_base_b(power + 1e-10, 10)
558
+
559
+ log_power_min = torch.amin(log_power, dim=(1, 2)).reshape(inputs.shape[0], 1, 1)
560
+ log_power_offset = log_power - log_power_min
561
+ log_power_offset_max = torch.amax(log_power_offset, dim=(1, 2)).reshape(inputs.shape[0], 1, 1)
562
+ # equivalent to TF div_no_nan
563
+ log_power_normalized = log_power_offset / log_power_offset_max
564
+ log_power_normalized = torch.nan_to_num(log_power_normalized, nan=0.0)
565
+
566
+ return log_power_normalized.reshape(inputs.shape)
567
+
568
+
569
+ def get_cqt(
570
+ inputs: Tensor,
571
+ n_harmonics: int,
572
+ use_batch_norm: bool,
573
+ bn_layer: nn.BatchNorm2d,
574
+ ):
575
+ """Calculate the CQT of the input audio.
576
+
577
+ Input shape: (batch, number of audio samples, 1)
578
+ Output shape: (batch, number of frequency bins, number of time frames)
579
+
580
+ Args:
581
+ inputs: The audio input.
582
+ n_harmonics: The number of harmonics to capture above the maximum output frequency.
583
+ Used to calculate the number of semitones for the CQT.
584
+ use_batchnorm: If True, applies batch normalization after computing the CQT
585
+
586
+ Returns:
587
+ The log-normalized CQT of the input audio.
588
+ """
589
+ n_semitones = np.min(
590
+ [
591
+ int(np.ceil(12.0 * np.log2(n_harmonics)) + ANNOTATIONS_N_SEMITONES),
592
+ MAX_N_SEMITONES,
593
+ ]
594
+ )
595
+ cqt_layer = CQT2010v2(
596
+ sr=AUDIO_SAMPLE_RATE,
597
+ hop_length=FFT_HOP,
598
+ fmin=ANNOTATIONS_BASE_FREQUENCY,
599
+ n_bins=n_semitones * CONTOURS_BINS_PER_SEMITONE,
600
+ bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE,
601
+ verbose=False,
602
+ )
603
+ cqt_layer.to(inputs.device)
604
+ x = cqt_layer(inputs)
605
+ x = torch.transpose(x, 1, 2)
606
+ x = normalized_log(x)
607
+
608
+ x = x.unsqueeze(1)
609
+ if use_batch_norm:
610
+ x = bn_layer(x)
611
+ x = x.squeeze(1)
612
+
613
+ return x
614
+
615
+
616
+ class HarmonicStacking(nn.Module):
617
+ """Harmonic stacking layer
618
+
619
+ Input shape: (n_batch, n_times, n_freqs, 1)
620
+ Output shape: (n_batch, n_times, n_output_freqs, len(harmonics))
621
+
622
+ n_freqs should be much larger than n_output_freqs so that information from the upper
623
+ harmonics is captured.
624
+
625
+ Attributes:
626
+ bins_per_semitone: The number of bins per semitone of the input CQT
627
+ harmonics: List of harmonics to use. Should be positive numbers.
628
+ shifts: A list containing the number of bins to shift in frequency for each harmonic
629
+ n_output_freqs: The number of frequency bins in each harmonic layer.
630
+ """
631
+
632
+ def __init__(
633
+ self,
634
+ bins_per_semitone: int,
635
+ harmonics: List[float],
636
+ n_output_freqs: int,
637
+ ):
638
+ super().__init__()
639
+ self.bins_per_semitone = bins_per_semitone
640
+ self.harmonics = harmonics
641
+ self.n_output_freqs = n_output_freqs
642
+
643
+ self.shifts = [
644
+ int(round(12.0 * self.bins_per_semitone * math.log2(h))) for h in self.harmonics
645
+ ]
646
+
647
+ @torch.no_grad()
648
+ def forward(self, x):
649
+ # x: (batch, t, n_bins)
650
+ hcqt = []
651
+ for shift in self.shifts:
652
+ if shift == 0:
653
+ cur_cqt = x
654
+ if shift > 0:
655
+ cur_cqt = F.pad(x[:, :, shift:], (0, shift))
656
+ elif shift < 0: # sub-harmonic
657
+ cur_cqt = F.pad(x[:, :, :shift], (-shift, 0))
658
+ hcqt.append(cur_cqt)
659
+ hcqt = torch.stack(hcqt, dim=1)
660
+ hcqt = hcqt[:, :, :, :self.n_output_freqs]
661
+ return hcqt
662
+
663
+
664
+ class BasicPitchTorch(nn.Module):
665
+
666
+ def __init__(
667
+ self,
668
+ stack_harmonics=[0.5, 1, 2, 3, 4, 5, 6, 7],
669
+ ) -> None:
670
+ super().__init__()
671
+ self.stack_harmonics = stack_harmonics
672
+ if len(stack_harmonics) > 0:
673
+ self.hs = HarmonicStacking(
674
+ bins_per_semitone=CONTOURS_BINS_PER_SEMITONE,
675
+ harmonics=stack_harmonics,
676
+ n_output_freqs=ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE
677
+ )
678
+ num_in_channels = len(stack_harmonics)
679
+ else:
680
+ num_in_channels = 1
681
+
682
+ self.bn_layer = nn.BatchNorm2d(1, eps=0.001)
683
+ self.conv_contour = nn.Sequential(
684
+ # NOTE: in the original implementation, this part of the network should be dangling...
685
+ # nn.Conv2d(num_in_channels, 32, kernel_size=5, padding="same"),
686
+ # nn.BatchNorm2d(32),
687
+ # nn.ReLU(),
688
+ nn.Conv2d(num_in_channels, 8, kernel_size=(3, 3 * 13), padding="same"),
689
+ nn.BatchNorm2d(8, eps=0.001),
690
+ nn.ReLU(),
691
+ nn.Conv2d(8, 1, kernel_size=5, padding="same"),
692
+ nn.Sigmoid()
693
+ )
694
+ self.conv_note = nn.Sequential(
695
+ nn.Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
696
+ nn.ReLU(),
697
+ nn.Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
698
+ nn.Sigmoid()
699
+ )
700
+ self.conv_onset_pre = nn.Sequential(
701
+ nn.Conv2d(num_in_channels, 32, kernel_size=5, stride=(1, 3)),
702
+ nn.BatchNorm2d(32, eps=0.001),
703
+ nn.ReLU(),
704
+ )
705
+ self.conv_onset_post = nn.Sequential(
706
+ nn.Conv2d(32 + 1, 1, kernel_size=3, stride=1, padding="same"),
707
+ nn.Sigmoid()
708
+ )
709
+
710
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
711
+ cqt = get_cqt(
712
+ x,
713
+ len(self.stack_harmonics),
714
+ True,
715
+ self.bn_layer,
716
+ )
717
+ if hasattr(self, "hs"):
718
+ cqt = self.hs(cqt)
719
+ else:
720
+ cqt = cqt.unsqueeze(1)
721
+
722
+ x_contour = self.conv_contour(cqt)
723
+
724
+ # for strided conv, padding is different between PyTorch and TensorFlow
725
+ # we use this equation: pad = [ (stride * (output-1)) - input + kernel ] / 2
726
+ # (172, 264) --(1, 3)--> (172, 88), pad = ((1 * 171 - 172 + 7) / 2, (3 * 87 - 264 + 7) / 2) = (3, 2)
727
+ # F.pad process from the last dimension, so it's (2, 2, 3, 3)
728
+ x_contour_for_note = F.pad(x_contour, (2,2,3,3))
729
+ x_note = self.conv_note(x_contour_for_note)
730
+
731
+ # (172, 264) --(1, 3)--> (172, 88), pad = ((1 * 171 - 172 + 5) / 2, (3 * 87 - 264 + 5) / 2) = (2, 1)
732
+ # F.pad process from the last dimension, so it's (1, 1, 2, 2)
733
+ cqt_for_onset = F.pad(cqt, (1,1,2,2))
734
+ x_onset_pre = self.conv_onset_pre(cqt_for_onset)
735
+ x_onset_pre = torch.cat([x_note, x_onset_pre], dim=1)
736
+ x_onset = self.conv_onset_post(x_onset_pre)
737
+ outputs = {"onset": x_onset.squeeze(1), "contour": x_contour.squeeze(1), "note": x_note.squeeze(1)}
738
+ return outputs
739
+
740
+
741
+ class BasicPitch():
742
+ def __init__(self,model_path=hf_hub_download("shethjenil/Audio2Midi_Models","basicpitch/nmp.pth"),device="cpu"):
743
+ self.model = BasicPitchTorch()
744
+ self.model.load_state_dict(torch.load(model_path))
745
+ self.model.to(device)
746
+ self.model.eval()
747
+ self.device = device
748
+
749
+ def run_inference(
750
+ self,
751
+ audio_path: str,
752
+ progress_callback: Callable[[int, int], None] = None
753
+ ) -> Dict[str, np.array]:
754
+ audio_windowed, _, audio_original_length = get_audio_input(audio_path, OVERLAP_LEN, HOP_SIZE)
755
+ audio_windowed = torch.from_numpy(np.copy(audio_windowed)).T.to(self.device) # Shape: [num_windows, window_len]
756
+
757
+ outputs = []
758
+ total = audio_windowed.shape[0]
759
+
760
+ with torch.no_grad():
761
+ for i, window in enumerate(audio_windowed):
762
+ window = window.unsqueeze(0) # Add batch dimension
763
+ output = self.model(window)
764
+ outputs.append(output)
765
+
766
+ # Call the callback if provided
767
+ if progress_callback:
768
+ progress_callback(i + 1, total)
769
+
770
+ # Merge outputs (assuming model returns a dict of tensors)
771
+ merged_output = {}
772
+ for key in outputs[0]:
773
+ merged_output[key] = torch.cat([o[key] for o in outputs], dim=0)
774
+
775
+ unwrapped_output = {
776
+ k: unwrap_output(merged_output[k], audio_original_length, N_OVERLAPPING_FRAMES)
777
+ for k in merged_output
778
+ }
779
+ return unwrapped_output
780
+
781
+ def predict(self,audio,onset_thresh=0.5,frame_thresh=0.3,min_note_len=127.70,midi_tempo=120,infer_onsets=True,include_pitch_bends=True,multiple_pitch_bends=True,melodia_trick=True,progress_callback: Callable[[int, int], None] = None,min_freqat=None,max_freqat=None,output_file="output.mid"):
782
+ model_output_to_notes(self.run_inference(audio,progress_callback),onset_thresh = onset_thresh,frame_thresh = frame_thresh,infer_onsets = infer_onsets,min_note_len = min_note_len,min_freq = min_freqat,max_freq = max_freqat,include_pitch_bends = include_pitch_bends,multiple_pitch_bends = multiple_pitch_bends,melodia_trick = melodia_trick,midi_tempo = midi_tempo).write(output_file)
783
+ return output_file