audio2midi 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1281 @@
1
+ import numpy as np
2
+ from librosa.sequence import viterbi_discriminative
3
+ from librosa import note_to_hz,midi_to_hz , load as librosa_load
4
+ from scipy.stats import norm
5
+ from scipy.ndimage import gaussian_filter1d
6
+ from scipy.signal import medfilt ,argrelmax
7
+ from torchaudio.models.conformer import ConformerLayer
8
+ from torch import cat as torch_cat , load as torch_load , from_numpy as torch_from_numpy,no_grad as torch_no_grad ,mean as torch_mean,std as torch_std,sigmoid as torch_sigmoid,nan_to_num as torch_nan_to_num,nn
9
+ from pretty_midi_fix import PrettyMIDI , Instrument , Note , PitchBend , instrument_name_to_program ,note_name_to_number
10
+ from typing import Callable, Dict, List, Optional, Tuple , Literal
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from mir_eval.melody import hz2cents
14
+
15
+
16
+
17
+
18
+ class RegressionPostProcessor(object):
19
+ def __init__(self, frames_per_second, classes_num, onset_threshold,
20
+ offset_threshold, frame_threshold, pedal_offset_threshold,
21
+ begin_note):
22
+ """Postprocess the output probabilities of a transription model to MIDI
23
+ events.
24
+
25
+ Args:
26
+ frames_per_second: float
27
+ classes_num: int
28
+ onset_threshold: float
29
+ offset_threshold: float
30
+ frame_threshold: float
31
+ pedal_offset_threshold: float
32
+ """
33
+ self.frames_per_second = frames_per_second
34
+ self.classes_num = classes_num
35
+ self.onset_threshold = onset_threshold
36
+ self.offset_threshold = offset_threshold
37
+ self.frame_threshold = frame_threshold
38
+ self.pedal_offset_threshold = pedal_offset_threshold
39
+ self.begin_note = begin_note
40
+ self.velocity_scale = 128
41
+
42
+ def output_dict_to_midi_events(self, output_dict):
43
+ """Main function. Post process model outputs to MIDI events.
44
+
45
+ Args:
46
+ output_dict: {
47
+ 'reg_onset_output': (segment_frames, classes_num),
48
+ 'reg_offset_output': (segment_frames, classes_num),
49
+ 'frame_output': (segment_frames, classes_num),
50
+ 'velocity_output': (segment_frames, classes_num),
51
+ 'reg_pedal_onset_output': (segment_frames, 1),
52
+ 'reg_pedal_offset_output': (segment_frames, 1),
53
+ 'pedal_frame_output': (segment_frames, 1)}
54
+
55
+ Outputs:
56
+ est_note_events: list of dict, e.g. [
57
+ {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83},
58
+ {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]
59
+
60
+ est_pedal_events: list of dict, e.g. [
61
+ {'onset_time': 0.17, 'offset_time': 0.96},
62
+ {'osnet_time': 1.17, 'offset_time': 2.65}]
63
+ """
64
+ output_dict['frame_output'] = output_dict['note']
65
+ output_dict['velocity_output'] = output_dict['note']
66
+ output_dict['reg_onset_output'] = output_dict['onset']
67
+ output_dict['reg_offset_output'] = output_dict['offset']
68
+ # Post process piano note outputs to piano note and pedal events information
69
+ (est_on_off_note_vels, est_pedal_on_offs) = \
70
+ self.output_dict_to_note_pedal_arrays(output_dict)
71
+ """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity],
72
+ est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
73
+
74
+ # Reformat notes to MIDI events
75
+ est_note_events = self.detected_notes_to_events(est_on_off_note_vels)
76
+
77
+ if est_pedal_on_offs is None:
78
+ est_pedal_events = None
79
+ else:
80
+ est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)
81
+
82
+ return est_note_events, est_pedal_events
83
+
84
+ def output_dict_to_note_pedal_arrays(self, output_dict):
85
+ """Postprocess the output probabilities of a transription model to MIDI
86
+ events.
87
+
88
+ Args:
89
+ output_dict: dict, {
90
+ 'reg_onset_output': (frames_num, classes_num),
91
+ 'reg_offset_output': (frames_num, classes_num),
92
+ 'frame_output': (frames_num, classes_num),
93
+ 'velocity_output': (frames_num, classes_num),
94
+ ...}
95
+
96
+ Returns:
97
+ est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time,
98
+ offset_time, piano_note and velocity. E.g. [
99
+ [39.74, 39.87, 27, 0.65],
100
+ [11.98, 12.11, 33, 0.69],
101
+ ...]
102
+
103
+ est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time
104
+ and offset_time. E.g. [
105
+ [0.17, 0.96],
106
+ [1.17, 2.65],
107
+ ...]
108
+ """
109
+
110
+ # ------ 1. Process regression outputs to binarized outputs ------
111
+ # For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
112
+ # will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]
113
+
114
+ # Calculate binarized onset output from regression output
115
+ (onset_output, onset_shift_output) = \
116
+ self.get_binarized_output_from_regression(
117
+ reg_output=output_dict['reg_onset_output'],
118
+ threshold=self.onset_threshold, neighbour=2)
119
+
120
+ output_dict['onset_output'] = onset_output # Values are 0 or 1
121
+ output_dict['onset_shift_output'] = onset_shift_output
122
+
123
+ # Calculate binarized offset output from regression output
124
+ (offset_output, offset_shift_output) = \
125
+ self.get_binarized_output_from_regression(
126
+ reg_output=output_dict['reg_offset_output'],
127
+ threshold=self.offset_threshold, neighbour=4)
128
+
129
+ output_dict['offset_output'] = offset_output # Values are 0 or 1
130
+ output_dict['offset_shift_output'] = offset_shift_output
131
+
132
+ if 'reg_pedal_onset_output' in output_dict.keys():
133
+ """Pedal onsets are not used in inference. Instead, frame-wise pedal
134
+ predictions are used to detect onsets. We empirically found this is
135
+ more accurate to detect pedal onsets."""
136
+ pass
137
+
138
+ if 'reg_pedal_offset_output' in output_dict.keys():
139
+ # Calculate binarized pedal offset output from regression output
140
+ (pedal_offset_output, pedal_offset_shift_output) = \
141
+ self.get_binarized_output_from_regression(
142
+ reg_output=output_dict['reg_pedal_offset_output'],
143
+ threshold=self.pedal_offset_threshold, neighbour=4)
144
+
145
+ output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1
146
+ output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output
147
+
148
+ # ------ 2. Process matrices results to event results ------
149
+ # Detect piano notes from output_dict
150
+ est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)
151
+
152
+ est_pedal_on_offs = None
153
+
154
+ return est_on_off_note_vels, est_pedal_on_offs
155
+
156
+ def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
157
+ """Calculate binarized output and shifts of onsets or offsets from the
158
+ regression results.
159
+
160
+ Args:
161
+ reg_output: (frames_num, classes_num)
162
+ threshold: float
163
+ neighbour: int
164
+
165
+ Returns:
166
+ binary_output: (frames_num, classes_num)
167
+ shift_output: (frames_num, classes_num)
168
+ """
169
+ binary_output = np.zeros_like(reg_output)
170
+ shift_output = np.zeros_like(reg_output)
171
+ (frames_num, classes_num) = reg_output.shape
172
+
173
+ for k in range(classes_num):
174
+ x = reg_output[:, k]
175
+ for n in range(neighbour, frames_num - neighbour):
176
+ if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
177
+ binary_output[n, k] = 1
178
+
179
+ """See Section III-D in [1] for deduction.
180
+ [1] Q. Kong, et al., High-resolution Piano Transcription
181
+ with Pedals by Regressing Onsets and Offsets Times, 2020."""
182
+ if x[n - 1] > x[n + 1]:
183
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
184
+ else:
185
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
186
+ shift_output[n, k] = shift
187
+
188
+ return binary_output, shift_output
189
+
190
+ def is_monotonic_neighbour(self, x, n, neighbour):
191
+ """Detect if values are monotonic in both side of x[n].
192
+
193
+ Args:
194
+ x: (frames_num,)
195
+ n: int
196
+ neighbour: int
197
+
198
+ Returns:
199
+ monotonic: bool
200
+ """
201
+ monotonic = True
202
+ for i in range(neighbour):
203
+ if x[n - i] < x[n - i - 1]:
204
+ monotonic = False
205
+ if x[n + i] < x[n + i + 1]:
206
+ monotonic = False
207
+
208
+ return monotonic
209
+
210
+ def output_dict_to_detected_notes(self, output_dict):
211
+ """Postprocess output_dict to piano notes.
212
+
213
+ Args:
214
+ output_dict: dict, e.g. {
215
+ 'onset_output': (frames_num, classes_num),
216
+ 'onset_shift_output': (frames_num, classes_num),
217
+ 'offset_output': (frames_num, classes_num),
218
+ 'offset_shift_output': (frames_num, classes_num),
219
+ 'frame_output': (frames_num, classes_num),
220
+ 'onset_output': (frames_num, classes_num),
221
+ ...}
222
+
223
+ Returns:
224
+ est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets,
225
+ MIDI notes and velocities. E.g.,
226
+ [[39.7375, 39.7500, 27., 0.6638],
227
+ [11.9824, 12.5000, 33., 0.6892],
228
+ ...]
229
+ """
230
+
231
+ est_tuples = []
232
+ est_midi_notes = []
233
+ classes_num = output_dict['frame_output'].shape[-1]
234
+
235
+ for piano_note in range(classes_num):
236
+ """Detect piano notes"""
237
+ est_tuples_per_note = self.note_detection_with_onset_offset_regress(
238
+ frame_output=output_dict['frame_output'][:, piano_note],
239
+ onset_output=output_dict['onset_output'][:, piano_note],
240
+ onset_shift_output=output_dict['onset_shift_output'][:, piano_note],
241
+ offset_output=output_dict['offset_output'][:, piano_note],
242
+ offset_shift_output=output_dict['offset_shift_output'][:, piano_note],
243
+ velocity_output=output_dict['velocity_output'][:, piano_note],
244
+ frame_threshold=self.frame_threshold)
245
+
246
+ est_tuples += est_tuples_per_note
247
+ est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)
248
+
249
+ est_tuples = np.array(est_tuples) # (notes, 5)
250
+ """(notes, 5), the five columns are onset, offset, onset_shift,
251
+ offset_shift and normalized_velocity"""
252
+
253
+ est_midi_notes = np.array(est_midi_notes) # (notes,)
254
+
255
+ onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
256
+ offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
257
+ velocities = est_tuples[:, 4]
258
+
259
+ est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
260
+ """(notes, 3), the three columns are onset_times, offset_times and velocity."""
261
+
262
+ est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)
263
+
264
+ return est_on_off_note_vels
265
+
266
+ def detected_notes_to_events(self, est_on_off_note_vels):
267
+ """Reformat detected notes to midi events.
268
+
269
+ Args:
270
+ est_on_off_vels: (notes, 3), the three columns are onset_times,
271
+ offset_times and velocity. E.g.
272
+ [[32.8376, 35.7700, 0.7932],
273
+ [37.3712, 39.9300, 0.8058],
274
+ ...]
275
+
276
+ Returns:
277
+ midi_events, list, e.g.,
278
+ [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
279
+ {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
280
+ ...]
281
+ """
282
+ midi_events = []
283
+ for i in range(est_on_off_note_vels.shape[0]):
284
+ midi_events.append({
285
+ 'onset_time': est_on_off_note_vels[i][0],
286
+ 'offset_time': est_on_off_note_vels[i][1],
287
+ 'midi_note': int(est_on_off_note_vels[i][2]),
288
+ 'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})
289
+
290
+ return midi_events
291
+
292
+ def note_detection_with_onset_offset_regress(self,frame_output, onset_output,
293
+ onset_shift_output, offset_output, offset_shift_output, velocity_output,
294
+ frame_threshold):
295
+ """Process prediction matrices to note events information.
296
+ First, detect onsets with onset outputs. Then, detect offsets
297
+ with frame and offset outputs.
298
+
299
+ Args:
300
+ frame_output: (frames_num,)
301
+ onset_output: (frames_num,)
302
+ onset_shift_output: (frames_num,)
303
+ offset_output: (frames_num,)
304
+ offset_shift_output: (frames_num,)
305
+ velocity_output: (frames_num,)
306
+ frame_threshold: float
307
+ Returns:
308
+ output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
309
+ e.g., [
310
+ [1821, 1909, 0.47498, 0.3048533, 0.72119445],
311
+ [1909, 1947, 0.30730522, -0.45764327, 0.64200014],
312
+ ...]
313
+ """
314
+ output_tuples = []
315
+ bgn = None
316
+ frame_disappear = None
317
+ offset_occur = None
318
+
319
+ for i in range(onset_output.shape[0]):
320
+ if onset_output[i] == 1:
321
+ """Onset detected"""
322
+ if bgn:
323
+ """Consecutive onsets. E.g., pedal is not released, but two
324
+ consecutive notes being played."""
325
+ fin = max(i - 1, 0)
326
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
327
+ 0, velocity_output[bgn]])
328
+ frame_disappear, offset_occur = None, None
329
+ bgn = i
330
+
331
+ if bgn and i > bgn:
332
+ """If onset found, then search offset"""
333
+ if frame_output[i] <= frame_threshold and not frame_disappear:
334
+ """Frame disappear detected"""
335
+ frame_disappear = i
336
+
337
+ if offset_output[i] == 1 and not offset_occur:
338
+ """Offset detected"""
339
+ offset_occur = i
340
+
341
+ if frame_disappear:
342
+ if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
343
+ """bgn --------- offset_occur --- frame_disappear"""
344
+ fin = offset_occur
345
+ else:
346
+ """bgn --- offset_occur --------- frame_disappear"""
347
+ fin = frame_disappear
348
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
349
+ offset_shift_output[fin], velocity_output[bgn]])
350
+ bgn, frame_disappear, offset_occur = None, None, None
351
+
352
+ if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
353
+ """Offset not detected"""
354
+ fin = i
355
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
356
+ offset_shift_output[fin], velocity_output[bgn]])
357
+ bgn, frame_disappear, offset_occur = None, None, None
358
+
359
+ # Sort pairs by onsets
360
+ output_tuples.sort(key=lambda pair: pair[0])
361
+
362
+ return output_tuples
363
+
364
+ class PerformanceLabel:
365
+ """
366
+ The dataset labeling class for performance representations. Currently, includes onset, note, and fine-grained f0
367
+ representations. Note min, note max, and f0_bin_per_semitone values are to be arranged per instrument. The default
368
+ values are for violin performance analysis. Fretted instruments might not require such f0 resolutions per semitone.
369
+ """
370
+ def __init__(self, note_min='F#3', note_max='C8', f0_bins_per_semitone=9, f0_smooth_std_c=None,
371
+ onset_smooth_std=0.7, f0_tolerance_c=200):
372
+ midi_min = note_name_to_number(note_min)
373
+ midi_max = note_name_to_number(note_max)
374
+ self.midi_centers = np.arange(midi_min, midi_max)
375
+ self.onset_smooth_std=onset_smooth_std # onset smoothing along time axis (compensate for alignment)
376
+
377
+ f0_hz_range = note_to_hz([note_min, note_max])
378
+ f0_c_min, f0_c_max = hz2cents(f0_hz_range)
379
+ self.f0_granularity_c = 100/f0_bins_per_semitone
380
+ if not f0_smooth_std_c:
381
+ f0_smooth_std_c = self.f0_granularity_c * 5/4 # Keep the ratio from the CREPE paper (20 cents and 25 cents)
382
+ self.f0_smooth_std_c = f0_smooth_std_c
383
+
384
+ self.f0_centers_c = np.arange(f0_c_min, f0_c_max, self.f0_granularity_c)
385
+ self.f0_centers_hz = 10 * 2 ** (self.f0_centers_c / 1200)
386
+ self.f0_n_bins = len(self.f0_centers_c)
387
+
388
+ self.pdf_normalizer = norm.pdf(0)
389
+
390
+ self.f0_c2hz = lambda c: 10*2**(c/1200)
391
+ self.f0_hz2c = hz2cents
392
+ self.midi_centers_c = self.f0_hz2c(midi_to_hz(self.midi_centers))
393
+
394
+ self.f0_tolerance_bins = int(f0_tolerance_c/self.f0_granularity_c)
395
+ self.f0_transition_matrix = gaussian_filter1d(np.eye(2*self.f0_tolerance_bins + 1), 25/self.f0_granularity_c)
396
+
397
+ def f0_c2label(self, pitch_c):
398
+ """
399
+ Convert a single f0 value in cents to a one-hot label vector with smoothing (i.e., create a gaussian blur around
400
+ the target f0 bin for regularization and training stability. The blur is controlled by self.f0_smooth_std_c
401
+ :param pitch_c: a single pitch value in cents
402
+ :return: one-hot label vector with frequency blur
403
+ """
404
+ result = norm.pdf((self.f0_centers_c - pitch_c) / self.f0_smooth_std_c).astype(np.float32)
405
+ result /= self.pdf_normalizer
406
+ return result
407
+
408
+ def f0_label2c(self, salience, center=None):
409
+ """
410
+ Convert the salience predictions to monophonic f0 in cents. Only outputs a single f0 value per frame!
411
+ :param salience: f0 activations
412
+ :param center: f0 center bin to calculate the weighted average. Use argmax if empty
413
+ :return: f0 array per frame (in cents).
414
+ """
415
+ if salience.ndim == 1:
416
+ if center is None:
417
+ center = int(np.argmax(salience))
418
+ start = max(0, center - 4)
419
+ end = min(len(salience), center + 5)
420
+ salience = salience[start:end]
421
+ product_sum = np.sum(salience * self.f0_centers_c[start:end])
422
+ weight_sum = np.sum(salience)
423
+ return product_sum / np.clip(weight_sum, 1e-8, None)
424
+ if salience.ndim == 2:
425
+ return np.array([self.f0_label2c(salience[i, :]) for i in range(salience.shape[0])])
426
+ raise Exception("label should be either 1d or 2d ndarray")
427
+
428
+ def fill_onset_matrix(self, onsets, window, feature_rate):
429
+ """
430
+ Create a sparse onset matrix from window and onsets (per-semitone). Apply a gaussian smoothing (along time)
431
+ so that we can tolerate better the alignment problems. This is similar to the frequency smoothing for the f0.
432
+ The temporal smoothing is controlled by the parameter self.onset_smooth_std
433
+ :param onsets: A 2d np.array of individual note onsets with their respective time values
434
+ (Nx2: time in seconds - midi number)
435
+ :param window: Timestamps for the frame centers of the sparse matrix
436
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
437
+ :return: onset_roll: A sparse matrix filled with temporally blurred onsets.
438
+ """
439
+ onsets = self.get_window_feats(onsets, window, feature_rate)
440
+ onset_roll = np.zeros((len(window), len(self.midi_centers)))
441
+ for onset in onsets:
442
+ onset, note = onset # it was a pair with time and midi note
443
+ if self.midi_centers[0] < note < self.midi_centers[-1]: # midi note should be in the range defined
444
+ note = int(note) - self.midi_centers[0] # find the note index in our range
445
+ onset = (onset*feature_rate)-window[0] # onset index (as float but in frames, not in seconds!)
446
+ start = max(0, int(onset) - 3)
447
+ end = min(len(window) - 1, int(onset) + 3)
448
+ try:
449
+ vals = norm.pdf(np.linspace(start - onset, end - onset, end - start + 1) / self.onset_smooth_std)
450
+ # if you increase 0.7 you smooth the peak
451
+ # if you decrease it, e.g., 0.1, it becomes too peaky! around 0.5-0.7 seems ok
452
+ vals /= self.pdf_normalizer
453
+ onset_roll[start:end + 1, note] += vals
454
+ except ValueError:
455
+ print('start',start, 'onset', onset, 'end', end)
456
+ return onset_roll, onsets
457
+
458
+ def fill_note_matrix(self, notes, window, feature_rate):
459
+ """
460
+ Create the note matrix (piano roll) from window timestamps and note values per frame.
461
+ :param notes: A 2d np.array of individual notes with their active time values Nx2
462
+ :param window: Timestamps for the frame centers of the output
463
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
464
+ :return note_roll: The piano roll in the defined range of [note_min, note_max).
465
+ """
466
+ notes = self.get_window_feats(notes, window, feature_rate)
467
+
468
+ # take the notes in the midi range defined
469
+ notes = notes[np.logical_and(notes[:,1]>=self.midi_centers[0], notes[:,1]<=self.midi_centers[-1]),:]
470
+
471
+ times = (notes[:,0]*feature_rate - window[0]).astype(int) # in feature samples (fs:self.hop/self.sr)
472
+ notes = (notes[:,1] - self.midi_centers[0]).astype(int)
473
+
474
+ note_roll = np.zeros((len(window), len(self.midi_centers)))
475
+ note_roll[(times, notes)] = 1
476
+ return note_roll, notes
477
+
478
+ def fill_f0_matrix(self, f0s, window, feature_rate):
479
+ """
480
+ Unlike the labels for onsets and notes, f0 label is only relevant for strictly monophonic regions! Thus, this
481
+ function returns a boolean which represents where to apply the given values.
482
+ Never back-propagate without the boolean! Empty frames mean that the label is not that reliable.
483
+
484
+ :param f0s: A 2d np.array of f0 values with the time they belong to (2xN: time in seconds - f0 in Hz)
485
+ :param window: Timestamps for the frame centers of the output
486
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
487
+
488
+ :return f0_roll: f0 label matrix and
489
+ f0_hz: f0 values in Hz
490
+ annotation_bool: A boolean array representing which frames have reliable f0 annotations.
491
+ """
492
+ f0s = self.get_window_feats(f0s, window, feature_rate)
493
+ f0_cents = np.zeros_like(window, dtype=float)
494
+ f0s[:,1] = self.f0_hz2c(f0s[:,1]) # convert f0 in hz to cents
495
+
496
+ annotation_bool = np.zeros_like(window, dtype=bool)
497
+ f0_roll = np.zeros((len(window), len(self.f0_centers_c)))
498
+ times_in_frame = f0s[:, 0]*feature_rate - window[0]
499
+ for t, f0 in enumerate(f0s):
500
+ t = times_in_frame[t]
501
+ if t%1 < 0.25: # only consider it as annotation if the f0 values is really close to the frame center
502
+ t = int(np.round(t))
503
+ f0_roll[t] = self.f0_c2label(f0[1])
504
+ annotation_bool[t] = True
505
+ f0_cents[t] = f0[1]
506
+
507
+ return f0_roll, f0_cents, annotation_bool
508
+
509
+ @staticmethod
510
+ def get_window_feats(time_feature_matrix, window, feature_rate):
511
+ """
512
+ Restrict the feature matrix to the features that are inside the window
513
+ :param window: Timestamps for the frame centers of the output
514
+ :param time_feature_matrix: A 2d array of Nx2 per the entire file.
515
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
516
+ :return: window_features: the features inside the given window
517
+ """
518
+ start = time_feature_matrix[:,0]>(window[0]-0.5)/feature_rate
519
+ end = time_feature_matrix[:,0]<(window[-1]+0.5)/feature_rate
520
+ window_features = np.logical_and(start, end)
521
+ window_features = np.array(time_feature_matrix[window_features,:])
522
+ return window_features
523
+
524
+ def represent_midi(self, midi, feature_rate):
525
+ """
526
+ Represent a midi file as sparse matrices of onsets, offsets, and notes. No f0 is included.
527
+ :param midi: A midi file (either a path or a pretty_midi_fix.PrettyMIDI object)
528
+ :param feature_rate: The feature rate in Hz
529
+ :return: dict {onset, offset, note, time}: Same format with the model's learning and outputs
530
+ """
531
+ def _get_onsets_offsets_frames(midi_content):
532
+ if isinstance(midi_content, str):
533
+ midi_content = PrettyMIDI(midi_content)
534
+ onsets = []
535
+ offsets = []
536
+ frames = []
537
+ for instrument in midi_content.instruments:
538
+ for note in instrument.notes:
539
+ start = int(np.round(note.start * feature_rate))
540
+ end = int(np.round(note.end * feature_rate))
541
+ note_times = (np.arange(start, end+0.5)/feature_rate)[:, np.newaxis]
542
+ note_pitch = np.full_like(note_times, fill_value=note.pitch)
543
+ onsets.append([note.start, note.pitch])
544
+ offsets.append([note.end, note.pitch])
545
+ frames.append(np.hstack([note_times, note_pitch]))
546
+ onsets = np.vstack(onsets)
547
+ offsets = np.vstack(offsets)
548
+ frames = np.vstack(frames)
549
+ return onsets, offsets, frames, midi_content
550
+ onset_array, offset_array, frame_array, midi_object = _get_onsets_offsets_frames(midi)
551
+ window = np.arange(frame_array[0, 0]*feature_rate, frame_array[-1, 0]*feature_rate, dtype=int)
552
+ onset_roll, _ = self.fill_onset_matrix(onset_array, window, feature_rate)
553
+ offset_roll, _ = self.fill_onset_matrix(offset_array, window, feature_rate)
554
+ note_roll, _ = self.fill_note_matrix(frame_array, window, feature_rate)
555
+ start_anchor = onset_array[onset_array[:, 0]==np.min(onset_array[:, 0])]
556
+ end_anchor = offset_array[offset_array[:, 0]==np.max(offset_array[:, 0])]
557
+ return {
558
+ 'midi': midi_object,
559
+ 'note': note_roll,
560
+ 'onset': onset_roll,
561
+ 'offset': offset_roll,
562
+ 'time': window/feature_rate,
563
+ 'start_anchor': start_anchor,
564
+ 'end_anchor': end_anchor
565
+ }
566
+
567
+ class NoPadConvBlock(nn.Module):
568
+ def __init__(self, f, w, s, d, in_channels):
569
+ super().__init__()
570
+
571
+ self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1),
572
+ dilation=(d, 1))
573
+ self.relu = nn.ReLU()
574
+ self.bn = nn.BatchNorm2d(f)
575
+ self.pool = nn.MaxPool2d(kernel_size=(2, 1))
576
+ self.dropout = nn.Dropout(0.25)
577
+
578
+ def forward(self, x):
579
+ x = self.conv2d(x)
580
+ x = self.relu(x)
581
+ x = self.bn(x)
582
+ x = self.pool(x)
583
+ x = self.dropout(x)
584
+ return x
585
+
586
+ class TinyPathway(nn.Module):
587
+ def __init__(self, dilation=1, hop=256, localize=False,
588
+ model_capacity="full", n_layers=6, chunk_size=256):
589
+ super().__init__()
590
+
591
+ capacity_multiplier = {
592
+ 'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
593
+ }[model_capacity]
594
+ self.layers = [1, 2, 3, 4, 5, 6]
595
+ self.layers = self.layers[:n_layers]
596
+ filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]]
597
+ filters = [1] + filters
598
+ widths = [512, 64, 64, 64, 32, 32]
599
+ strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize)
600
+ strides[0] = strides[0]*4 # apply 4 times more stride at the first layer
601
+ dilations = self.deter_dilations(dilation)
602
+
603
+ for i in range(len(self.layers)):
604
+ f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i]
605
+ self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel))
606
+ self.chunk_size = chunk_size
607
+ self.input_window, self.hop = self.find_input_size_for_pathway()
608
+ self.out_dim = filters[n_layers]
609
+
610
+ def find_input_size_for_pathway(self):
611
+ def find_input_size(output_size, kernel_size, stride, dilation, padding):
612
+ num = (stride*(output_size-1)) + 1
613
+ input_size = num - 2*padding + dilation*(kernel_size-1)
614
+ return input_size
615
+ conv_calc, n = {}, 0
616
+ for i in self.layers:
617
+ layer = self.__getattr__("conv%d" % (i-1))
618
+ for mm in layer.modules():
619
+ if hasattr(mm, 'kernel_size'):
620
+ try:
621
+ d = mm.dilation[0]
622
+ except TypeError:
623
+ d = mm.dilation
624
+ conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d]
625
+ n += 1
626
+ out = self.chunk_size
627
+ hop = 1
628
+ for n in sorted(conv_calc.keys())[::-1]:
629
+ kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n]
630
+ out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n)
631
+ hop = hop*stride_n
632
+ return out, hop
633
+
634
+ def deter_dilations(self, total_dilation, localize=False):
635
+ n_layers = len(self.layers)
636
+ if localize: # e.g., 32*1023 window and 3 layers -> [1, 1, 32]
637
+ a = [total_dilation] + [1 for _ in range(n_layers-1)]
638
+ else: # e.g., 32*1023 window and 3 layers -> [4, 4, 2]
639
+ total_dilation = int(np.log2(total_dilation))
640
+ a = []
641
+ for layer in range(n_layers):
642
+ this_dilation = int(np.ceil(total_dilation/(n_layers-layer)))
643
+ a.append(2**this_dilation)
644
+ total_dilation = total_dilation - this_dilation
645
+ return a[::-1]
646
+
647
+ def forward(self, x):
648
+ x = x.view(x.shape[0], 1, -1, 1)
649
+ for i in range(len(self.layers)):
650
+ x = self.__getattr__("conv%d" % i)(x)
651
+ x = x.permute(0, 3, 2, 1)
652
+ return x
653
+
654
+
655
+
656
+
657
+ class Pitch_Det(nn.Module):
658
+ def __init__(
659
+ self,
660
+ pathway_multiscale: int = 32,
661
+ num_pathway_layers: int = 2,
662
+ chunk_size: int = 256,
663
+ hop_length: int = 256,
664
+ encoder_dim: int = 256,
665
+ sr: int = 44100,
666
+ num_heads: int = 4,
667
+ ffn_dim: int = 128,
668
+ num_separator_layers: int = 16,
669
+ num_representation_layers: int = 4,
670
+ depthwise_conv_kernel_size: int = 31,
671
+ dropout: float = 0.25,
672
+ use_group_norm: bool = False,
673
+ convolution_first: bool = False,
674
+ labeling=PerformanceLabel(),
675
+ wiring='tiktok',
676
+ model_capacity="full"
677
+ ):
678
+ super().__init__()
679
+ self.labeling = labeling
680
+ self.sr = sr
681
+ self.window_size = 1024
682
+ self.hop_length = hop_length
683
+ self.f0_bins_per_semitone = int(np.round(100/self.labeling.f0_granularity_c))
684
+
685
+ self.main = TinyPathway(dilation=1, hop=hop_length, localize=True,
686
+ n_layers=num_pathway_layers, chunk_size=chunk_size,model_capacity=model_capacity)
687
+ self.attendant = TinyPathway(dilation=pathway_multiscale, hop=hop_length, localize=False,
688
+ n_layers=num_pathway_layers, chunk_size=chunk_size,model_capacity=model_capacity)
689
+ assert self.main.hop == self.attendant.hop # they should output with the same sample rate
690
+ print('hop in samples:', self.main.hop)
691
+ self.input_window = self.attendant.input_window
692
+
693
+ self.encoder_dim = encoder_dim
694
+ self.dropout = nn.Dropout(dropout)
695
+
696
+ # merge two streams into a conformer input
697
+ self.stream_merger = nn.Sequential(self.dropout,
698
+ nn.Linear(self.main.out_dim + self.attendant.out_dim, self.encoder_dim))
699
+
700
+
701
+
702
+ print('main stream window:', self.main.input_window,
703
+ ', attendant stream window:', self.attendant.input_window,
704
+ ', conformer input dim:', self.encoder_dim)
705
+
706
+ center = ((chunk_size - 1) * self.main.hop) # region labeled with pitch track
707
+ main_overlap = self.main.input_window - center
708
+ main_overlap = [int(np.floor(main_overlap / 2)), int(np.ceil(main_overlap / 2))]
709
+ attendant_overlap = self.attendant.input_window - center
710
+ attendant_overlap = [int(np.floor(attendant_overlap / 2)), int(np.ceil(attendant_overlap / 2))]
711
+ print('main frame overlap:', main_overlap, ', attendant frame overlap:', attendant_overlap)
712
+ main_crop_relative = [attendant_overlap[0] - main_overlap[0], main_overlap[1] - attendant_overlap[1]]
713
+ print('crop for main pathway', main_crop_relative)
714
+ print("Total sequence duration is", self.attendant.input_window, 'samples')
715
+ print('Main stream receptive field for one frame is', (self.main.input_window - center), 'samples')
716
+ print('Attendant stream receptive field for one frame is', (self.attendant.input_window - center), 'samples')
717
+ self.frame_overlap = attendant_overlap
718
+
719
+ self.main_stream_crop = main_crop_relative
720
+ self.max_window_size = self.attendant.input_window
721
+ self.chunk_size = chunk_size
722
+
723
+ self.separator_stream = nn.ModuleList( # source-separation, reinvented
724
+ [
725
+ ConformerLayer(
726
+ input_dim=self.encoder_dim,
727
+ ffn_dim=ffn_dim,
728
+ num_attention_heads=num_heads,
729
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
730
+ dropout=dropout,
731
+ use_group_norm=use_group_norm,
732
+ convolution_first=convolution_first,
733
+ )
734
+ for _ in range(num_separator_layers)
735
+ ]
736
+ )
737
+
738
+ self.f0_stream = nn.ModuleList(
739
+ [
740
+ ConformerLayer(
741
+ input_dim=self.encoder_dim,
742
+ ffn_dim=ffn_dim,
743
+ num_attention_heads=num_heads,
744
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
745
+ dropout=dropout,
746
+ use_group_norm=use_group_norm,
747
+ convolution_first=convolution_first,
748
+ )
749
+ for _ in range(num_representation_layers)
750
+ ]
751
+ )
752
+ self.f0_head = nn.Linear(self.encoder_dim, len(self.labeling.f0_centers_c))
753
+
754
+ self.note_stream = nn.ModuleList(
755
+ [
756
+ ConformerLayer(
757
+ input_dim=self.encoder_dim,
758
+ ffn_dim=ffn_dim,
759
+ num_attention_heads=num_heads,
760
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
761
+ dropout=dropout,
762
+ use_group_norm=use_group_norm,
763
+ convolution_first=convolution_first,
764
+ )
765
+ for _ in range(num_representation_layers)
766
+ ]
767
+ )
768
+ self.note_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
769
+
770
+ self.onset_stream = nn.ModuleList(
771
+ [
772
+ ConformerLayer(
773
+ input_dim=self.encoder_dim,
774
+ ffn_dim=ffn_dim,
775
+ num_attention_heads=num_heads,
776
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
777
+ dropout=dropout,
778
+ use_group_norm=use_group_norm,
779
+ convolution_first=convolution_first,
780
+ )
781
+ for _ in range(num_representation_layers)
782
+ ]
783
+ )
784
+ self.onset_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
785
+
786
+ self.offset_stream = nn.ModuleList(
787
+ [
788
+ ConformerLayer(
789
+ input_dim=self.encoder_dim,
790
+ ffn_dim=ffn_dim,
791
+ num_attention_heads=num_heads,
792
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
793
+ dropout=dropout,
794
+ use_group_norm=use_group_norm,
795
+ convolution_first=convolution_first,
796
+ )
797
+ for _ in range(num_representation_layers)
798
+ ]
799
+ )
800
+ self.offset_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
801
+
802
+ self.labeling = labeling
803
+ self.double_merger = nn.Sequential(self.dropout, nn.Linear(2 * self.encoder_dim, self.encoder_dim))
804
+ self.triple_merger = nn.Sequential(self.dropout, nn.Linear(3 * self.encoder_dim, self.encoder_dim))
805
+ self.wiring = wiring
806
+
807
+ print('Total parameter count: ', self.count_parameters())
808
+
809
+ def count_parameters(self) -> int:
810
+ """ Count parameters of encoder """
811
+ return sum([p.numel() for p in self.parameters()])
812
+
813
+ def stream(self, x, representation, key_padding_mask=None):
814
+ for i, layer in enumerate(self.__getattr__('{}_stream'.format(representation))):
815
+ x = layer(x, key_padding_mask)
816
+ return x
817
+
818
+ def head(self, x, representation):
819
+ return self.__getattr__('{}_head'.format(representation))(x)
820
+
821
+ def forward(self, x, key_padding_mask=None):
822
+
823
+ # two auditory streams followed by the separator stream to ensure timbre-awareness
824
+ x_attendant = self.attendant(x)
825
+ x_main = self.main(x[:, self.main_stream_crop[0]:self.main_stream_crop[1]])
826
+ x = self.stream_merger(torch_cat((x_attendant, x_main), -1).squeeze(1))
827
+ x = self.stream(x, 'separator', key_padding_mask)
828
+
829
+ f0 = self.stream(x, 'f0', key_padding_mask) # they say this is a low level feature :)
830
+
831
+ if self.wiring == 'parallel':
832
+ note = self.stream(x, 'note', key_padding_mask)
833
+ onset = self.stream(x, 'onset', key_padding_mask)
834
+ offset = self.stream(x, 'offset', key_padding_mask)
835
+
836
+ elif self.wiring == 'tiktok':
837
+ onset = self.stream(x, 'onset', key_padding_mask)
838
+ offset = self.stream(x, 'offset', key_padding_mask)
839
+ # f0 is disconnected, note relies on separator, onset, and offset
840
+ note = self.stream(self.triple_merger(torch_cat((x, onset, offset), -1)), 'note', key_padding_mask)
841
+
842
+ elif self.wiring == 'tiktok2':
843
+ onset = self.stream(x, 'onset', key_padding_mask)
844
+ offset = self.stream(x, 'offset', key_padding_mask)
845
+ # note is connected to f0, onset, and offset
846
+ note = self.stream(self.triple_merger(torch_cat((f0, onset, offset), -1)), 'note', key_padding_mask)
847
+
848
+ elif self.wiring == 'spotify':
849
+ # note is connected to f0 only
850
+ note = self.stream(f0, 'note', key_padding_mask)
851
+ # here onset and onsets are higher-level features informed by the separator and note
852
+ onset = self.stream(self.double_merger(torch_cat((x, note), -1)), 'onset', key_padding_mask)
853
+ offset = self.stream(self.double_merger(torch_cat((x, note), -1)), 'offset', key_padding_mask)
854
+
855
+ else:
856
+ # onset and offset are connected to f0 and separator streams
857
+ onset = self.stream(self.double_merger(torch_cat((x, f0), -1)), 'onset', key_padding_mask)
858
+ offset = self.stream(self.double_merger(torch_cat((x, f0), -1)), 'offset', key_padding_mask)
859
+ # note is connected to f0, onset, and offset streams
860
+ note = self.stream(self.triple_merger(torch_cat((f0, onset, offset), -1)), 'note', key_padding_mask)
861
+
862
+
863
+ return {'f0': self.head(f0, 'f0'),
864
+ 'note': self.head(note, 'note'),
865
+ 'onset': self.head(onset, 'onset'),
866
+ 'offset': self.head(offset, 'offset')}
867
+
868
+
869
+ class Violin_Pitch_Det(Pitch_Det):
870
+ def __init__(self,model=hf_hub_download("shethjenil/Audio2Midi_Models","violin.pt"),model_capacity:Literal['tiny', 'small', 'medium', 'large', 'full']="full",device="cpu"):
871
+ model_conf = {
872
+ "wiring": "parallel",
873
+ "sampling_rate": 44100,
874
+ "pathway_multiscale": 4,
875
+ "num_pathway_layers": 2,
876
+ "num_separator_layers": 16,
877
+ "num_representation_layers": 4,
878
+ "hop_length": 256,
879
+ "chunk_size": 512,
880
+ "minSNR": -32,
881
+ "maxSNR": 96,
882
+ "note_low": "F#3",
883
+ "note_high": "E8",
884
+ "f0_bins_per_semitone": 10,
885
+ "f0_smooth_std_c": 12,
886
+ "onset_smooth_std": 0.7
887
+ }
888
+ super().__init__(pathway_multiscale=model_conf['pathway_multiscale'],num_pathway_layers=model_conf['num_pathway_layers'], wiring=model_conf['wiring'],hop_length=model_conf['hop_length'], chunk_size=model_conf['chunk_size'],labeling=PerformanceLabel(note_min=model_conf['note_low'], note_max=model_conf['note_high'],f0_bins_per_semitone=model_conf['f0_bins_per_semitone'],f0_tolerance_c=200,f0_smooth_std_c=model_conf['f0_smooth_std_c'], onset_smooth_std=model_conf['onset_smooth_std']), sr=model_conf['sampling_rate'],model_capacity=model_capacity)
889
+ self.load_state_dict(torch_load(model, map_location=device,weights_only=True))
890
+ self.eval()
891
+
892
+ def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
893
+ include_pitch_bends: bool = True,
894
+ ) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
895
+ """Convert model output to notes
896
+ """
897
+ if postprocessing == 'spotify':
898
+ estimated_notes = self.spotify_create_notes(
899
+ output["note"],
900
+ output["onset"],
901
+ note_low=self.labeling.midi_centers[0],
902
+ note_high=self.labeling.midi_centers[-1],
903
+ onset_thresh=0.5,
904
+ frame_thresh=0.3,
905
+ infer_onsets=True,
906
+ min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
907
+ melodia_trick=True,
908
+ )
909
+
910
+ elif postprocessing == 'tiktok':
911
+ postprocessor = RegressionPostProcessor(
912
+ frames_per_second=self.sr / self.hop_length,
913
+ classes_num=self.labeling.midi_centers.shape[0],
914
+ begin_note=self.labeling.midi_centers[0],
915
+ onset_threshold=0.2,
916
+ offset_threshold=0.2,
917
+ frame_threshold=0.3,
918
+ pedal_offset_threshold=0.5,
919
+ )
920
+ tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
921
+ estimated_notes = []
922
+ for list_item in tiktok_note_dict:
923
+ if list_item['offset_time'] > 0.6 + list_item['onset_time']:
924
+ estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
925
+ int(np.ceil(list_item['offset_time']/(output['time'][1]))),
926
+ list_item['midi_note'], list_item['velocity']/128))
927
+
928
+ if include_pitch_bends:
929
+ estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
930
+ else:
931
+ estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
932
+
933
+ times_s = output['time']
934
+ estimated_notes_time_seconds = [
935
+ (times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
936
+ ]
937
+
938
+ return estimated_notes_time_seconds
939
+
940
+ def note2midi(
941
+ self,
942
+ note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
943
+ midi_tempo: float = 120,
944
+ ):
945
+ """Create a pretty_midi_fix object from note events
946
+ :param note_events_with_pitch_bends: list of tuples
947
+ [(start_time_seconds, end_time_seconds, pitch_midi, amplitude, [pitch_bend])]
948
+ :param midi_tempo: MIDI tempo (BPM)
949
+ :return: PrettyMIDI object
950
+ """
951
+ mid = PrettyMIDI(initial_tempo=midi_tempo)
952
+
953
+ # Create a single instrument (e.g., program=40 = violin)
954
+ instrument = Instrument(program=40)
955
+
956
+ for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
957
+ note = Note(
958
+ velocity=int(np.round(127 * amplitude)),
959
+ pitch=note_number,
960
+ start=start_time,
961
+ end=end_time,
962
+ )
963
+ instrument.notes.append(note)
964
+
965
+ if pitch_bend is not None and isinstance(pitch_bend, (list, np.ndarray)):
966
+ pitch_bend = np.asarray(pitch_bend)
967
+ pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
968
+ for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
969
+ instrument.pitch_bends.append(PitchBend(pb_midi, pb_time))
970
+
971
+ # Add the single instrument to the MIDI object
972
+ mid.instruments.append(instrument)
973
+
974
+ return mid
975
+
976
+ def get_pitch_bends(
977
+ self,
978
+ contours: np.ndarray, note_events: List[Tuple[int, int, int, float]],
979
+ timing_refinement_range: int = 0, to_midi: bool = True,
980
+ ) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
981
+ """
982
+ Given note events and contours, estimate pitch bends per note.
983
+ Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
984
+ The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
985
+ the start and end times of each note event.
986
+ Args:
987
+ contours: Matrix of estimated pitch contours
988
+ note_events: note event tuple
989
+ timing_refinement_range: if > 0, refine onset/offset boundaries with f0 confidence
990
+ to_midi: whether to convert pitch bends to midi pitch bends. If False, return pitch estimates in the format
991
+ [time (index), pitch (Hz), confidence in range [0, 1]].
992
+ Returns:
993
+ note events with pitch bends
994
+ """
995
+
996
+ f0_matrix = [] # [time (index), pitch (Hz), confidence in range [0, 1]]
997
+ note_events_with_pitch_bends = []
998
+ for start_idx, end_idx, pitch_midi, amplitude in note_events:
999
+ if timing_refinement_range:
1000
+ start_idx = np.max([0, start_idx - timing_refinement_range])
1001
+ end_idx = np.min([contours.shape[0], end_idx + timing_refinement_range])
1002
+ freq_idx = int(np.round(self.midi_pitch_to_contour_bin(pitch_midi)))
1003
+ freq_start_idx = np.max([freq_idx - self.labeling.f0_tolerance_bins, 0])
1004
+ freq_end_idx = np.min([self.labeling.f0_n_bins, freq_idx + self.labeling.f0_tolerance_bins + 1])
1005
+
1006
+ trans_start_idx = np.max([0, self.labeling.f0_tolerance_bins - freq_idx])
1007
+ trans_end_idx = (2 * self.labeling.f0_tolerance_bins + 1) - \
1008
+ np.max([0, freq_idx - (self.labeling.f0_n_bins - self.labeling.f0_tolerance_bins - 1)])
1009
+
1010
+ # apply regional viterbi to estimate the intonation
1011
+ # observation probabilities come from the f0_roll matrix
1012
+ observation = contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
1013
+ observation = observation / observation.sum(axis=1)[:, None]
1014
+ observation[np.isnan(observation.sum(axis=1)), :] = np.ones(freq_end_idx - freq_start_idx) * 1 / (
1015
+ freq_end_idx - freq_start_idx)
1016
+
1017
+ # transition probabilities assure continuity
1018
+ transition = self.labeling.f0_transition_matrix[trans_start_idx:trans_end_idx,
1019
+ trans_start_idx:trans_end_idx] + 1e-6
1020
+ transition = transition / np.sum(transition, axis=1)[:, None]
1021
+
1022
+ path = viterbi_discriminative(observation.T / observation.sum(axis=1), transition) + freq_start_idx
1023
+
1024
+ cents = np.array([self.labeling.f0_label2c(contours[i + start_idx, :], path[i]) for i in range(len(path))])
1025
+ bends = cents - self.labeling.midi_centers_c[pitch_midi - self.labeling.midi_centers[0]]
1026
+ if to_midi:
1027
+ bends = (bends * 4096 / 100).astype(int)
1028
+ bends[bends > 8191] = 8191
1029
+ bends[bends < -8192] = -8192
1030
+
1031
+ if timing_refinement_range:
1032
+ confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
1033
+ threshold = np.median(confidences)
1034
+ threshold = (np.median(confidences > threshold) + threshold) / 2 # some magic
1035
+ median_kernel = 2 * (timing_refinement_range // 2) + 1 # some more magic
1036
+ confidences = medfilt(confidences, kernel_size=median_kernel)
1037
+ conf_bool = confidences > threshold
1038
+ onset_idx = np.argmax(conf_bool)
1039
+ offset_idx = len(confidences) - np.argmax(conf_bool[::-1])
1040
+ bends = bends[onset_idx:offset_idx]
1041
+ start_idx = start_idx + onset_idx
1042
+ end_idx = start_idx + offset_idx
1043
+
1044
+ note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
1045
+ else:
1046
+ confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
1047
+ time_idx = np.arange(len(path)) + start_idx
1048
+ # f0_hz = self.labeling.f0_c2hz(cents)
1049
+ possible_f0s = np.array([time_idx, cents, confidences]).T
1050
+ f0_matrix.append(possible_f0s[np.abs(bends)<100]) # filter out pitch bends that are too large
1051
+ if not to_midi:
1052
+ return np.vstack(f0_matrix)
1053
+ else:
1054
+ return note_events_with_pitch_bends
1055
+
1056
+ def midi_pitch_to_contour_bin(self, pitch_midi: int) -> np.array:
1057
+ """Convert midi pitch to corresponding index in contour matrix
1058
+ Args:
1059
+ pitch_midi: pitch in midi
1060
+ Returns:
1061
+ index in contour matrix
1062
+ """
1063
+ pitch_hz = midi_to_hz(pitch_midi)
1064
+ return np.argmin(np.abs(self.labeling.f0_centers_hz - pitch_hz))
1065
+
1066
+ def get_inferred_onsets(self,onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array:
1067
+ """
1068
+ Infer onsets from large changes in note roll matrix amplitudes.
1069
+ Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
1070
+ :param onset_roll: Onset activation matrix (n_times, n_freqs).
1071
+ :param note_roll: Frame-level note activation matrix (n_times, n_freqs).
1072
+ :param n_diff: Differences used to detect onsets.
1073
+ :return: The maximum between the predicted onsets and its differences.
1074
+ """
1075
+
1076
+ diffs = []
1077
+ for n in range(1, n_diff + 1):
1078
+ frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll])
1079
+ diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
1080
+ frame_diff = np.min(diffs, axis=0)
1081
+ frame_diff[frame_diff < 0] = 0
1082
+ frame_diff[:n_diff, :] = 0
1083
+ frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
1084
+
1085
+ max_onsets_diff = np.max([onset_roll, frame_diff],
1086
+ axis=0) # use the max of the predicted onsets and the differences
1087
+
1088
+ return max_onsets_diff
1089
+
1090
+ def spotify_create_notes(
1091
+ self,
1092
+ note_roll: np.array,
1093
+ onset_roll: np.array,
1094
+ onset_thresh: float,
1095
+ frame_thresh: float,
1096
+ min_note_len: int,
1097
+ infer_onsets: bool,
1098
+ note_low : int, #self.labeling.midi_centers[0]
1099
+ note_high : int, #self.labeling.midi_centers[-1],
1100
+ melodia_trick: bool = True,
1101
+ energy_tol: int = 11,
1102
+ ) -> List[Tuple[int, int, int, float]]:
1103
+ """Decode raw model output to polyphonic note events
1104
+ Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
1105
+ Args:
1106
+ note_roll: Frame activation matrix (n_times, n_freqs).
1107
+ onset_roll: Onset activation matrix (n_times, n_freqs).
1108
+ onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
1109
+ frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
1110
+ min_note_len: Minimum allowed note length in frames.
1111
+ infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
1112
+ melodia_trick : Whether to use the melodia trick to better detect notes.
1113
+ energy_tol: Drop notes below this energy.
1114
+ Returns:
1115
+ list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
1116
+ representing the note events, where amplitude is a number between 0 and 1
1117
+ """
1118
+
1119
+ n_frames = note_roll.shape[0]
1120
+
1121
+ # use onsets inferred from frames in addition to the predicted onsets
1122
+ if infer_onsets:
1123
+ onset_roll = self.get_inferred_onsets(onset_roll, note_roll)
1124
+
1125
+ peak_thresh_mat = np.zeros(onset_roll.shape)
1126
+ peaks = argrelmax(onset_roll, axis=0)
1127
+ peak_thresh_mat[peaks] = onset_roll[peaks]
1128
+
1129
+ onset_idx = np.where(peak_thresh_mat >= onset_thresh)
1130
+ onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
1131
+ onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
1132
+
1133
+ remaining_energy = np.zeros(note_roll.shape)
1134
+ remaining_energy[:, :] = note_roll[:, :]
1135
+
1136
+ # loop over onsets
1137
+ note_events = []
1138
+ for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
1139
+ # if we're too close to the end of the audio, continue
1140
+ if note_start_idx >= n_frames - 1:
1141
+ continue
1142
+
1143
+ # find time index at this frequency band where the frames drop below an energy threshold
1144
+ i = note_start_idx + 1
1145
+ k = 0 # number of frames since energy dropped below threshold
1146
+ while i < n_frames - 1 and k < energy_tol:
1147
+ if remaining_energy[i, freq_idx] < frame_thresh:
1148
+ k += 1
1149
+ else:
1150
+ k = 0
1151
+ i += 1
1152
+
1153
+ i -= k # go back to frame above threshold
1154
+
1155
+ # if the note is too short, skip it
1156
+ if i - note_start_idx <= min_note_len:
1157
+ continue
1158
+
1159
+ remaining_energy[note_start_idx:i, freq_idx] = 0
1160
+ if freq_idx < note_high:
1161
+ remaining_energy[note_start_idx:i, freq_idx + 1] = 0
1162
+ if freq_idx > note_low:
1163
+ remaining_energy[note_start_idx:i, freq_idx - 1] = 0
1164
+
1165
+ # add the note
1166
+ amplitude = np.mean(note_roll[note_start_idx:i, freq_idx])
1167
+ note_events.append(
1168
+ (
1169
+ note_start_idx,
1170
+ i,
1171
+ freq_idx + note_low,
1172
+ amplitude,
1173
+ )
1174
+ )
1175
+
1176
+ if melodia_trick:
1177
+ energy_shape = remaining_energy.shape
1178
+
1179
+ while np.max(remaining_energy) > frame_thresh:
1180
+ i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
1181
+ remaining_energy[i_mid, freq_idx] = 0
1182
+
1183
+ # forward pass
1184
+ i = i_mid + 1
1185
+ k = 0
1186
+ while i < n_frames - 1 and k < energy_tol:
1187
+ if remaining_energy[i, freq_idx] < frame_thresh:
1188
+ k += 1
1189
+ else:
1190
+ k = 0
1191
+
1192
+ remaining_energy[i, freq_idx] = 0
1193
+ if freq_idx < note_high:
1194
+ remaining_energy[i, freq_idx + 1] = 0
1195
+ if freq_idx > note_low:
1196
+ remaining_energy[i, freq_idx - 1] = 0
1197
+
1198
+ i += 1
1199
+
1200
+ i_end = i - 1 - k # go back to frame above threshold
1201
+
1202
+ # backward pass
1203
+ i = i_mid - 1
1204
+ k = 0
1205
+ while i > 0 and k < energy_tol:
1206
+ if remaining_energy[i, freq_idx] < frame_thresh:
1207
+ k += 1
1208
+ else:
1209
+ k = 0
1210
+
1211
+ remaining_energy[i, freq_idx] = 0
1212
+ if freq_idx < note_high:
1213
+ remaining_energy[i, freq_idx + 1] = 0
1214
+ if freq_idx > note_low:
1215
+ remaining_energy[i, freq_idx - 1] = 0
1216
+
1217
+ i -= 1
1218
+
1219
+ i_start = i + 1 + k # go back to frame above threshold
1220
+ assert i_start >= 0, "{}".format(i_start)
1221
+ assert i_end < n_frames
1222
+
1223
+ if i_end - i_start <= min_note_len:
1224
+ # note is too short, skip it
1225
+ continue
1226
+
1227
+ # add the note
1228
+ amplitude = np.mean(note_roll[i_start:i_end, freq_idx])
1229
+ note_events.append(
1230
+ (
1231
+ i_start,
1232
+ i_end,
1233
+ freq_idx + note_low,
1234
+ amplitude,
1235
+ )
1236
+ )
1237
+
1238
+ return note_events
1239
+
1240
+ def read_audio(self, audio):
1241
+ """
1242
+ Read and resample an audio file, convert to mono, and unfold into representation frames.
1243
+ The time array represents the center of each small frame with 5.8ms hop length. This is different than the chunk
1244
+ level frames. The chunk level frames represent the entire sequence the model sees. Whereas it predicts with the
1245
+ small frames intervals (5.8ms).
1246
+ :param audio: str, pathlib.Path
1247
+ :return: frames: (n_big_frames, frame_length), times: (n_small_frames,)
1248
+ """
1249
+ audio = torch_from_numpy(librosa_load(audio, sr=self.sr, mono=True)[0])
1250
+ len_audio = audio.shape[-1]
1251
+ n_frames = int(np.ceil((len_audio + sum(self.frame_overlap)) / (self.hop_length * self.chunk_size)))
1252
+ audio = nn.functional.pad(audio, (self.frame_overlap[0],self.frame_overlap[1] + (n_frames * self.hop_length * self.chunk_size) - len_audio))
1253
+ frames = audio.unfold(0, self.max_window_size, self.hop_length*self.chunk_size)
1254
+ times = np.arange(0, len_audio, self.hop_length) / self.sr # not tensor, we don't compute anything with it
1255
+ return frames, times
1256
+
1257
+ def model_predict(self, audio, batch_size,progress_callback: Callable[[int, int], None]):
1258
+ device = self.main.conv0.conv2d.weight.device
1259
+ performance = {'f0': [], 'note': [], 'onset': [], 'offset': []}
1260
+ frames, times = self.read_audio(audio)
1261
+ with torch_no_grad():
1262
+ for i in range(0, len(frames), batch_size):
1263
+ f = frames[i:min(i + batch_size, len(frames))].to(device)
1264
+ f -= (torch_mean(f, axis=1).unsqueeze(-1))
1265
+ f /= (torch_std(f, axis=1).unsqueeze(-1))
1266
+ out = self.forward(f)
1267
+ for key, value in out.items():
1268
+ value = torch_sigmoid(value)
1269
+ value = torch_nan_to_num(value) # the model outputs nan when the frame is silent (this is an expected behavior due to normalization)
1270
+ value = value.view(-1, value.shape[-1])
1271
+ value = value.detach().cpu().numpy()
1272
+ performance[key].append(value)
1273
+ progress_callback(i,len(frames))
1274
+ performance = {key: np.concatenate(value, axis=0)[:len(times)] for key, value in performance.items()}
1275
+ performance['time'] = times
1276
+ return performance
1277
+
1278
+ def predict(self, audio, batch_size=32, postprocessing="spotify",include_pitch_bends=True,progress_callback: Callable[[int, int], None] = None,output_file="output.mid"):
1279
+ output = self.model_predict(audio, batch_size,progress_callback)
1280
+ self.note2midi(self.out2note(output, postprocessing, include_pitch_bends), 120).write(output_file)
1281
+ return output_file