audio2midi 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audio2midi/__init__.py +0 -0
- audio2midi/basic_pitch_pitch_detector.py +783 -0
- audio2midi/crepe_pitch_detector.py +130 -0
- audio2midi/librosa_pitch_detector.py +153 -0
- audio2midi/melodia_pitch_detector.py +58 -0
- audio2midi/pop2piano.py +2604 -0
- audio2midi/py.typed +0 -0
- audio2midi/violin_pitch_detector.py +1281 -0
- audio2midi-0.1.0.dist-info/METADATA +100 -0
- audio2midi-0.1.0.dist-info/RECORD +11 -0
- audio2midi-0.1.0.dist-info/WHEEL +5 -0
@@ -0,0 +1,1281 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from librosa.sequence import viterbi_discriminative
|
3
|
+
from librosa import note_to_hz,midi_to_hz , load as librosa_load
|
4
|
+
from scipy.stats import norm
|
5
|
+
from scipy.ndimage import gaussian_filter1d
|
6
|
+
from scipy.signal import medfilt ,argrelmax
|
7
|
+
from torchaudio.models.conformer import ConformerLayer
|
8
|
+
from torch import cat as torch_cat , load as torch_load , from_numpy as torch_from_numpy,no_grad as torch_no_grad ,mean as torch_mean,std as torch_std,sigmoid as torch_sigmoid,nan_to_num as torch_nan_to_num,nn
|
9
|
+
from pretty_midi_fix import PrettyMIDI , Instrument , Note , PitchBend , instrument_name_to_program ,note_name_to_number
|
10
|
+
from typing import Callable, Dict, List, Optional, Tuple , Literal
|
11
|
+
from huggingface_hub import hf_hub_download
|
12
|
+
|
13
|
+
from mir_eval.melody import hz2cents
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
|
18
|
+
class RegressionPostProcessor(object):
|
19
|
+
def __init__(self, frames_per_second, classes_num, onset_threshold,
|
20
|
+
offset_threshold, frame_threshold, pedal_offset_threshold,
|
21
|
+
begin_note):
|
22
|
+
"""Postprocess the output probabilities of a transription model to MIDI
|
23
|
+
events.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
frames_per_second: float
|
27
|
+
classes_num: int
|
28
|
+
onset_threshold: float
|
29
|
+
offset_threshold: float
|
30
|
+
frame_threshold: float
|
31
|
+
pedal_offset_threshold: float
|
32
|
+
"""
|
33
|
+
self.frames_per_second = frames_per_second
|
34
|
+
self.classes_num = classes_num
|
35
|
+
self.onset_threshold = onset_threshold
|
36
|
+
self.offset_threshold = offset_threshold
|
37
|
+
self.frame_threshold = frame_threshold
|
38
|
+
self.pedal_offset_threshold = pedal_offset_threshold
|
39
|
+
self.begin_note = begin_note
|
40
|
+
self.velocity_scale = 128
|
41
|
+
|
42
|
+
def output_dict_to_midi_events(self, output_dict):
|
43
|
+
"""Main function. Post process model outputs to MIDI events.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
output_dict: {
|
47
|
+
'reg_onset_output': (segment_frames, classes_num),
|
48
|
+
'reg_offset_output': (segment_frames, classes_num),
|
49
|
+
'frame_output': (segment_frames, classes_num),
|
50
|
+
'velocity_output': (segment_frames, classes_num),
|
51
|
+
'reg_pedal_onset_output': (segment_frames, 1),
|
52
|
+
'reg_pedal_offset_output': (segment_frames, 1),
|
53
|
+
'pedal_frame_output': (segment_frames, 1)}
|
54
|
+
|
55
|
+
Outputs:
|
56
|
+
est_note_events: list of dict, e.g. [
|
57
|
+
{'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83},
|
58
|
+
{'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]
|
59
|
+
|
60
|
+
est_pedal_events: list of dict, e.g. [
|
61
|
+
{'onset_time': 0.17, 'offset_time': 0.96},
|
62
|
+
{'osnet_time': 1.17, 'offset_time': 2.65}]
|
63
|
+
"""
|
64
|
+
output_dict['frame_output'] = output_dict['note']
|
65
|
+
output_dict['velocity_output'] = output_dict['note']
|
66
|
+
output_dict['reg_onset_output'] = output_dict['onset']
|
67
|
+
output_dict['reg_offset_output'] = output_dict['offset']
|
68
|
+
# Post process piano note outputs to piano note and pedal events information
|
69
|
+
(est_on_off_note_vels, est_pedal_on_offs) = \
|
70
|
+
self.output_dict_to_note_pedal_arrays(output_dict)
|
71
|
+
"""est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity],
|
72
|
+
est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
|
73
|
+
|
74
|
+
# Reformat notes to MIDI events
|
75
|
+
est_note_events = self.detected_notes_to_events(est_on_off_note_vels)
|
76
|
+
|
77
|
+
if est_pedal_on_offs is None:
|
78
|
+
est_pedal_events = None
|
79
|
+
else:
|
80
|
+
est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)
|
81
|
+
|
82
|
+
return est_note_events, est_pedal_events
|
83
|
+
|
84
|
+
def output_dict_to_note_pedal_arrays(self, output_dict):
|
85
|
+
"""Postprocess the output probabilities of a transription model to MIDI
|
86
|
+
events.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
output_dict: dict, {
|
90
|
+
'reg_onset_output': (frames_num, classes_num),
|
91
|
+
'reg_offset_output': (frames_num, classes_num),
|
92
|
+
'frame_output': (frames_num, classes_num),
|
93
|
+
'velocity_output': (frames_num, classes_num),
|
94
|
+
...}
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time,
|
98
|
+
offset_time, piano_note and velocity. E.g. [
|
99
|
+
[39.74, 39.87, 27, 0.65],
|
100
|
+
[11.98, 12.11, 33, 0.69],
|
101
|
+
...]
|
102
|
+
|
103
|
+
est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time
|
104
|
+
and offset_time. E.g. [
|
105
|
+
[0.17, 0.96],
|
106
|
+
[1.17, 2.65],
|
107
|
+
...]
|
108
|
+
"""
|
109
|
+
|
110
|
+
# ------ 1. Process regression outputs to binarized outputs ------
|
111
|
+
# For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
|
112
|
+
# will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]
|
113
|
+
|
114
|
+
# Calculate binarized onset output from regression output
|
115
|
+
(onset_output, onset_shift_output) = \
|
116
|
+
self.get_binarized_output_from_regression(
|
117
|
+
reg_output=output_dict['reg_onset_output'],
|
118
|
+
threshold=self.onset_threshold, neighbour=2)
|
119
|
+
|
120
|
+
output_dict['onset_output'] = onset_output # Values are 0 or 1
|
121
|
+
output_dict['onset_shift_output'] = onset_shift_output
|
122
|
+
|
123
|
+
# Calculate binarized offset output from regression output
|
124
|
+
(offset_output, offset_shift_output) = \
|
125
|
+
self.get_binarized_output_from_regression(
|
126
|
+
reg_output=output_dict['reg_offset_output'],
|
127
|
+
threshold=self.offset_threshold, neighbour=4)
|
128
|
+
|
129
|
+
output_dict['offset_output'] = offset_output # Values are 0 or 1
|
130
|
+
output_dict['offset_shift_output'] = offset_shift_output
|
131
|
+
|
132
|
+
if 'reg_pedal_onset_output' in output_dict.keys():
|
133
|
+
"""Pedal onsets are not used in inference. Instead, frame-wise pedal
|
134
|
+
predictions are used to detect onsets. We empirically found this is
|
135
|
+
more accurate to detect pedal onsets."""
|
136
|
+
pass
|
137
|
+
|
138
|
+
if 'reg_pedal_offset_output' in output_dict.keys():
|
139
|
+
# Calculate binarized pedal offset output from regression output
|
140
|
+
(pedal_offset_output, pedal_offset_shift_output) = \
|
141
|
+
self.get_binarized_output_from_regression(
|
142
|
+
reg_output=output_dict['reg_pedal_offset_output'],
|
143
|
+
threshold=self.pedal_offset_threshold, neighbour=4)
|
144
|
+
|
145
|
+
output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1
|
146
|
+
output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output
|
147
|
+
|
148
|
+
# ------ 2. Process matrices results to event results ------
|
149
|
+
# Detect piano notes from output_dict
|
150
|
+
est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)
|
151
|
+
|
152
|
+
est_pedal_on_offs = None
|
153
|
+
|
154
|
+
return est_on_off_note_vels, est_pedal_on_offs
|
155
|
+
|
156
|
+
def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
|
157
|
+
"""Calculate binarized output and shifts of onsets or offsets from the
|
158
|
+
regression results.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
reg_output: (frames_num, classes_num)
|
162
|
+
threshold: float
|
163
|
+
neighbour: int
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
binary_output: (frames_num, classes_num)
|
167
|
+
shift_output: (frames_num, classes_num)
|
168
|
+
"""
|
169
|
+
binary_output = np.zeros_like(reg_output)
|
170
|
+
shift_output = np.zeros_like(reg_output)
|
171
|
+
(frames_num, classes_num) = reg_output.shape
|
172
|
+
|
173
|
+
for k in range(classes_num):
|
174
|
+
x = reg_output[:, k]
|
175
|
+
for n in range(neighbour, frames_num - neighbour):
|
176
|
+
if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
|
177
|
+
binary_output[n, k] = 1
|
178
|
+
|
179
|
+
"""See Section III-D in [1] for deduction.
|
180
|
+
[1] Q. Kong, et al., High-resolution Piano Transcription
|
181
|
+
with Pedals by Regressing Onsets and Offsets Times, 2020."""
|
182
|
+
if x[n - 1] > x[n + 1]:
|
183
|
+
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
|
184
|
+
else:
|
185
|
+
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
|
186
|
+
shift_output[n, k] = shift
|
187
|
+
|
188
|
+
return binary_output, shift_output
|
189
|
+
|
190
|
+
def is_monotonic_neighbour(self, x, n, neighbour):
|
191
|
+
"""Detect if values are monotonic in both side of x[n].
|
192
|
+
|
193
|
+
Args:
|
194
|
+
x: (frames_num,)
|
195
|
+
n: int
|
196
|
+
neighbour: int
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
monotonic: bool
|
200
|
+
"""
|
201
|
+
monotonic = True
|
202
|
+
for i in range(neighbour):
|
203
|
+
if x[n - i] < x[n - i - 1]:
|
204
|
+
monotonic = False
|
205
|
+
if x[n + i] < x[n + i + 1]:
|
206
|
+
monotonic = False
|
207
|
+
|
208
|
+
return monotonic
|
209
|
+
|
210
|
+
def output_dict_to_detected_notes(self, output_dict):
|
211
|
+
"""Postprocess output_dict to piano notes.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
output_dict: dict, e.g. {
|
215
|
+
'onset_output': (frames_num, classes_num),
|
216
|
+
'onset_shift_output': (frames_num, classes_num),
|
217
|
+
'offset_output': (frames_num, classes_num),
|
218
|
+
'offset_shift_output': (frames_num, classes_num),
|
219
|
+
'frame_output': (frames_num, classes_num),
|
220
|
+
'onset_output': (frames_num, classes_num),
|
221
|
+
...}
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets,
|
225
|
+
MIDI notes and velocities. E.g.,
|
226
|
+
[[39.7375, 39.7500, 27., 0.6638],
|
227
|
+
[11.9824, 12.5000, 33., 0.6892],
|
228
|
+
...]
|
229
|
+
"""
|
230
|
+
|
231
|
+
est_tuples = []
|
232
|
+
est_midi_notes = []
|
233
|
+
classes_num = output_dict['frame_output'].shape[-1]
|
234
|
+
|
235
|
+
for piano_note in range(classes_num):
|
236
|
+
"""Detect piano notes"""
|
237
|
+
est_tuples_per_note = self.note_detection_with_onset_offset_regress(
|
238
|
+
frame_output=output_dict['frame_output'][:, piano_note],
|
239
|
+
onset_output=output_dict['onset_output'][:, piano_note],
|
240
|
+
onset_shift_output=output_dict['onset_shift_output'][:, piano_note],
|
241
|
+
offset_output=output_dict['offset_output'][:, piano_note],
|
242
|
+
offset_shift_output=output_dict['offset_shift_output'][:, piano_note],
|
243
|
+
velocity_output=output_dict['velocity_output'][:, piano_note],
|
244
|
+
frame_threshold=self.frame_threshold)
|
245
|
+
|
246
|
+
est_tuples += est_tuples_per_note
|
247
|
+
est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)
|
248
|
+
|
249
|
+
est_tuples = np.array(est_tuples) # (notes, 5)
|
250
|
+
"""(notes, 5), the five columns are onset, offset, onset_shift,
|
251
|
+
offset_shift and normalized_velocity"""
|
252
|
+
|
253
|
+
est_midi_notes = np.array(est_midi_notes) # (notes,)
|
254
|
+
|
255
|
+
onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
|
256
|
+
offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
|
257
|
+
velocities = est_tuples[:, 4]
|
258
|
+
|
259
|
+
est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
|
260
|
+
"""(notes, 3), the three columns are onset_times, offset_times and velocity."""
|
261
|
+
|
262
|
+
est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)
|
263
|
+
|
264
|
+
return est_on_off_note_vels
|
265
|
+
|
266
|
+
def detected_notes_to_events(self, est_on_off_note_vels):
|
267
|
+
"""Reformat detected notes to midi events.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
est_on_off_vels: (notes, 3), the three columns are onset_times,
|
271
|
+
offset_times and velocity. E.g.
|
272
|
+
[[32.8376, 35.7700, 0.7932],
|
273
|
+
[37.3712, 39.9300, 0.8058],
|
274
|
+
...]
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
midi_events, list, e.g.,
|
278
|
+
[{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
|
279
|
+
{'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
|
280
|
+
...]
|
281
|
+
"""
|
282
|
+
midi_events = []
|
283
|
+
for i in range(est_on_off_note_vels.shape[0]):
|
284
|
+
midi_events.append({
|
285
|
+
'onset_time': est_on_off_note_vels[i][0],
|
286
|
+
'offset_time': est_on_off_note_vels[i][1],
|
287
|
+
'midi_note': int(est_on_off_note_vels[i][2]),
|
288
|
+
'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})
|
289
|
+
|
290
|
+
return midi_events
|
291
|
+
|
292
|
+
def note_detection_with_onset_offset_regress(self,frame_output, onset_output,
|
293
|
+
onset_shift_output, offset_output, offset_shift_output, velocity_output,
|
294
|
+
frame_threshold):
|
295
|
+
"""Process prediction matrices to note events information.
|
296
|
+
First, detect onsets with onset outputs. Then, detect offsets
|
297
|
+
with frame and offset outputs.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
frame_output: (frames_num,)
|
301
|
+
onset_output: (frames_num,)
|
302
|
+
onset_shift_output: (frames_num,)
|
303
|
+
offset_output: (frames_num,)
|
304
|
+
offset_shift_output: (frames_num,)
|
305
|
+
velocity_output: (frames_num,)
|
306
|
+
frame_threshold: float
|
307
|
+
Returns:
|
308
|
+
output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
|
309
|
+
e.g., [
|
310
|
+
[1821, 1909, 0.47498, 0.3048533, 0.72119445],
|
311
|
+
[1909, 1947, 0.30730522, -0.45764327, 0.64200014],
|
312
|
+
...]
|
313
|
+
"""
|
314
|
+
output_tuples = []
|
315
|
+
bgn = None
|
316
|
+
frame_disappear = None
|
317
|
+
offset_occur = None
|
318
|
+
|
319
|
+
for i in range(onset_output.shape[0]):
|
320
|
+
if onset_output[i] == 1:
|
321
|
+
"""Onset detected"""
|
322
|
+
if bgn:
|
323
|
+
"""Consecutive onsets. E.g., pedal is not released, but two
|
324
|
+
consecutive notes being played."""
|
325
|
+
fin = max(i - 1, 0)
|
326
|
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
327
|
+
0, velocity_output[bgn]])
|
328
|
+
frame_disappear, offset_occur = None, None
|
329
|
+
bgn = i
|
330
|
+
|
331
|
+
if bgn and i > bgn:
|
332
|
+
"""If onset found, then search offset"""
|
333
|
+
if frame_output[i] <= frame_threshold and not frame_disappear:
|
334
|
+
"""Frame disappear detected"""
|
335
|
+
frame_disappear = i
|
336
|
+
|
337
|
+
if offset_output[i] == 1 and not offset_occur:
|
338
|
+
"""Offset detected"""
|
339
|
+
offset_occur = i
|
340
|
+
|
341
|
+
if frame_disappear:
|
342
|
+
if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
|
343
|
+
"""bgn --------- offset_occur --- frame_disappear"""
|
344
|
+
fin = offset_occur
|
345
|
+
else:
|
346
|
+
"""bgn --- offset_occur --------- frame_disappear"""
|
347
|
+
fin = frame_disappear
|
348
|
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
349
|
+
offset_shift_output[fin], velocity_output[bgn]])
|
350
|
+
bgn, frame_disappear, offset_occur = None, None, None
|
351
|
+
|
352
|
+
if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
|
353
|
+
"""Offset not detected"""
|
354
|
+
fin = i
|
355
|
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
356
|
+
offset_shift_output[fin], velocity_output[bgn]])
|
357
|
+
bgn, frame_disappear, offset_occur = None, None, None
|
358
|
+
|
359
|
+
# Sort pairs by onsets
|
360
|
+
output_tuples.sort(key=lambda pair: pair[0])
|
361
|
+
|
362
|
+
return output_tuples
|
363
|
+
|
364
|
+
class PerformanceLabel:
|
365
|
+
"""
|
366
|
+
The dataset labeling class for performance representations. Currently, includes onset, note, and fine-grained f0
|
367
|
+
representations. Note min, note max, and f0_bin_per_semitone values are to be arranged per instrument. The default
|
368
|
+
values are for violin performance analysis. Fretted instruments might not require such f0 resolutions per semitone.
|
369
|
+
"""
|
370
|
+
def __init__(self, note_min='F#3', note_max='C8', f0_bins_per_semitone=9, f0_smooth_std_c=None,
|
371
|
+
onset_smooth_std=0.7, f0_tolerance_c=200):
|
372
|
+
midi_min = note_name_to_number(note_min)
|
373
|
+
midi_max = note_name_to_number(note_max)
|
374
|
+
self.midi_centers = np.arange(midi_min, midi_max)
|
375
|
+
self.onset_smooth_std=onset_smooth_std # onset smoothing along time axis (compensate for alignment)
|
376
|
+
|
377
|
+
f0_hz_range = note_to_hz([note_min, note_max])
|
378
|
+
f0_c_min, f0_c_max = hz2cents(f0_hz_range)
|
379
|
+
self.f0_granularity_c = 100/f0_bins_per_semitone
|
380
|
+
if not f0_smooth_std_c:
|
381
|
+
f0_smooth_std_c = self.f0_granularity_c * 5/4 # Keep the ratio from the CREPE paper (20 cents and 25 cents)
|
382
|
+
self.f0_smooth_std_c = f0_smooth_std_c
|
383
|
+
|
384
|
+
self.f0_centers_c = np.arange(f0_c_min, f0_c_max, self.f0_granularity_c)
|
385
|
+
self.f0_centers_hz = 10 * 2 ** (self.f0_centers_c / 1200)
|
386
|
+
self.f0_n_bins = len(self.f0_centers_c)
|
387
|
+
|
388
|
+
self.pdf_normalizer = norm.pdf(0)
|
389
|
+
|
390
|
+
self.f0_c2hz = lambda c: 10*2**(c/1200)
|
391
|
+
self.f0_hz2c = hz2cents
|
392
|
+
self.midi_centers_c = self.f0_hz2c(midi_to_hz(self.midi_centers))
|
393
|
+
|
394
|
+
self.f0_tolerance_bins = int(f0_tolerance_c/self.f0_granularity_c)
|
395
|
+
self.f0_transition_matrix = gaussian_filter1d(np.eye(2*self.f0_tolerance_bins + 1), 25/self.f0_granularity_c)
|
396
|
+
|
397
|
+
def f0_c2label(self, pitch_c):
|
398
|
+
"""
|
399
|
+
Convert a single f0 value in cents to a one-hot label vector with smoothing (i.e., create a gaussian blur around
|
400
|
+
the target f0 bin for regularization and training stability. The blur is controlled by self.f0_smooth_std_c
|
401
|
+
:param pitch_c: a single pitch value in cents
|
402
|
+
:return: one-hot label vector with frequency blur
|
403
|
+
"""
|
404
|
+
result = norm.pdf((self.f0_centers_c - pitch_c) / self.f0_smooth_std_c).astype(np.float32)
|
405
|
+
result /= self.pdf_normalizer
|
406
|
+
return result
|
407
|
+
|
408
|
+
def f0_label2c(self, salience, center=None):
|
409
|
+
"""
|
410
|
+
Convert the salience predictions to monophonic f0 in cents. Only outputs a single f0 value per frame!
|
411
|
+
:param salience: f0 activations
|
412
|
+
:param center: f0 center bin to calculate the weighted average. Use argmax if empty
|
413
|
+
:return: f0 array per frame (in cents).
|
414
|
+
"""
|
415
|
+
if salience.ndim == 1:
|
416
|
+
if center is None:
|
417
|
+
center = int(np.argmax(salience))
|
418
|
+
start = max(0, center - 4)
|
419
|
+
end = min(len(salience), center + 5)
|
420
|
+
salience = salience[start:end]
|
421
|
+
product_sum = np.sum(salience * self.f0_centers_c[start:end])
|
422
|
+
weight_sum = np.sum(salience)
|
423
|
+
return product_sum / np.clip(weight_sum, 1e-8, None)
|
424
|
+
if salience.ndim == 2:
|
425
|
+
return np.array([self.f0_label2c(salience[i, :]) for i in range(salience.shape[0])])
|
426
|
+
raise Exception("label should be either 1d or 2d ndarray")
|
427
|
+
|
428
|
+
def fill_onset_matrix(self, onsets, window, feature_rate):
|
429
|
+
"""
|
430
|
+
Create a sparse onset matrix from window and onsets (per-semitone). Apply a gaussian smoothing (along time)
|
431
|
+
so that we can tolerate better the alignment problems. This is similar to the frequency smoothing for the f0.
|
432
|
+
The temporal smoothing is controlled by the parameter self.onset_smooth_std
|
433
|
+
:param onsets: A 2d np.array of individual note onsets with their respective time values
|
434
|
+
(Nx2: time in seconds - midi number)
|
435
|
+
:param window: Timestamps for the frame centers of the sparse matrix
|
436
|
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
437
|
+
:return: onset_roll: A sparse matrix filled with temporally blurred onsets.
|
438
|
+
"""
|
439
|
+
onsets = self.get_window_feats(onsets, window, feature_rate)
|
440
|
+
onset_roll = np.zeros((len(window), len(self.midi_centers)))
|
441
|
+
for onset in onsets:
|
442
|
+
onset, note = onset # it was a pair with time and midi note
|
443
|
+
if self.midi_centers[0] < note < self.midi_centers[-1]: # midi note should be in the range defined
|
444
|
+
note = int(note) - self.midi_centers[0] # find the note index in our range
|
445
|
+
onset = (onset*feature_rate)-window[0] # onset index (as float but in frames, not in seconds!)
|
446
|
+
start = max(0, int(onset) - 3)
|
447
|
+
end = min(len(window) - 1, int(onset) + 3)
|
448
|
+
try:
|
449
|
+
vals = norm.pdf(np.linspace(start - onset, end - onset, end - start + 1) / self.onset_smooth_std)
|
450
|
+
# if you increase 0.7 you smooth the peak
|
451
|
+
# if you decrease it, e.g., 0.1, it becomes too peaky! around 0.5-0.7 seems ok
|
452
|
+
vals /= self.pdf_normalizer
|
453
|
+
onset_roll[start:end + 1, note] += vals
|
454
|
+
except ValueError:
|
455
|
+
print('start',start, 'onset', onset, 'end', end)
|
456
|
+
return onset_roll, onsets
|
457
|
+
|
458
|
+
def fill_note_matrix(self, notes, window, feature_rate):
|
459
|
+
"""
|
460
|
+
Create the note matrix (piano roll) from window timestamps and note values per frame.
|
461
|
+
:param notes: A 2d np.array of individual notes with their active time values Nx2
|
462
|
+
:param window: Timestamps for the frame centers of the output
|
463
|
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
464
|
+
:return note_roll: The piano roll in the defined range of [note_min, note_max).
|
465
|
+
"""
|
466
|
+
notes = self.get_window_feats(notes, window, feature_rate)
|
467
|
+
|
468
|
+
# take the notes in the midi range defined
|
469
|
+
notes = notes[np.logical_and(notes[:,1]>=self.midi_centers[0], notes[:,1]<=self.midi_centers[-1]),:]
|
470
|
+
|
471
|
+
times = (notes[:,0]*feature_rate - window[0]).astype(int) # in feature samples (fs:self.hop/self.sr)
|
472
|
+
notes = (notes[:,1] - self.midi_centers[0]).astype(int)
|
473
|
+
|
474
|
+
note_roll = np.zeros((len(window), len(self.midi_centers)))
|
475
|
+
note_roll[(times, notes)] = 1
|
476
|
+
return note_roll, notes
|
477
|
+
|
478
|
+
def fill_f0_matrix(self, f0s, window, feature_rate):
|
479
|
+
"""
|
480
|
+
Unlike the labels for onsets and notes, f0 label is only relevant for strictly monophonic regions! Thus, this
|
481
|
+
function returns a boolean which represents where to apply the given values.
|
482
|
+
Never back-propagate without the boolean! Empty frames mean that the label is not that reliable.
|
483
|
+
|
484
|
+
:param f0s: A 2d np.array of f0 values with the time they belong to (2xN: time in seconds - f0 in Hz)
|
485
|
+
:param window: Timestamps for the frame centers of the output
|
486
|
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
487
|
+
|
488
|
+
:return f0_roll: f0 label matrix and
|
489
|
+
f0_hz: f0 values in Hz
|
490
|
+
annotation_bool: A boolean array representing which frames have reliable f0 annotations.
|
491
|
+
"""
|
492
|
+
f0s = self.get_window_feats(f0s, window, feature_rate)
|
493
|
+
f0_cents = np.zeros_like(window, dtype=float)
|
494
|
+
f0s[:,1] = self.f0_hz2c(f0s[:,1]) # convert f0 in hz to cents
|
495
|
+
|
496
|
+
annotation_bool = np.zeros_like(window, dtype=bool)
|
497
|
+
f0_roll = np.zeros((len(window), len(self.f0_centers_c)))
|
498
|
+
times_in_frame = f0s[:, 0]*feature_rate - window[0]
|
499
|
+
for t, f0 in enumerate(f0s):
|
500
|
+
t = times_in_frame[t]
|
501
|
+
if t%1 < 0.25: # only consider it as annotation if the f0 values is really close to the frame center
|
502
|
+
t = int(np.round(t))
|
503
|
+
f0_roll[t] = self.f0_c2label(f0[1])
|
504
|
+
annotation_bool[t] = True
|
505
|
+
f0_cents[t] = f0[1]
|
506
|
+
|
507
|
+
return f0_roll, f0_cents, annotation_bool
|
508
|
+
|
509
|
+
@staticmethod
|
510
|
+
def get_window_feats(time_feature_matrix, window, feature_rate):
|
511
|
+
"""
|
512
|
+
Restrict the feature matrix to the features that are inside the window
|
513
|
+
:param window: Timestamps for the frame centers of the output
|
514
|
+
:param time_feature_matrix: A 2d array of Nx2 per the entire file.
|
515
|
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
516
|
+
:return: window_features: the features inside the given window
|
517
|
+
"""
|
518
|
+
start = time_feature_matrix[:,0]>(window[0]-0.5)/feature_rate
|
519
|
+
end = time_feature_matrix[:,0]<(window[-1]+0.5)/feature_rate
|
520
|
+
window_features = np.logical_and(start, end)
|
521
|
+
window_features = np.array(time_feature_matrix[window_features,:])
|
522
|
+
return window_features
|
523
|
+
|
524
|
+
def represent_midi(self, midi, feature_rate):
|
525
|
+
"""
|
526
|
+
Represent a midi file as sparse matrices of onsets, offsets, and notes. No f0 is included.
|
527
|
+
:param midi: A midi file (either a path or a pretty_midi_fix.PrettyMIDI object)
|
528
|
+
:param feature_rate: The feature rate in Hz
|
529
|
+
:return: dict {onset, offset, note, time}: Same format with the model's learning and outputs
|
530
|
+
"""
|
531
|
+
def _get_onsets_offsets_frames(midi_content):
|
532
|
+
if isinstance(midi_content, str):
|
533
|
+
midi_content = PrettyMIDI(midi_content)
|
534
|
+
onsets = []
|
535
|
+
offsets = []
|
536
|
+
frames = []
|
537
|
+
for instrument in midi_content.instruments:
|
538
|
+
for note in instrument.notes:
|
539
|
+
start = int(np.round(note.start * feature_rate))
|
540
|
+
end = int(np.round(note.end * feature_rate))
|
541
|
+
note_times = (np.arange(start, end+0.5)/feature_rate)[:, np.newaxis]
|
542
|
+
note_pitch = np.full_like(note_times, fill_value=note.pitch)
|
543
|
+
onsets.append([note.start, note.pitch])
|
544
|
+
offsets.append([note.end, note.pitch])
|
545
|
+
frames.append(np.hstack([note_times, note_pitch]))
|
546
|
+
onsets = np.vstack(onsets)
|
547
|
+
offsets = np.vstack(offsets)
|
548
|
+
frames = np.vstack(frames)
|
549
|
+
return onsets, offsets, frames, midi_content
|
550
|
+
onset_array, offset_array, frame_array, midi_object = _get_onsets_offsets_frames(midi)
|
551
|
+
window = np.arange(frame_array[0, 0]*feature_rate, frame_array[-1, 0]*feature_rate, dtype=int)
|
552
|
+
onset_roll, _ = self.fill_onset_matrix(onset_array, window, feature_rate)
|
553
|
+
offset_roll, _ = self.fill_onset_matrix(offset_array, window, feature_rate)
|
554
|
+
note_roll, _ = self.fill_note_matrix(frame_array, window, feature_rate)
|
555
|
+
start_anchor = onset_array[onset_array[:, 0]==np.min(onset_array[:, 0])]
|
556
|
+
end_anchor = offset_array[offset_array[:, 0]==np.max(offset_array[:, 0])]
|
557
|
+
return {
|
558
|
+
'midi': midi_object,
|
559
|
+
'note': note_roll,
|
560
|
+
'onset': onset_roll,
|
561
|
+
'offset': offset_roll,
|
562
|
+
'time': window/feature_rate,
|
563
|
+
'start_anchor': start_anchor,
|
564
|
+
'end_anchor': end_anchor
|
565
|
+
}
|
566
|
+
|
567
|
+
class NoPadConvBlock(nn.Module):
|
568
|
+
def __init__(self, f, w, s, d, in_channels):
|
569
|
+
super().__init__()
|
570
|
+
|
571
|
+
self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1),
|
572
|
+
dilation=(d, 1))
|
573
|
+
self.relu = nn.ReLU()
|
574
|
+
self.bn = nn.BatchNorm2d(f)
|
575
|
+
self.pool = nn.MaxPool2d(kernel_size=(2, 1))
|
576
|
+
self.dropout = nn.Dropout(0.25)
|
577
|
+
|
578
|
+
def forward(self, x):
|
579
|
+
x = self.conv2d(x)
|
580
|
+
x = self.relu(x)
|
581
|
+
x = self.bn(x)
|
582
|
+
x = self.pool(x)
|
583
|
+
x = self.dropout(x)
|
584
|
+
return x
|
585
|
+
|
586
|
+
class TinyPathway(nn.Module):
|
587
|
+
def __init__(self, dilation=1, hop=256, localize=False,
|
588
|
+
model_capacity="full", n_layers=6, chunk_size=256):
|
589
|
+
super().__init__()
|
590
|
+
|
591
|
+
capacity_multiplier = {
|
592
|
+
'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
|
593
|
+
}[model_capacity]
|
594
|
+
self.layers = [1, 2, 3, 4, 5, 6]
|
595
|
+
self.layers = self.layers[:n_layers]
|
596
|
+
filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]]
|
597
|
+
filters = [1] + filters
|
598
|
+
widths = [512, 64, 64, 64, 32, 32]
|
599
|
+
strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize)
|
600
|
+
strides[0] = strides[0]*4 # apply 4 times more stride at the first layer
|
601
|
+
dilations = self.deter_dilations(dilation)
|
602
|
+
|
603
|
+
for i in range(len(self.layers)):
|
604
|
+
f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i]
|
605
|
+
self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel))
|
606
|
+
self.chunk_size = chunk_size
|
607
|
+
self.input_window, self.hop = self.find_input_size_for_pathway()
|
608
|
+
self.out_dim = filters[n_layers]
|
609
|
+
|
610
|
+
def find_input_size_for_pathway(self):
|
611
|
+
def find_input_size(output_size, kernel_size, stride, dilation, padding):
|
612
|
+
num = (stride*(output_size-1)) + 1
|
613
|
+
input_size = num - 2*padding + dilation*(kernel_size-1)
|
614
|
+
return input_size
|
615
|
+
conv_calc, n = {}, 0
|
616
|
+
for i in self.layers:
|
617
|
+
layer = self.__getattr__("conv%d" % (i-1))
|
618
|
+
for mm in layer.modules():
|
619
|
+
if hasattr(mm, 'kernel_size'):
|
620
|
+
try:
|
621
|
+
d = mm.dilation[0]
|
622
|
+
except TypeError:
|
623
|
+
d = mm.dilation
|
624
|
+
conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d]
|
625
|
+
n += 1
|
626
|
+
out = self.chunk_size
|
627
|
+
hop = 1
|
628
|
+
for n in sorted(conv_calc.keys())[::-1]:
|
629
|
+
kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n]
|
630
|
+
out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n)
|
631
|
+
hop = hop*stride_n
|
632
|
+
return out, hop
|
633
|
+
|
634
|
+
def deter_dilations(self, total_dilation, localize=False):
|
635
|
+
n_layers = len(self.layers)
|
636
|
+
if localize: # e.g., 32*1023 window and 3 layers -> [1, 1, 32]
|
637
|
+
a = [total_dilation] + [1 for _ in range(n_layers-1)]
|
638
|
+
else: # e.g., 32*1023 window and 3 layers -> [4, 4, 2]
|
639
|
+
total_dilation = int(np.log2(total_dilation))
|
640
|
+
a = []
|
641
|
+
for layer in range(n_layers):
|
642
|
+
this_dilation = int(np.ceil(total_dilation/(n_layers-layer)))
|
643
|
+
a.append(2**this_dilation)
|
644
|
+
total_dilation = total_dilation - this_dilation
|
645
|
+
return a[::-1]
|
646
|
+
|
647
|
+
def forward(self, x):
|
648
|
+
x = x.view(x.shape[0], 1, -1, 1)
|
649
|
+
for i in range(len(self.layers)):
|
650
|
+
x = self.__getattr__("conv%d" % i)(x)
|
651
|
+
x = x.permute(0, 3, 2, 1)
|
652
|
+
return x
|
653
|
+
|
654
|
+
|
655
|
+
|
656
|
+
|
657
|
+
class Pitch_Det(nn.Module):
|
658
|
+
def __init__(
|
659
|
+
self,
|
660
|
+
pathway_multiscale: int = 32,
|
661
|
+
num_pathway_layers: int = 2,
|
662
|
+
chunk_size: int = 256,
|
663
|
+
hop_length: int = 256,
|
664
|
+
encoder_dim: int = 256,
|
665
|
+
sr: int = 44100,
|
666
|
+
num_heads: int = 4,
|
667
|
+
ffn_dim: int = 128,
|
668
|
+
num_separator_layers: int = 16,
|
669
|
+
num_representation_layers: int = 4,
|
670
|
+
depthwise_conv_kernel_size: int = 31,
|
671
|
+
dropout: float = 0.25,
|
672
|
+
use_group_norm: bool = False,
|
673
|
+
convolution_first: bool = False,
|
674
|
+
labeling=PerformanceLabel(),
|
675
|
+
wiring='tiktok',
|
676
|
+
model_capacity="full"
|
677
|
+
):
|
678
|
+
super().__init__()
|
679
|
+
self.labeling = labeling
|
680
|
+
self.sr = sr
|
681
|
+
self.window_size = 1024
|
682
|
+
self.hop_length = hop_length
|
683
|
+
self.f0_bins_per_semitone = int(np.round(100/self.labeling.f0_granularity_c))
|
684
|
+
|
685
|
+
self.main = TinyPathway(dilation=1, hop=hop_length, localize=True,
|
686
|
+
n_layers=num_pathway_layers, chunk_size=chunk_size,model_capacity=model_capacity)
|
687
|
+
self.attendant = TinyPathway(dilation=pathway_multiscale, hop=hop_length, localize=False,
|
688
|
+
n_layers=num_pathway_layers, chunk_size=chunk_size,model_capacity=model_capacity)
|
689
|
+
assert self.main.hop == self.attendant.hop # they should output with the same sample rate
|
690
|
+
print('hop in samples:', self.main.hop)
|
691
|
+
self.input_window = self.attendant.input_window
|
692
|
+
|
693
|
+
self.encoder_dim = encoder_dim
|
694
|
+
self.dropout = nn.Dropout(dropout)
|
695
|
+
|
696
|
+
# merge two streams into a conformer input
|
697
|
+
self.stream_merger = nn.Sequential(self.dropout,
|
698
|
+
nn.Linear(self.main.out_dim + self.attendant.out_dim, self.encoder_dim))
|
699
|
+
|
700
|
+
|
701
|
+
|
702
|
+
print('main stream window:', self.main.input_window,
|
703
|
+
', attendant stream window:', self.attendant.input_window,
|
704
|
+
', conformer input dim:', self.encoder_dim)
|
705
|
+
|
706
|
+
center = ((chunk_size - 1) * self.main.hop) # region labeled with pitch track
|
707
|
+
main_overlap = self.main.input_window - center
|
708
|
+
main_overlap = [int(np.floor(main_overlap / 2)), int(np.ceil(main_overlap / 2))]
|
709
|
+
attendant_overlap = self.attendant.input_window - center
|
710
|
+
attendant_overlap = [int(np.floor(attendant_overlap / 2)), int(np.ceil(attendant_overlap / 2))]
|
711
|
+
print('main frame overlap:', main_overlap, ', attendant frame overlap:', attendant_overlap)
|
712
|
+
main_crop_relative = [attendant_overlap[0] - main_overlap[0], main_overlap[1] - attendant_overlap[1]]
|
713
|
+
print('crop for main pathway', main_crop_relative)
|
714
|
+
print("Total sequence duration is", self.attendant.input_window, 'samples')
|
715
|
+
print('Main stream receptive field for one frame is', (self.main.input_window - center), 'samples')
|
716
|
+
print('Attendant stream receptive field for one frame is', (self.attendant.input_window - center), 'samples')
|
717
|
+
self.frame_overlap = attendant_overlap
|
718
|
+
|
719
|
+
self.main_stream_crop = main_crop_relative
|
720
|
+
self.max_window_size = self.attendant.input_window
|
721
|
+
self.chunk_size = chunk_size
|
722
|
+
|
723
|
+
self.separator_stream = nn.ModuleList( # source-separation, reinvented
|
724
|
+
[
|
725
|
+
ConformerLayer(
|
726
|
+
input_dim=self.encoder_dim,
|
727
|
+
ffn_dim=ffn_dim,
|
728
|
+
num_attention_heads=num_heads,
|
729
|
+
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
730
|
+
dropout=dropout,
|
731
|
+
use_group_norm=use_group_norm,
|
732
|
+
convolution_first=convolution_first,
|
733
|
+
)
|
734
|
+
for _ in range(num_separator_layers)
|
735
|
+
]
|
736
|
+
)
|
737
|
+
|
738
|
+
self.f0_stream = nn.ModuleList(
|
739
|
+
[
|
740
|
+
ConformerLayer(
|
741
|
+
input_dim=self.encoder_dim,
|
742
|
+
ffn_dim=ffn_dim,
|
743
|
+
num_attention_heads=num_heads,
|
744
|
+
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
745
|
+
dropout=dropout,
|
746
|
+
use_group_norm=use_group_norm,
|
747
|
+
convolution_first=convolution_first,
|
748
|
+
)
|
749
|
+
for _ in range(num_representation_layers)
|
750
|
+
]
|
751
|
+
)
|
752
|
+
self.f0_head = nn.Linear(self.encoder_dim, len(self.labeling.f0_centers_c))
|
753
|
+
|
754
|
+
self.note_stream = nn.ModuleList(
|
755
|
+
[
|
756
|
+
ConformerLayer(
|
757
|
+
input_dim=self.encoder_dim,
|
758
|
+
ffn_dim=ffn_dim,
|
759
|
+
num_attention_heads=num_heads,
|
760
|
+
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
761
|
+
dropout=dropout,
|
762
|
+
use_group_norm=use_group_norm,
|
763
|
+
convolution_first=convolution_first,
|
764
|
+
)
|
765
|
+
for _ in range(num_representation_layers)
|
766
|
+
]
|
767
|
+
)
|
768
|
+
self.note_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
|
769
|
+
|
770
|
+
self.onset_stream = nn.ModuleList(
|
771
|
+
[
|
772
|
+
ConformerLayer(
|
773
|
+
input_dim=self.encoder_dim,
|
774
|
+
ffn_dim=ffn_dim,
|
775
|
+
num_attention_heads=num_heads,
|
776
|
+
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
777
|
+
dropout=dropout,
|
778
|
+
use_group_norm=use_group_norm,
|
779
|
+
convolution_first=convolution_first,
|
780
|
+
)
|
781
|
+
for _ in range(num_representation_layers)
|
782
|
+
]
|
783
|
+
)
|
784
|
+
self.onset_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
|
785
|
+
|
786
|
+
self.offset_stream = nn.ModuleList(
|
787
|
+
[
|
788
|
+
ConformerLayer(
|
789
|
+
input_dim=self.encoder_dim,
|
790
|
+
ffn_dim=ffn_dim,
|
791
|
+
num_attention_heads=num_heads,
|
792
|
+
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
793
|
+
dropout=dropout,
|
794
|
+
use_group_norm=use_group_norm,
|
795
|
+
convolution_first=convolution_first,
|
796
|
+
)
|
797
|
+
for _ in range(num_representation_layers)
|
798
|
+
]
|
799
|
+
)
|
800
|
+
self.offset_head = nn.Linear(self.encoder_dim, len(self.labeling.midi_centers))
|
801
|
+
|
802
|
+
self.labeling = labeling
|
803
|
+
self.double_merger = nn.Sequential(self.dropout, nn.Linear(2 * self.encoder_dim, self.encoder_dim))
|
804
|
+
self.triple_merger = nn.Sequential(self.dropout, nn.Linear(3 * self.encoder_dim, self.encoder_dim))
|
805
|
+
self.wiring = wiring
|
806
|
+
|
807
|
+
print('Total parameter count: ', self.count_parameters())
|
808
|
+
|
809
|
+
def count_parameters(self) -> int:
|
810
|
+
""" Count parameters of encoder """
|
811
|
+
return sum([p.numel() for p in self.parameters()])
|
812
|
+
|
813
|
+
def stream(self, x, representation, key_padding_mask=None):
|
814
|
+
for i, layer in enumerate(self.__getattr__('{}_stream'.format(representation))):
|
815
|
+
x = layer(x, key_padding_mask)
|
816
|
+
return x
|
817
|
+
|
818
|
+
def head(self, x, representation):
|
819
|
+
return self.__getattr__('{}_head'.format(representation))(x)
|
820
|
+
|
821
|
+
def forward(self, x, key_padding_mask=None):
|
822
|
+
|
823
|
+
# two auditory streams followed by the separator stream to ensure timbre-awareness
|
824
|
+
x_attendant = self.attendant(x)
|
825
|
+
x_main = self.main(x[:, self.main_stream_crop[0]:self.main_stream_crop[1]])
|
826
|
+
x = self.stream_merger(torch_cat((x_attendant, x_main), -1).squeeze(1))
|
827
|
+
x = self.stream(x, 'separator', key_padding_mask)
|
828
|
+
|
829
|
+
f0 = self.stream(x, 'f0', key_padding_mask) # they say this is a low level feature :)
|
830
|
+
|
831
|
+
if self.wiring == 'parallel':
|
832
|
+
note = self.stream(x, 'note', key_padding_mask)
|
833
|
+
onset = self.stream(x, 'onset', key_padding_mask)
|
834
|
+
offset = self.stream(x, 'offset', key_padding_mask)
|
835
|
+
|
836
|
+
elif self.wiring == 'tiktok':
|
837
|
+
onset = self.stream(x, 'onset', key_padding_mask)
|
838
|
+
offset = self.stream(x, 'offset', key_padding_mask)
|
839
|
+
# f0 is disconnected, note relies on separator, onset, and offset
|
840
|
+
note = self.stream(self.triple_merger(torch_cat((x, onset, offset), -1)), 'note', key_padding_mask)
|
841
|
+
|
842
|
+
elif self.wiring == 'tiktok2':
|
843
|
+
onset = self.stream(x, 'onset', key_padding_mask)
|
844
|
+
offset = self.stream(x, 'offset', key_padding_mask)
|
845
|
+
# note is connected to f0, onset, and offset
|
846
|
+
note = self.stream(self.triple_merger(torch_cat((f0, onset, offset), -1)), 'note', key_padding_mask)
|
847
|
+
|
848
|
+
elif self.wiring == 'spotify':
|
849
|
+
# note is connected to f0 only
|
850
|
+
note = self.stream(f0, 'note', key_padding_mask)
|
851
|
+
# here onset and onsets are higher-level features informed by the separator and note
|
852
|
+
onset = self.stream(self.double_merger(torch_cat((x, note), -1)), 'onset', key_padding_mask)
|
853
|
+
offset = self.stream(self.double_merger(torch_cat((x, note), -1)), 'offset', key_padding_mask)
|
854
|
+
|
855
|
+
else:
|
856
|
+
# onset and offset are connected to f0 and separator streams
|
857
|
+
onset = self.stream(self.double_merger(torch_cat((x, f0), -1)), 'onset', key_padding_mask)
|
858
|
+
offset = self.stream(self.double_merger(torch_cat((x, f0), -1)), 'offset', key_padding_mask)
|
859
|
+
# note is connected to f0, onset, and offset streams
|
860
|
+
note = self.stream(self.triple_merger(torch_cat((f0, onset, offset), -1)), 'note', key_padding_mask)
|
861
|
+
|
862
|
+
|
863
|
+
return {'f0': self.head(f0, 'f0'),
|
864
|
+
'note': self.head(note, 'note'),
|
865
|
+
'onset': self.head(onset, 'onset'),
|
866
|
+
'offset': self.head(offset, 'offset')}
|
867
|
+
|
868
|
+
|
869
|
+
class Violin_Pitch_Det(Pitch_Det):
|
870
|
+
def __init__(self,model=hf_hub_download("shethjenil/Audio2Midi_Models","violin.pt"),model_capacity:Literal['tiny', 'small', 'medium', 'large', 'full']="full",device="cpu"):
|
871
|
+
model_conf = {
|
872
|
+
"wiring": "parallel",
|
873
|
+
"sampling_rate": 44100,
|
874
|
+
"pathway_multiscale": 4,
|
875
|
+
"num_pathway_layers": 2,
|
876
|
+
"num_separator_layers": 16,
|
877
|
+
"num_representation_layers": 4,
|
878
|
+
"hop_length": 256,
|
879
|
+
"chunk_size": 512,
|
880
|
+
"minSNR": -32,
|
881
|
+
"maxSNR": 96,
|
882
|
+
"note_low": "F#3",
|
883
|
+
"note_high": "E8",
|
884
|
+
"f0_bins_per_semitone": 10,
|
885
|
+
"f0_smooth_std_c": 12,
|
886
|
+
"onset_smooth_std": 0.7
|
887
|
+
}
|
888
|
+
super().__init__(pathway_multiscale=model_conf['pathway_multiscale'],num_pathway_layers=model_conf['num_pathway_layers'], wiring=model_conf['wiring'],hop_length=model_conf['hop_length'], chunk_size=model_conf['chunk_size'],labeling=PerformanceLabel(note_min=model_conf['note_low'], note_max=model_conf['note_high'],f0_bins_per_semitone=model_conf['f0_bins_per_semitone'],f0_tolerance_c=200,f0_smooth_std_c=model_conf['f0_smooth_std_c'], onset_smooth_std=model_conf['onset_smooth_std']), sr=model_conf['sampling_rate'],model_capacity=model_capacity)
|
889
|
+
self.load_state_dict(torch_load(model, map_location=device,weights_only=True))
|
890
|
+
self.eval()
|
891
|
+
|
892
|
+
def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
|
893
|
+
include_pitch_bends: bool = True,
|
894
|
+
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
|
895
|
+
"""Convert model output to notes
|
896
|
+
"""
|
897
|
+
if postprocessing == 'spotify':
|
898
|
+
estimated_notes = self.spotify_create_notes(
|
899
|
+
output["note"],
|
900
|
+
output["onset"],
|
901
|
+
note_low=self.labeling.midi_centers[0],
|
902
|
+
note_high=self.labeling.midi_centers[-1],
|
903
|
+
onset_thresh=0.5,
|
904
|
+
frame_thresh=0.3,
|
905
|
+
infer_onsets=True,
|
906
|
+
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
|
907
|
+
melodia_trick=True,
|
908
|
+
)
|
909
|
+
|
910
|
+
elif postprocessing == 'tiktok':
|
911
|
+
postprocessor = RegressionPostProcessor(
|
912
|
+
frames_per_second=self.sr / self.hop_length,
|
913
|
+
classes_num=self.labeling.midi_centers.shape[0],
|
914
|
+
begin_note=self.labeling.midi_centers[0],
|
915
|
+
onset_threshold=0.2,
|
916
|
+
offset_threshold=0.2,
|
917
|
+
frame_threshold=0.3,
|
918
|
+
pedal_offset_threshold=0.5,
|
919
|
+
)
|
920
|
+
tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
|
921
|
+
estimated_notes = []
|
922
|
+
for list_item in tiktok_note_dict:
|
923
|
+
if list_item['offset_time'] > 0.6 + list_item['onset_time']:
|
924
|
+
estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
|
925
|
+
int(np.ceil(list_item['offset_time']/(output['time'][1]))),
|
926
|
+
list_item['midi_note'], list_item['velocity']/128))
|
927
|
+
|
928
|
+
if include_pitch_bends:
|
929
|
+
estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
|
930
|
+
else:
|
931
|
+
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
|
932
|
+
|
933
|
+
times_s = output['time']
|
934
|
+
estimated_notes_time_seconds = [
|
935
|
+
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
|
936
|
+
]
|
937
|
+
|
938
|
+
return estimated_notes_time_seconds
|
939
|
+
|
940
|
+
def note2midi(
|
941
|
+
self,
|
942
|
+
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
|
943
|
+
midi_tempo: float = 120,
|
944
|
+
):
|
945
|
+
"""Create a pretty_midi_fix object from note events
|
946
|
+
:param note_events_with_pitch_bends: list of tuples
|
947
|
+
[(start_time_seconds, end_time_seconds, pitch_midi, amplitude, [pitch_bend])]
|
948
|
+
:param midi_tempo: MIDI tempo (BPM)
|
949
|
+
:return: PrettyMIDI object
|
950
|
+
"""
|
951
|
+
mid = PrettyMIDI(initial_tempo=midi_tempo)
|
952
|
+
|
953
|
+
# Create a single instrument (e.g., program=40 = violin)
|
954
|
+
instrument = Instrument(program=40)
|
955
|
+
|
956
|
+
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
|
957
|
+
note = Note(
|
958
|
+
velocity=int(np.round(127 * amplitude)),
|
959
|
+
pitch=note_number,
|
960
|
+
start=start_time,
|
961
|
+
end=end_time,
|
962
|
+
)
|
963
|
+
instrument.notes.append(note)
|
964
|
+
|
965
|
+
if pitch_bend is not None and isinstance(pitch_bend, (list, np.ndarray)):
|
966
|
+
pitch_bend = np.asarray(pitch_bend)
|
967
|
+
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
|
968
|
+
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
|
969
|
+
instrument.pitch_bends.append(PitchBend(pb_midi, pb_time))
|
970
|
+
|
971
|
+
# Add the single instrument to the MIDI object
|
972
|
+
mid.instruments.append(instrument)
|
973
|
+
|
974
|
+
return mid
|
975
|
+
|
976
|
+
def get_pitch_bends(
|
977
|
+
self,
|
978
|
+
contours: np.ndarray, note_events: List[Tuple[int, int, int, float]],
|
979
|
+
timing_refinement_range: int = 0, to_midi: bool = True,
|
980
|
+
) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
|
981
|
+
"""
|
982
|
+
Given note events and contours, estimate pitch bends per note.
|
983
|
+
Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
|
984
|
+
The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
|
985
|
+
the start and end times of each note event.
|
986
|
+
Args:
|
987
|
+
contours: Matrix of estimated pitch contours
|
988
|
+
note_events: note event tuple
|
989
|
+
timing_refinement_range: if > 0, refine onset/offset boundaries with f0 confidence
|
990
|
+
to_midi: whether to convert pitch bends to midi pitch bends. If False, return pitch estimates in the format
|
991
|
+
[time (index), pitch (Hz), confidence in range [0, 1]].
|
992
|
+
Returns:
|
993
|
+
note events with pitch bends
|
994
|
+
"""
|
995
|
+
|
996
|
+
f0_matrix = [] # [time (index), pitch (Hz), confidence in range [0, 1]]
|
997
|
+
note_events_with_pitch_bends = []
|
998
|
+
for start_idx, end_idx, pitch_midi, amplitude in note_events:
|
999
|
+
if timing_refinement_range:
|
1000
|
+
start_idx = np.max([0, start_idx - timing_refinement_range])
|
1001
|
+
end_idx = np.min([contours.shape[0], end_idx + timing_refinement_range])
|
1002
|
+
freq_idx = int(np.round(self.midi_pitch_to_contour_bin(pitch_midi)))
|
1003
|
+
freq_start_idx = np.max([freq_idx - self.labeling.f0_tolerance_bins, 0])
|
1004
|
+
freq_end_idx = np.min([self.labeling.f0_n_bins, freq_idx + self.labeling.f0_tolerance_bins + 1])
|
1005
|
+
|
1006
|
+
trans_start_idx = np.max([0, self.labeling.f0_tolerance_bins - freq_idx])
|
1007
|
+
trans_end_idx = (2 * self.labeling.f0_tolerance_bins + 1) - \
|
1008
|
+
np.max([0, freq_idx - (self.labeling.f0_n_bins - self.labeling.f0_tolerance_bins - 1)])
|
1009
|
+
|
1010
|
+
# apply regional viterbi to estimate the intonation
|
1011
|
+
# observation probabilities come from the f0_roll matrix
|
1012
|
+
observation = contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
|
1013
|
+
observation = observation / observation.sum(axis=1)[:, None]
|
1014
|
+
observation[np.isnan(observation.sum(axis=1)), :] = np.ones(freq_end_idx - freq_start_idx) * 1 / (
|
1015
|
+
freq_end_idx - freq_start_idx)
|
1016
|
+
|
1017
|
+
# transition probabilities assure continuity
|
1018
|
+
transition = self.labeling.f0_transition_matrix[trans_start_idx:trans_end_idx,
|
1019
|
+
trans_start_idx:trans_end_idx] + 1e-6
|
1020
|
+
transition = transition / np.sum(transition, axis=1)[:, None]
|
1021
|
+
|
1022
|
+
path = viterbi_discriminative(observation.T / observation.sum(axis=1), transition) + freq_start_idx
|
1023
|
+
|
1024
|
+
cents = np.array([self.labeling.f0_label2c(contours[i + start_idx, :], path[i]) for i in range(len(path))])
|
1025
|
+
bends = cents - self.labeling.midi_centers_c[pitch_midi - self.labeling.midi_centers[0]]
|
1026
|
+
if to_midi:
|
1027
|
+
bends = (bends * 4096 / 100).astype(int)
|
1028
|
+
bends[bends > 8191] = 8191
|
1029
|
+
bends[bends < -8192] = -8192
|
1030
|
+
|
1031
|
+
if timing_refinement_range:
|
1032
|
+
confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
|
1033
|
+
threshold = np.median(confidences)
|
1034
|
+
threshold = (np.median(confidences > threshold) + threshold) / 2 # some magic
|
1035
|
+
median_kernel = 2 * (timing_refinement_range // 2) + 1 # some more magic
|
1036
|
+
confidences = medfilt(confidences, kernel_size=median_kernel)
|
1037
|
+
conf_bool = confidences > threshold
|
1038
|
+
onset_idx = np.argmax(conf_bool)
|
1039
|
+
offset_idx = len(confidences) - np.argmax(conf_bool[::-1])
|
1040
|
+
bends = bends[onset_idx:offset_idx]
|
1041
|
+
start_idx = start_idx + onset_idx
|
1042
|
+
end_idx = start_idx + offset_idx
|
1043
|
+
|
1044
|
+
note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
|
1045
|
+
else:
|
1046
|
+
confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
|
1047
|
+
time_idx = np.arange(len(path)) + start_idx
|
1048
|
+
# f0_hz = self.labeling.f0_c2hz(cents)
|
1049
|
+
possible_f0s = np.array([time_idx, cents, confidences]).T
|
1050
|
+
f0_matrix.append(possible_f0s[np.abs(bends)<100]) # filter out pitch bends that are too large
|
1051
|
+
if not to_midi:
|
1052
|
+
return np.vstack(f0_matrix)
|
1053
|
+
else:
|
1054
|
+
return note_events_with_pitch_bends
|
1055
|
+
|
1056
|
+
def midi_pitch_to_contour_bin(self, pitch_midi: int) -> np.array:
|
1057
|
+
"""Convert midi pitch to corresponding index in contour matrix
|
1058
|
+
Args:
|
1059
|
+
pitch_midi: pitch in midi
|
1060
|
+
Returns:
|
1061
|
+
index in contour matrix
|
1062
|
+
"""
|
1063
|
+
pitch_hz = midi_to_hz(pitch_midi)
|
1064
|
+
return np.argmin(np.abs(self.labeling.f0_centers_hz - pitch_hz))
|
1065
|
+
|
1066
|
+
def get_inferred_onsets(self,onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array:
|
1067
|
+
"""
|
1068
|
+
Infer onsets from large changes in note roll matrix amplitudes.
|
1069
|
+
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
|
1070
|
+
:param onset_roll: Onset activation matrix (n_times, n_freqs).
|
1071
|
+
:param note_roll: Frame-level note activation matrix (n_times, n_freqs).
|
1072
|
+
:param n_diff: Differences used to detect onsets.
|
1073
|
+
:return: The maximum between the predicted onsets and its differences.
|
1074
|
+
"""
|
1075
|
+
|
1076
|
+
diffs = []
|
1077
|
+
for n in range(1, n_diff + 1):
|
1078
|
+
frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll])
|
1079
|
+
diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
|
1080
|
+
frame_diff = np.min(diffs, axis=0)
|
1081
|
+
frame_diff[frame_diff < 0] = 0
|
1082
|
+
frame_diff[:n_diff, :] = 0
|
1083
|
+
frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
|
1084
|
+
|
1085
|
+
max_onsets_diff = np.max([onset_roll, frame_diff],
|
1086
|
+
axis=0) # use the max of the predicted onsets and the differences
|
1087
|
+
|
1088
|
+
return max_onsets_diff
|
1089
|
+
|
1090
|
+
def spotify_create_notes(
|
1091
|
+
self,
|
1092
|
+
note_roll: np.array,
|
1093
|
+
onset_roll: np.array,
|
1094
|
+
onset_thresh: float,
|
1095
|
+
frame_thresh: float,
|
1096
|
+
min_note_len: int,
|
1097
|
+
infer_onsets: bool,
|
1098
|
+
note_low : int, #self.labeling.midi_centers[0]
|
1099
|
+
note_high : int, #self.labeling.midi_centers[-1],
|
1100
|
+
melodia_trick: bool = True,
|
1101
|
+
energy_tol: int = 11,
|
1102
|
+
) -> List[Tuple[int, int, int, float]]:
|
1103
|
+
"""Decode raw model output to polyphonic note events
|
1104
|
+
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
|
1105
|
+
Args:
|
1106
|
+
note_roll: Frame activation matrix (n_times, n_freqs).
|
1107
|
+
onset_roll: Onset activation matrix (n_times, n_freqs).
|
1108
|
+
onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
|
1109
|
+
frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
|
1110
|
+
min_note_len: Minimum allowed note length in frames.
|
1111
|
+
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
|
1112
|
+
melodia_trick : Whether to use the melodia trick to better detect notes.
|
1113
|
+
energy_tol: Drop notes below this energy.
|
1114
|
+
Returns:
|
1115
|
+
list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
|
1116
|
+
representing the note events, where amplitude is a number between 0 and 1
|
1117
|
+
"""
|
1118
|
+
|
1119
|
+
n_frames = note_roll.shape[0]
|
1120
|
+
|
1121
|
+
# use onsets inferred from frames in addition to the predicted onsets
|
1122
|
+
if infer_onsets:
|
1123
|
+
onset_roll = self.get_inferred_onsets(onset_roll, note_roll)
|
1124
|
+
|
1125
|
+
peak_thresh_mat = np.zeros(onset_roll.shape)
|
1126
|
+
peaks = argrelmax(onset_roll, axis=0)
|
1127
|
+
peak_thresh_mat[peaks] = onset_roll[peaks]
|
1128
|
+
|
1129
|
+
onset_idx = np.where(peak_thresh_mat >= onset_thresh)
|
1130
|
+
onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
|
1131
|
+
onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
|
1132
|
+
|
1133
|
+
remaining_energy = np.zeros(note_roll.shape)
|
1134
|
+
remaining_energy[:, :] = note_roll[:, :]
|
1135
|
+
|
1136
|
+
# loop over onsets
|
1137
|
+
note_events = []
|
1138
|
+
for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
|
1139
|
+
# if we're too close to the end of the audio, continue
|
1140
|
+
if note_start_idx >= n_frames - 1:
|
1141
|
+
continue
|
1142
|
+
|
1143
|
+
# find time index at this frequency band where the frames drop below an energy threshold
|
1144
|
+
i = note_start_idx + 1
|
1145
|
+
k = 0 # number of frames since energy dropped below threshold
|
1146
|
+
while i < n_frames - 1 and k < energy_tol:
|
1147
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
1148
|
+
k += 1
|
1149
|
+
else:
|
1150
|
+
k = 0
|
1151
|
+
i += 1
|
1152
|
+
|
1153
|
+
i -= k # go back to frame above threshold
|
1154
|
+
|
1155
|
+
# if the note is too short, skip it
|
1156
|
+
if i - note_start_idx <= min_note_len:
|
1157
|
+
continue
|
1158
|
+
|
1159
|
+
remaining_energy[note_start_idx:i, freq_idx] = 0
|
1160
|
+
if freq_idx < note_high:
|
1161
|
+
remaining_energy[note_start_idx:i, freq_idx + 1] = 0
|
1162
|
+
if freq_idx > note_low:
|
1163
|
+
remaining_energy[note_start_idx:i, freq_idx - 1] = 0
|
1164
|
+
|
1165
|
+
# add the note
|
1166
|
+
amplitude = np.mean(note_roll[note_start_idx:i, freq_idx])
|
1167
|
+
note_events.append(
|
1168
|
+
(
|
1169
|
+
note_start_idx,
|
1170
|
+
i,
|
1171
|
+
freq_idx + note_low,
|
1172
|
+
amplitude,
|
1173
|
+
)
|
1174
|
+
)
|
1175
|
+
|
1176
|
+
if melodia_trick:
|
1177
|
+
energy_shape = remaining_energy.shape
|
1178
|
+
|
1179
|
+
while np.max(remaining_energy) > frame_thresh:
|
1180
|
+
i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
|
1181
|
+
remaining_energy[i_mid, freq_idx] = 0
|
1182
|
+
|
1183
|
+
# forward pass
|
1184
|
+
i = i_mid + 1
|
1185
|
+
k = 0
|
1186
|
+
while i < n_frames - 1 and k < energy_tol:
|
1187
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
1188
|
+
k += 1
|
1189
|
+
else:
|
1190
|
+
k = 0
|
1191
|
+
|
1192
|
+
remaining_energy[i, freq_idx] = 0
|
1193
|
+
if freq_idx < note_high:
|
1194
|
+
remaining_energy[i, freq_idx + 1] = 0
|
1195
|
+
if freq_idx > note_low:
|
1196
|
+
remaining_energy[i, freq_idx - 1] = 0
|
1197
|
+
|
1198
|
+
i += 1
|
1199
|
+
|
1200
|
+
i_end = i - 1 - k # go back to frame above threshold
|
1201
|
+
|
1202
|
+
# backward pass
|
1203
|
+
i = i_mid - 1
|
1204
|
+
k = 0
|
1205
|
+
while i > 0 and k < energy_tol:
|
1206
|
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
1207
|
+
k += 1
|
1208
|
+
else:
|
1209
|
+
k = 0
|
1210
|
+
|
1211
|
+
remaining_energy[i, freq_idx] = 0
|
1212
|
+
if freq_idx < note_high:
|
1213
|
+
remaining_energy[i, freq_idx + 1] = 0
|
1214
|
+
if freq_idx > note_low:
|
1215
|
+
remaining_energy[i, freq_idx - 1] = 0
|
1216
|
+
|
1217
|
+
i -= 1
|
1218
|
+
|
1219
|
+
i_start = i + 1 + k # go back to frame above threshold
|
1220
|
+
assert i_start >= 0, "{}".format(i_start)
|
1221
|
+
assert i_end < n_frames
|
1222
|
+
|
1223
|
+
if i_end - i_start <= min_note_len:
|
1224
|
+
# note is too short, skip it
|
1225
|
+
continue
|
1226
|
+
|
1227
|
+
# add the note
|
1228
|
+
amplitude = np.mean(note_roll[i_start:i_end, freq_idx])
|
1229
|
+
note_events.append(
|
1230
|
+
(
|
1231
|
+
i_start,
|
1232
|
+
i_end,
|
1233
|
+
freq_idx + note_low,
|
1234
|
+
amplitude,
|
1235
|
+
)
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
return note_events
|
1239
|
+
|
1240
|
+
def read_audio(self, audio):
|
1241
|
+
"""
|
1242
|
+
Read and resample an audio file, convert to mono, and unfold into representation frames.
|
1243
|
+
The time array represents the center of each small frame with 5.8ms hop length. This is different than the chunk
|
1244
|
+
level frames. The chunk level frames represent the entire sequence the model sees. Whereas it predicts with the
|
1245
|
+
small frames intervals (5.8ms).
|
1246
|
+
:param audio: str, pathlib.Path
|
1247
|
+
:return: frames: (n_big_frames, frame_length), times: (n_small_frames,)
|
1248
|
+
"""
|
1249
|
+
audio = torch_from_numpy(librosa_load(audio, sr=self.sr, mono=True)[0])
|
1250
|
+
len_audio = audio.shape[-1]
|
1251
|
+
n_frames = int(np.ceil((len_audio + sum(self.frame_overlap)) / (self.hop_length * self.chunk_size)))
|
1252
|
+
audio = nn.functional.pad(audio, (self.frame_overlap[0],self.frame_overlap[1] + (n_frames * self.hop_length * self.chunk_size) - len_audio))
|
1253
|
+
frames = audio.unfold(0, self.max_window_size, self.hop_length*self.chunk_size)
|
1254
|
+
times = np.arange(0, len_audio, self.hop_length) / self.sr # not tensor, we don't compute anything with it
|
1255
|
+
return frames, times
|
1256
|
+
|
1257
|
+
def model_predict(self, audio, batch_size,progress_callback: Callable[[int, int], None]):
|
1258
|
+
device = self.main.conv0.conv2d.weight.device
|
1259
|
+
performance = {'f0': [], 'note': [], 'onset': [], 'offset': []}
|
1260
|
+
frames, times = self.read_audio(audio)
|
1261
|
+
with torch_no_grad():
|
1262
|
+
for i in range(0, len(frames), batch_size):
|
1263
|
+
f = frames[i:min(i + batch_size, len(frames))].to(device)
|
1264
|
+
f -= (torch_mean(f, axis=1).unsqueeze(-1))
|
1265
|
+
f /= (torch_std(f, axis=1).unsqueeze(-1))
|
1266
|
+
out = self.forward(f)
|
1267
|
+
for key, value in out.items():
|
1268
|
+
value = torch_sigmoid(value)
|
1269
|
+
value = torch_nan_to_num(value) # the model outputs nan when the frame is silent (this is an expected behavior due to normalization)
|
1270
|
+
value = value.view(-1, value.shape[-1])
|
1271
|
+
value = value.detach().cpu().numpy()
|
1272
|
+
performance[key].append(value)
|
1273
|
+
progress_callback(i,len(frames))
|
1274
|
+
performance = {key: np.concatenate(value, axis=0)[:len(times)] for key, value in performance.items()}
|
1275
|
+
performance['time'] = times
|
1276
|
+
return performance
|
1277
|
+
|
1278
|
+
def predict(self, audio, batch_size=32, postprocessing="spotify",include_pitch_bends=True,progress_callback: Callable[[int, int], None] = None,output_file="output.mid"):
|
1279
|
+
output = self.model_predict(audio, batch_size,progress_callback)
|
1280
|
+
self.note2midi(self.out2note(output, postprocessing, include_pitch_bends), 120).write(output_file)
|
1281
|
+
return output_file
|