audio2midi 0.2.0__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audio2midi/crepe_pitch_detector.py +874 -44
- audio2midi/crepe_pitch_detector_tf.py +133 -0
- audio2midi/librosa_pitch_detector.py +3 -4
- audio2midi/violin_pitch_detector.py +2 -1
- audio2midi-0.4.0.dist-info/METADATA +207 -0
- audio2midi-0.4.0.dist-info/RECORD +12 -0
- audio2midi-0.2.0.dist-info/METADATA +0 -106
- audio2midi-0.2.0.dist-info/RECORD +0 -11
- {audio2midi-0.2.0.dist-info → audio2midi-0.4.0.dist-info}/WHEEL +0 -0
@@ -1,54 +1,872 @@
|
|
1
|
-
|
2
|
-
from typing import Callable
|
3
|
-
from numpy.lib.stride_tricks import as_strided
|
4
|
-
from keras.layers import Input, Reshape, Conv2D, BatchNormalization
|
5
|
-
from keras.layers import MaxPool2D, Dropout, Permute, Flatten, Dense
|
6
|
-
from keras.models import Model
|
7
|
-
from keras.callbacks import Callback
|
8
|
-
from hmmlearn.hmm import CategoricalHMM
|
9
|
-
from librosa import load as librosa_load
|
10
|
-
from pretty_midi_fix import PrettyMIDI , PitchBend , Note ,Instrument
|
1
|
+
import warnings
|
11
2
|
import numpy as np
|
3
|
+
import torch
|
4
|
+
import librosa
|
5
|
+
from torch.nn import functional as F
|
6
|
+
from tqdm import tqdm
|
7
|
+
from functools import partial
|
12
8
|
from huggingface_hub import hf_hub_download
|
9
|
+
from scipy.stats import triang
|
10
|
+
|
11
|
+
###############################################################################
|
12
|
+
# Constants
|
13
|
+
###############################################################################
|
14
|
+
|
15
|
+
CENTS_PER_BIN = 20 # cents
|
16
|
+
MAX_FMAX = 2006. # hz
|
17
|
+
PITCH_BINS = 360
|
18
|
+
SAMPLE_RATE = 16000 # hz
|
19
|
+
WINDOW_SIZE = 1024 # samples
|
20
|
+
UNVOICED = np.nan
|
21
|
+
# Minimum decibel level
|
22
|
+
MIN_DB = -100.
|
23
|
+
|
24
|
+
# Reference decibel level
|
25
|
+
REF_DB = 20.
|
26
|
+
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
###############################################################################
|
31
|
+
# Probability sequence decoding methods
|
32
|
+
###############################################################################
|
33
|
+
|
34
|
+
|
35
|
+
def argmax(logits):
|
36
|
+
"""Sample observations by taking the argmax"""
|
37
|
+
bins = logits.argmax(dim=1)
|
38
|
+
|
39
|
+
# Convert to frequency in Hz
|
40
|
+
return bins, bins_to_frequency(bins)
|
41
|
+
|
42
|
+
|
43
|
+
|
44
|
+
|
45
|
+
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
|
52
|
+
###############################################################################
|
53
|
+
# Pitch unit conversions
|
54
|
+
###############################################################################
|
55
|
+
|
56
|
+
|
57
|
+
def bins_to_cents(bins):
|
58
|
+
"""Converts pitch bins to cents"""
|
59
|
+
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
60
|
+
|
61
|
+
# Trade quantization error for noise
|
62
|
+
return dither(cents)
|
63
|
+
|
64
|
+
|
65
|
+
def bins_to_frequency(bins):
|
66
|
+
"""Converts pitch bins to frequency in Hz"""
|
67
|
+
return cents_to_frequency(bins_to_cents(bins))
|
68
|
+
|
69
|
+
|
70
|
+
def cents_to_bins(cents, quantize_fn=torch.floor):
|
71
|
+
"""Converts cents to pitch bins"""
|
72
|
+
bins = (cents - 1997.3794084376191) / CENTS_PER_BIN
|
73
|
+
return quantize_fn(bins).int()
|
74
|
+
|
75
|
+
|
76
|
+
def cents_to_frequency(cents):
|
77
|
+
"""Converts cents to frequency in Hz"""
|
78
|
+
return 10 * 2 ** (cents / 1200)
|
79
|
+
|
80
|
+
|
81
|
+
def frequency_to_bins(frequency, quantize_fn=torch.floor):
|
82
|
+
"""Convert frequency in Hz to pitch bins"""
|
83
|
+
return cents_to_bins(frequency_to_cents(frequency), quantize_fn)
|
84
|
+
|
85
|
+
|
86
|
+
def frequency_to_cents(frequency):
|
87
|
+
"""Convert frequency in Hz to cents"""
|
88
|
+
return 1200 * torch.log2(frequency / 10.)
|
89
|
+
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
|
102
|
+
|
103
|
+
|
104
|
+
###############################################################################
|
105
|
+
# Pitch thresholding methods
|
106
|
+
###############################################################################
|
107
|
+
|
108
|
+
|
109
|
+
class At:
|
110
|
+
"""Simple thresholding at a specified probability value"""
|
111
|
+
|
112
|
+
def __init__(self, value):
|
113
|
+
self.value = value
|
114
|
+
|
115
|
+
def __call__(self, pitch, periodicity):
|
116
|
+
# Make a copy to prevent in-place modification
|
117
|
+
pitch = torch.clone(pitch)
|
118
|
+
|
119
|
+
# Threshold
|
120
|
+
pitch[periodicity < self.value] = UNVOICED
|
121
|
+
return pitch
|
122
|
+
|
123
|
+
|
124
|
+
class Hysteresis:
|
125
|
+
"""Hysteresis thresholding"""
|
126
|
+
|
127
|
+
def __init__(self,
|
128
|
+
lower_bound=.19,
|
129
|
+
upper_bound=.31,
|
130
|
+
width=.2,
|
131
|
+
stds=1.7,
|
132
|
+
return_threshold=False):
|
133
|
+
self.lower_bound = lower_bound
|
134
|
+
self.upper_bound = upper_bound
|
135
|
+
self.width = width
|
136
|
+
self.stds = stds
|
137
|
+
self.return_threshold = return_threshold
|
138
|
+
|
139
|
+
def __call__(self, pitch, periodicity):
|
140
|
+
# Save output device
|
141
|
+
device = pitch.device
|
142
|
+
|
143
|
+
# Perform hysteresis in log-2 space
|
144
|
+
pitch = torch.log2(pitch).detach().flatten().cpu().numpy()
|
145
|
+
|
146
|
+
# Flatten periodicity
|
147
|
+
periodicity = periodicity.flatten().cpu().numpy()
|
148
|
+
|
149
|
+
# Ignore confidently unvoiced pitch
|
150
|
+
pitch[periodicity < self.lower_bound] = UNVOICED
|
151
|
+
|
152
|
+
# Whiten pitch
|
153
|
+
mean, std = np.nanmean(pitch), np.nanstd(pitch)
|
154
|
+
pitch = (pitch - mean) / std
|
155
|
+
|
156
|
+
# Require high confidence to make predictions far from the mean
|
157
|
+
parabola = self.width * pitch ** 2 - self.width * self.stds ** 2
|
158
|
+
threshold = \
|
159
|
+
self.lower_bound + np.clip(parabola, 0, 1 - self.lower_bound)
|
160
|
+
threshold[np.isnan(threshold)] = self.lower_bound
|
161
|
+
|
162
|
+
# Apply hysteresis to prevent short, unconfident voiced regions
|
163
|
+
i = 0
|
164
|
+
while i < len(periodicity) - 1:
|
165
|
+
|
166
|
+
# Detect unvoiced to voiced transition
|
167
|
+
if periodicity[i] < threshold[i] and \
|
168
|
+
periodicity[i + 1] > threshold[i + 1]:
|
169
|
+
|
170
|
+
# Grow region until next unvoiced or end of array
|
171
|
+
start, end, keep = i + 1, i + 1, False
|
172
|
+
while end < len(periodicity) and \
|
173
|
+
periodicity[end] > threshold[end]:
|
174
|
+
if periodicity[end] > self.upper_bound:
|
175
|
+
keep = True
|
176
|
+
end += 1
|
177
|
+
|
178
|
+
# Force unvoiced if we didn't pass the confidence required by
|
179
|
+
# the hysteresis
|
180
|
+
if not keep:
|
181
|
+
threshold[start:end] = 1
|
182
|
+
|
183
|
+
i = end
|
184
|
+
|
185
|
+
else:
|
186
|
+
i += 1
|
187
|
+
|
188
|
+
# Remove pitch with low periodicity
|
189
|
+
pitch[periodicity < threshold] = UNVOICED
|
190
|
+
|
191
|
+
# Unwhiten
|
192
|
+
pitch = pitch * std + mean
|
193
|
+
|
194
|
+
# Convert to Hz
|
195
|
+
pitch = torch.tensor(2 ** pitch, device=device)[None, :]
|
196
|
+
|
197
|
+
# Optionally return threshold
|
198
|
+
if self.return_threshold:
|
199
|
+
return pitch, torch.tensor(threshold, device=device)
|
200
|
+
|
201
|
+
return pitch
|
202
|
+
|
203
|
+
|
204
|
+
###############################################################################
|
205
|
+
# Periodicity thresholding methods
|
206
|
+
###############################################################################
|
207
|
+
|
208
|
+
|
209
|
+
class Silence:
|
210
|
+
"""Set periodicity to zero in silent regions"""
|
211
|
+
|
212
|
+
def __init__(self, value=-60):
|
213
|
+
self.value = value
|
214
|
+
self.a_weighted_weights = self.perceptual_weights()
|
215
|
+
def perceptual_weights(self):
|
216
|
+
"""A-weighted frequency-dependent perceptual loudness weights"""
|
217
|
+
frequencies = librosa.fft_frequencies(sr=SAMPLE_RATE,n_fft=WINDOW_SIZE)
|
218
|
+
|
219
|
+
# A warning is raised for nearly inaudible frequencies, but it ends up
|
220
|
+
# defaulting to -100 db. That default is fine for our purposes.
|
221
|
+
with warnings.catch_warnings():
|
222
|
+
warnings.simplefilter('ignore', RuntimeWarning)
|
223
|
+
return librosa.A_weighting(frequencies)[:, None] - REF_DB
|
224
|
+
|
225
|
+
def a_weighted(self,audio, sample_rate, hop_length=None, pad=True):
|
226
|
+
"""Retrieve the per-frame loudness"""
|
227
|
+
# Save device
|
228
|
+
device = audio.device
|
229
|
+
|
230
|
+
# Default hop length of 10 ms
|
231
|
+
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
232
|
+
|
233
|
+
# Convert to numpy
|
234
|
+
audio = audio.detach().cpu().numpy().squeeze(0)
|
235
|
+
|
236
|
+
# Take stft
|
237
|
+
stft = librosa.stft(audio,
|
238
|
+
n_fft=WINDOW_SIZE,
|
239
|
+
hop_length=hop_length,
|
240
|
+
win_length=WINDOW_SIZE,
|
241
|
+
center=pad,
|
242
|
+
pad_mode='constant')
|
243
|
+
|
244
|
+
# Compute magnitude on db scale
|
245
|
+
db = librosa.amplitude_to_db(np.abs(stft))
|
246
|
+
|
247
|
+
# Apply A-weighting
|
248
|
+
weighted = db + self.a_weighted_weights
|
249
|
+
|
250
|
+
# Threshold
|
251
|
+
weighted[weighted < MIN_DB] = MIN_DB
|
252
|
+
|
253
|
+
# Average over weighted frequencies
|
254
|
+
return torch.from_numpy(weighted.mean(axis=0)).float().to(device)[None]
|
255
|
+
|
256
|
+
def __call__(self,
|
257
|
+
periodicity,
|
258
|
+
audio,
|
259
|
+
sample_rate=SAMPLE_RATE,
|
260
|
+
hop_length=None,
|
261
|
+
pad=True):
|
262
|
+
# Don't modify in-place
|
263
|
+
periodicity = torch.clone(periodicity)
|
264
|
+
|
265
|
+
# Compute loudness
|
266
|
+
loudness = self.a_weighted(
|
267
|
+
audio, sample_rate, hop_length, pad)
|
268
|
+
|
269
|
+
# Threshold silence
|
270
|
+
periodicity[loudness < self.value] = 0.
|
271
|
+
|
272
|
+
return periodicity
|
273
|
+
|
274
|
+
|
275
|
+
|
276
|
+
|
277
|
+
|
278
|
+
|
279
|
+
|
280
|
+
###############################################################################
|
281
|
+
# Sequence filters
|
282
|
+
###############################################################################
|
283
|
+
|
284
|
+
|
285
|
+
def mean(signals, win_length=9):
|
286
|
+
"""Averave filtering for signals containing nan values
|
287
|
+
|
288
|
+
Arguments
|
289
|
+
signals (torch.tensor (shape=(batch, time)))
|
290
|
+
The signals to filter
|
291
|
+
win_length
|
292
|
+
The size of the analysis window
|
293
|
+
|
294
|
+
Returns
|
295
|
+
filtered (torch.tensor (shape=(batch, time)))
|
296
|
+
"""
|
297
|
+
|
298
|
+
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
|
299
|
+
signals = signals.unsqueeze(1)
|
300
|
+
|
301
|
+
# Apply the mask by setting masked elements to zero, or make NaNs zero
|
302
|
+
mask = ~torch.isnan(signals)
|
303
|
+
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
|
304
|
+
|
305
|
+
# Create a ones kernel with the same number of channels as the input tensor
|
306
|
+
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
307
|
+
|
308
|
+
# Perform sum pooling
|
309
|
+
sum_pooled = F.conv1d(
|
310
|
+
masked_x,
|
311
|
+
ones_kernel,
|
312
|
+
stride=1,
|
313
|
+
padding=win_length // 2,
|
314
|
+
)
|
315
|
+
|
316
|
+
# Count the non-masked (valid) elements in each pooling window
|
317
|
+
valid_count = F.conv1d(
|
318
|
+
mask.float(),
|
319
|
+
ones_kernel,
|
320
|
+
stride=1,
|
321
|
+
padding=win_length // 2,
|
322
|
+
)
|
323
|
+
valid_count = valid_count.clamp(min=1) # Avoid division by zero
|
324
|
+
|
325
|
+
# Perform masked average pooling
|
326
|
+
avg_pooled = sum_pooled / valid_count
|
327
|
+
|
328
|
+
# Fill zero values with NaNs
|
329
|
+
avg_pooled[avg_pooled == 0] = float("nan")
|
330
|
+
|
331
|
+
return avg_pooled.squeeze(1)
|
332
|
+
|
333
|
+
|
334
|
+
def median(signals, win_length):
|
335
|
+
"""Median filtering for signals containing nan values
|
336
|
+
|
337
|
+
Arguments
|
338
|
+
signals (torch.tensor (shape=(batch, time)))
|
339
|
+
The signals to filter
|
340
|
+
win_length
|
341
|
+
The size of the analysis window
|
342
|
+
|
343
|
+
Returns
|
344
|
+
filtered (torch.tensor (shape=(batch, time)))
|
345
|
+
"""
|
346
|
+
|
347
|
+
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
|
348
|
+
signals = signals.unsqueeze(1)
|
349
|
+
|
350
|
+
mask = ~torch.isnan(signals)
|
351
|
+
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
|
352
|
+
padding = win_length // 2
|
353
|
+
|
354
|
+
x = F.pad(masked_x, (padding, padding), mode="reflect")
|
355
|
+
mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
356
|
+
|
357
|
+
x = x.unfold(2, win_length, 1)
|
358
|
+
mask = mask.unfold(2, win_length, 1)
|
359
|
+
|
360
|
+
x = x.contiguous().view(x.size()[:3] + (-1,))
|
361
|
+
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
362
|
+
|
363
|
+
# Combine the mask with the input tensor
|
364
|
+
x_masked = torch.where(mask.bool(), x.float(), float("inf")).to(x)
|
365
|
+
|
366
|
+
# Sort the masked tensor along the last dimension
|
367
|
+
x_sorted, _ = torch.sort(x_masked, dim=-1)
|
368
|
+
|
369
|
+
# Compute the count of non-masked (valid) values
|
370
|
+
valid_count = mask.sum(dim=-1)
|
371
|
+
|
372
|
+
# Calculate the index of the median value for each pooling window
|
373
|
+
median_idx = ((valid_count - 1) // 2).clamp(min=0)
|
374
|
+
|
375
|
+
# Gather the median values using the calculated indices
|
376
|
+
median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
|
377
|
+
|
378
|
+
# Fill infinite values with NaNs
|
379
|
+
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
13
380
|
|
14
|
-
|
15
|
-
|
381
|
+
return median_pooled.squeeze(1)
|
382
|
+
|
383
|
+
|
384
|
+
###############################################################################
|
385
|
+
# Utilities
|
386
|
+
###############################################################################
|
387
|
+
|
388
|
+
|
389
|
+
def nanfilter(signals, win_length, filter_fn):
|
390
|
+
"""Filters a sequence, ignoring nan values
|
391
|
+
|
392
|
+
Arguments
|
393
|
+
signals (torch.tensor (shape=(batch, time)))
|
394
|
+
The signals to filter
|
395
|
+
win_length
|
396
|
+
The size of the analysis window
|
397
|
+
filter_fn (function)
|
398
|
+
The function to use for filtering
|
399
|
+
|
400
|
+
Returns
|
401
|
+
filtered (torch.tensor (shape=(batch, time)))
|
402
|
+
"""
|
403
|
+
# Output buffer
|
404
|
+
filtered = torch.empty_like(signals)
|
405
|
+
|
406
|
+
# Loop over frames
|
407
|
+
for i in range(signals.size(1)):
|
408
|
+
|
409
|
+
# Get analysis window bounds
|
410
|
+
start = max(0, i - win_length // 2)
|
411
|
+
end = min(signals.size(1), i + win_length // 2 + 1)
|
412
|
+
|
413
|
+
# Apply filter to window
|
414
|
+
filtered[:, i] = filter_fn(signals[:, start:end])
|
415
|
+
|
416
|
+
return filtered
|
417
|
+
|
418
|
+
|
419
|
+
def nanmean(signals):
|
420
|
+
"""Computes the mean, ignoring nans
|
421
|
+
|
422
|
+
Arguments
|
423
|
+
signals (torch.tensor [shape=(batch, time)])
|
424
|
+
The signals to filter
|
425
|
+
|
426
|
+
Returns
|
427
|
+
filtered (torch.tensor [shape=(batch, time)])
|
428
|
+
"""
|
429
|
+
signals = signals.clone()
|
430
|
+
|
431
|
+
# Find nans
|
432
|
+
nans = torch.isnan(signals)
|
433
|
+
|
434
|
+
# Set nans to 0.
|
435
|
+
signals[nans] = 0.
|
436
|
+
|
437
|
+
# Compute average
|
438
|
+
return signals.sum(dim=1) / (~nans).float().sum(dim=1)
|
439
|
+
|
440
|
+
|
441
|
+
def nanmedian(signals):
|
442
|
+
"""Computes the median, ignoring nans
|
443
|
+
|
444
|
+
Arguments
|
445
|
+
signals (torch.tensor [shape=(batch, time)])
|
446
|
+
The signals to filter
|
447
|
+
|
448
|
+
Returns
|
449
|
+
filtered (torch.tensor [shape=(batch, time)])
|
450
|
+
"""
|
451
|
+
# Find nans
|
452
|
+
nans = torch.isnan(signals)
|
453
|
+
|
454
|
+
# Compute median for each slice
|
455
|
+
medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)]
|
456
|
+
|
457
|
+
# Stack results
|
458
|
+
return torch.tensor(medians, dtype=signals.dtype, device=signals.device)
|
459
|
+
|
460
|
+
|
461
|
+
def nanmedian1d(signal):
|
462
|
+
"""Computes the median. If signal is empty, returns torch.nan
|
463
|
+
|
464
|
+
Arguments
|
465
|
+
signal (torch.tensor [shape=(time,)])
|
466
|
+
|
467
|
+
Returns
|
468
|
+
median (torch.tensor [shape=(1,)])
|
469
|
+
"""
|
470
|
+
return torch.median(signal) if signal.numel() else np.nan
|
471
|
+
|
472
|
+
|
473
|
+
def dither(cents):
|
474
|
+
"""Dither the predicted pitch in cents to remove quantization error"""
|
475
|
+
noise = triang.rvs(c=0.5,loc=-CENTS_PER_BIN,scale=2 * CENTS_PER_BIN,size=cents.size())
|
476
|
+
return cents + cents.new_tensor(noise)
|
477
|
+
|
478
|
+
def periodicity(probabilities, bins):
|
479
|
+
"""Computes the periodicity from the network output and pitch bins"""
|
480
|
+
# shape=(batch * time / hop_length, 360)
|
481
|
+
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
482
|
+
|
483
|
+
# shape=(batch * time / hop_length, 1)
|
484
|
+
bins_stacked = bins.reshape(-1, 1).to(torch.int64)
|
485
|
+
|
486
|
+
# Use maximum logit over pitch bins as periodicity
|
487
|
+
periodicity = probs_stacked.gather(1, bins_stacked)
|
488
|
+
|
489
|
+
# shape=(batch, time / hop_length)
|
490
|
+
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
491
|
+
|
492
|
+
|
493
|
+
|
494
|
+
|
495
|
+
|
496
|
+
|
497
|
+
|
498
|
+
|
499
|
+
|
500
|
+
|
501
|
+
|
502
|
+
|
503
|
+
|
504
|
+
class CrepeTorch(torch.nn.Module):
|
505
|
+
|
506
|
+
def __init__(self, model_type='full',model_path=None):
|
16
507
|
super().__init__()
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
self.
|
508
|
+
xx, yy = np.meshgrid(range(360), range(360))
|
509
|
+
transition = np.maximum(12 - abs(xx - yy), 0)
|
510
|
+
self.viterbi_transition = transition / transition.sum(axis=1, keepdims=True)
|
511
|
+
model_type_importance = {'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32}[model_type]
|
512
|
+
out_channels = [n * model_type_importance for n in [32, 4, 4, 4, 8, 16]]
|
513
|
+
in_channels = [n * model_type_importance for n in [32, 4, 4, 4, 8]]
|
514
|
+
in_channels.insert(0,1)
|
515
|
+
self.in_features = 64*model_type_importance
|
516
|
+
# Shared layer parameters
|
517
|
+
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
518
|
+
strides = [(4, 1)] + 5 * [(1, 1)]
|
519
|
+
|
520
|
+
# Overload with eps and momentum conversion given by MMdnn
|
521
|
+
batch_norm_fn = partial(torch.nn.BatchNorm2d,eps=0.0010000000474974513,momentum=0.0)
|
522
|
+
|
523
|
+
# Layer definitions
|
524
|
+
self.conv1 = torch.nn.Conv2d(
|
525
|
+
in_channels=in_channels[0],
|
526
|
+
out_channels=out_channels[0],
|
527
|
+
kernel_size=kernel_sizes[0],
|
528
|
+
stride=strides[0])
|
529
|
+
self.conv1_BN = batch_norm_fn(
|
530
|
+
num_features=out_channels[0])
|
531
|
+
|
532
|
+
self.conv2 = torch.nn.Conv2d(
|
533
|
+
in_channels=in_channels[1],
|
534
|
+
out_channels=out_channels[1],
|
535
|
+
kernel_size=kernel_sizes[1],
|
536
|
+
stride=strides[1])
|
537
|
+
self.conv2_BN = batch_norm_fn(
|
538
|
+
num_features=out_channels[1])
|
539
|
+
|
540
|
+
self.conv3 = torch.nn.Conv2d(
|
541
|
+
in_channels=in_channels[2],
|
542
|
+
out_channels=out_channels[2],
|
543
|
+
kernel_size=kernel_sizes[2],
|
544
|
+
stride=strides[2])
|
545
|
+
self.conv3_BN = batch_norm_fn(
|
546
|
+
num_features=out_channels[2])
|
547
|
+
|
548
|
+
self.conv4 = torch.nn.Conv2d(
|
549
|
+
in_channels=in_channels[3],
|
550
|
+
out_channels=out_channels[3],
|
551
|
+
kernel_size=kernel_sizes[3],
|
552
|
+
stride=strides[3])
|
553
|
+
self.conv4_BN = batch_norm_fn(
|
554
|
+
num_features=out_channels[3])
|
555
|
+
|
556
|
+
self.conv5 = torch.nn.Conv2d(
|
557
|
+
in_channels=in_channels[4],
|
558
|
+
out_channels=out_channels[4],
|
559
|
+
kernel_size=kernel_sizes[4],
|
560
|
+
stride=strides[4])
|
561
|
+
self.conv5_BN = batch_norm_fn(
|
562
|
+
num_features=out_channels[4])
|
563
|
+
|
564
|
+
self.conv6 = torch.nn.Conv2d(
|
565
|
+
in_channels=in_channels[5],
|
566
|
+
out_channels=out_channels[5],
|
567
|
+
kernel_size=kernel_sizes[5],
|
568
|
+
stride=strides[5])
|
569
|
+
self.conv6_BN = batch_norm_fn(
|
570
|
+
num_features=out_channels[5])
|
571
|
+
|
572
|
+
self.classifier = torch.nn.Linear(
|
573
|
+
in_features=self.in_features,
|
574
|
+
out_features=PITCH_BINS)
|
575
|
+
if not model_path:
|
576
|
+
model_path = hf_hub_download("shethjenil/Audio2Midi_Models",f"crepe_{model_type}.pt")
|
577
|
+
self.load_state_dict(torch.load(model_path))
|
578
|
+
self.eval()
|
579
|
+
|
580
|
+
def forward(self, x, embed=False):
|
581
|
+
# Forward pass through first five layers
|
582
|
+
x = self.embed(x)
|
583
|
+
|
584
|
+
if embed:
|
585
|
+
return x
|
586
|
+
|
587
|
+
# Forward pass through layer six
|
588
|
+
x = self.layer(x, self.conv6, self.conv6_BN)
|
589
|
+
|
590
|
+
# shape=(batch, self.in_features)
|
591
|
+
x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features)
|
592
|
+
|
593
|
+
# Compute logits
|
594
|
+
return torch.sigmoid(self.classifier(x))
|
595
|
+
|
596
|
+
def embed(self, x):
|
597
|
+
"""Map input audio to pitch embedding"""
|
598
|
+
# shape=(batch, 1, 1024, 1)
|
599
|
+
x = x[:, None, :, None]
|
25
600
|
|
601
|
+
# Forward pass through first five layers
|
602
|
+
x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254))
|
603
|
+
x = self.layer(x, self.conv2, self.conv2_BN)
|
604
|
+
x = self.layer(x, self.conv3, self.conv3_BN)
|
605
|
+
x = self.layer(x, self.conv4, self.conv4_BN)
|
606
|
+
x = self.layer(x, self.conv5, self.conv5_BN)
|
607
|
+
return x
|
608
|
+
|
609
|
+
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
610
|
+
"""Forward pass through one layer"""
|
611
|
+
x = F.pad(x, padding)
|
612
|
+
x = conv(x)
|
613
|
+
x = F.relu(x)
|
614
|
+
x = batch_norm(x)
|
615
|
+
return F.max_pool2d(x, (2, 1), (2, 1))
|
616
|
+
|
617
|
+
def viterbi(self,logits):
|
618
|
+
"""Sample observations using viterbi decoding"""
|
619
|
+
# Normalize logits
|
620
|
+
with torch.no_grad():
|
621
|
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
622
|
+
|
623
|
+
# Convert to numpy
|
624
|
+
sequences = probs.cpu().numpy()
|
625
|
+
|
626
|
+
# Perform viterbi decoding
|
627
|
+
bins = np.array([
|
628
|
+
librosa.sequence.viterbi(sequence, self.viterbi_transition).astype(np.int64)
|
629
|
+
for sequence in sequences])
|
630
|
+
|
631
|
+
# Convert to pytorch
|
632
|
+
bins = torch.tensor(bins, device=probs.device)
|
633
|
+
|
634
|
+
# Convert to frequency in Hz
|
635
|
+
return bins, bins_to_frequency(bins)
|
636
|
+
|
637
|
+
def get_device(self):
|
638
|
+
return next(self.parameters()).device
|
639
|
+
|
640
|
+
def preprocess(self,audio,sample_rate,hop_length=None,batch_size=None,pad=True):
|
641
|
+
"""Convert audio to model input
|
642
|
+
|
643
|
+
Arguments
|
644
|
+
audio (torch.tensor [shape=(1, time)])
|
645
|
+
The audio signals
|
646
|
+
sample_rate (int)
|
647
|
+
The sampling rate in Hz
|
648
|
+
hop_length (int)
|
649
|
+
The hop_length in samples
|
650
|
+
batch_size (int)
|
651
|
+
The number of frames per batch
|
652
|
+
pad (bool)
|
653
|
+
Whether to zero-pad the audio
|
654
|
+
|
655
|
+
Returns
|
656
|
+
frames (torch.tensor [shape=(1 + int(time // hop_length), 1024)])
|
657
|
+
"""
|
658
|
+
# Default hop length of 10 ms
|
659
|
+
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
660
|
+
|
661
|
+
# Get total number of frames
|
662
|
+
|
663
|
+
# Maybe pad
|
664
|
+
if pad:
|
665
|
+
total_frames = 1 + int(audio.size(1) // hop_length)
|
666
|
+
audio = torch.nn.functional.pad(
|
667
|
+
audio,
|
668
|
+
(WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
669
|
+
else:
|
670
|
+
total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
671
|
+
|
672
|
+
# Default to running all frames in a single batch
|
673
|
+
batch_size = total_frames if batch_size is None else batch_size
|
674
|
+
|
675
|
+
# Generate batches
|
676
|
+
for i in range(0, total_frames, batch_size):
|
677
|
+
|
678
|
+
# Batch indices
|
679
|
+
start = max(0, i * hop_length)
|
680
|
+
end = min(audio.size(1),
|
681
|
+
(i + batch_size - 1) * hop_length + WINDOW_SIZE)
|
682
|
+
|
683
|
+
# Chunk
|
684
|
+
frames = torch.nn.functional.unfold(
|
685
|
+
audio[:, None, None, start:end],
|
686
|
+
kernel_size=(1, WINDOW_SIZE),
|
687
|
+
stride=(1, hop_length))
|
688
|
+
|
689
|
+
# shape=(1 + int(time / hop_length, 1024)
|
690
|
+
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE)
|
691
|
+
|
692
|
+
# Place on device
|
693
|
+
frames = frames.to(self.get_device())
|
694
|
+
|
695
|
+
# Mean-center
|
696
|
+
frames -= frames.mean(dim=1, keepdim=True)
|
697
|
+
|
698
|
+
# Scale
|
699
|
+
# Note: during silent frames, this produces very large values. But
|
700
|
+
# this seems to be what the network expects.
|
701
|
+
frames /= torch.max(torch.tensor(1e-10, device=frames.device),frames.std(dim=1, keepdim=True))
|
702
|
+
|
703
|
+
yield frames
|
704
|
+
|
705
|
+
def postprocess(self,probabilities,fmin=0.,fmax=MAX_FMAX,return_periodicity=False):
|
706
|
+
"""Convert model output to F0 and periodicity
|
707
|
+
|
708
|
+
Arguments
|
709
|
+
probabilities (torch.tensor [shape=(1, 360, time / hop_length)])
|
710
|
+
The probabilities for each pitch bin inferred by the network
|
711
|
+
fmin (float)
|
712
|
+
The minimum allowable frequency in Hz
|
713
|
+
fmax (float)
|
714
|
+
The maximum allowable frequency in Hz
|
715
|
+
viterbi (bool)
|
716
|
+
Whether to use viterbi decoding
|
717
|
+
return_periodicity (bool)
|
718
|
+
Whether to also return the network confidence
|
719
|
+
|
720
|
+
Returns
|
721
|
+
pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
|
722
|
+
periodicity (torch.tensor [shape=(1, 1 + int(time // hop_length))])
|
723
|
+
"""
|
724
|
+
# Sampling is non-differentiable, so remove from graph
|
725
|
+
probabilities = probabilities.detach()
|
726
|
+
|
727
|
+
# Convert frequency range to pitch bin range
|
728
|
+
minidx = frequency_to_bins(torch.tensor(fmin))
|
729
|
+
maxidx = frequency_to_bins(torch.tensor(fmax),
|
730
|
+
torch.ceil)
|
731
|
+
|
732
|
+
# Remove frequencies outside of allowable range
|
733
|
+
probabilities[:, :minidx] = -float('inf')
|
734
|
+
probabilities[:, maxidx:] = -float('inf')
|
735
|
+
|
736
|
+
# Perform argmax or viterbi sampling
|
737
|
+
bins, pitch = self.viterbi(probabilities)
|
738
|
+
|
739
|
+
if not return_periodicity:
|
740
|
+
return pitch
|
741
|
+
|
742
|
+
# Compute periodicity from probabilities and decoded pitch bins
|
743
|
+
return pitch, periodicity(probabilities, bins)
|
744
|
+
|
745
|
+
def predict(self,audio,sample_rate,hop_length=None,fmin=50.,fmax=MAX_FMAX,return_periodicity=False,batch_size=None,pad=True):
|
746
|
+
"""Performs pitch estimation
|
747
|
+
|
748
|
+
Arguments
|
749
|
+
audio (torch.tensor [shape=(1, time)])
|
750
|
+
The audio signal
|
751
|
+
sample_rate (int)
|
752
|
+
The sampling rate in Hz
|
753
|
+
hop_length (int)
|
754
|
+
The hop_length in samples
|
755
|
+
fmin (float)
|
756
|
+
The minimum allowable frequency in Hz
|
757
|
+
fmax (float)
|
758
|
+
The maximum allowable frequency in Hz
|
759
|
+
return_periodicity (bool)
|
760
|
+
Whether to also return the network confidence
|
761
|
+
batch_size (int)
|
762
|
+
The number of frames per batch
|
763
|
+
pad (bool)
|
764
|
+
Whether to zero-pad the audio
|
765
|
+
|
766
|
+
Returns
|
767
|
+
pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
|
768
|
+
(Optional) periodicity (torch.tensor
|
769
|
+
[shape=(1, 1 + int(time // hop_length))])
|
770
|
+
"""
|
771
|
+
|
772
|
+
results = []
|
773
|
+
with torch.no_grad():
|
774
|
+
print("prediction started")
|
775
|
+
for frames in self.preprocess(audio,sample_rate,hop_length,batch_size,pad):
|
776
|
+
# shape=(batch, 360, time / hop_length)
|
777
|
+
result = self.postprocess(self.forward(frames, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2),fmin,fmax,return_periodicity)
|
778
|
+
if isinstance(result, tuple):
|
779
|
+
result = (result[0].to(audio.device),result[1].to(audio.device))
|
780
|
+
else:
|
781
|
+
result = result.to(audio.device)
|
782
|
+
results.append(result)
|
783
|
+
print("prediction finished")
|
784
|
+
if return_periodicity:
|
785
|
+
pitch, periodicity = zip(*results)
|
786
|
+
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
787
|
+
return torch.cat(results, 1)
|
788
|
+
|
789
|
+
def predict_from_file(self,audio_file,hop_length=None,fmin=50.,fmax=MAX_FMAX,return_periodicity=False,batch_size=None,pad=True):
|
790
|
+
audio, sample_rate = librosa.load(audio_file, sr=SAMPLE_RATE)
|
791
|
+
return self.predict(torch.from_numpy(audio).unsqueeze(0),sample_rate,hop_length,fmin,fmax,return_periodicity,batch_size,pad)
|
792
|
+
|
793
|
+
def predict_from_file_to_file(self,audio_file,output_pitch_file,output_periodicity_file=None,hop_length=None,fmin=50.,fmax=MAX_FMAX,batch_size=None,pad=True):
|
794
|
+
prediction = self.predict_from_file(audio_file,hop_length,fmin,fmax,False,output_periodicity_file is not None,batch_size,pad)
|
795
|
+
if output_periodicity_file is not None:
|
796
|
+
torch.save(prediction[0].detach(), output_pitch_file)
|
797
|
+
torch.save(prediction[1].detach(), output_periodicity_file)
|
798
|
+
else:
|
799
|
+
torch.save(prediction.detach(), output_pitch_file)
|
800
|
+
|
801
|
+
def predict_from_files_to_files(self,audio_files,output_pitch_files,output_periodicity_files=None,hop_length=None,fmin=50.,fmax=MAX_FMAX,batch_size=None,pad=True):
|
802
|
+
if output_periodicity_files is None:
|
803
|
+
output_periodicity_files = len(audio_files) * [None]
|
804
|
+
for audio_file, output_pitch_file, output_periodicity_file in tqdm(zip(audio_files, output_pitch_files, output_periodicity_files), desc='torchcrepe', dynamic_ncols=True):
|
805
|
+
self.predict_from_file_to_file(audio_file,output_pitch_file,None,output_periodicity_file,hop_length,fmin,fmax,batch_size,pad)
|
806
|
+
|
807
|
+
def embedding(self,audio,sample_rate,hop_length=None,batch_size=None,pad=True):
|
808
|
+
"""Embeds audio to the output of CREPE's fifth maxpool layer
|
809
|
+
|
810
|
+
Arguments
|
811
|
+
audio (torch.tensor [shape=(1, time)])
|
812
|
+
The audio signals
|
813
|
+
sample_rate (int)
|
814
|
+
The sampling rate in Hz
|
815
|
+
hop_length (int)
|
816
|
+
The hop_length in samples
|
817
|
+
batch_size (int)
|
818
|
+
The number of frames per batch
|
819
|
+
pad (bool)
|
820
|
+
Whether to zero-pad the audio
|
821
|
+
|
822
|
+
Returns
|
823
|
+
embedding (torch.tensor [shape=(1,
|
824
|
+
1 + int(time // hop_length), 32, -1)])
|
825
|
+
"""
|
826
|
+
# shape=(batch, time / hop_length, 32, embedding_size)
|
827
|
+
with torch.no_grad():
|
828
|
+
return torch.cat([self.forward(frames, embed=True).reshape(audio.size(0), frames.size(0), 32, -1).to(audio.device) for frames in self.preprocess(audio,sample_rate,hop_length,batch_size,pad)], 1)
|
829
|
+
|
830
|
+
def embedding_from_file(self,audio_file,hop_length=None,batch_size=None,pad=True):
|
831
|
+
audio, sample_rate = librosa.load(audio_file, sr=SAMPLE_RATE)
|
832
|
+
return self.embed(torch.from_numpy(audio).unsqueeze(0),sample_rate,hop_length,batch_size,pad)
|
833
|
+
|
834
|
+
def embedding_from_file_to_file(self,audio_file,output_file,hop_length=None,batch_size=None,pad=True):
|
835
|
+
with torch.no_grad():
|
836
|
+
torch.save(self.embed_from_file(audio_file,hop_length,batch_size,pad).detach(), output_file)
|
837
|
+
|
838
|
+
def embedding_from_files_to_files(self,audio_files,output_files,hop_length=None,batch_size=None,pad=True):
|
839
|
+
for audio_file, output_file in tqdm(zip(audio_files, output_files), desc='torchcrepe', dynamic_ncols=True):
|
840
|
+
self.embed_from_file_to_file(audio_file,output_file,hop_length,batch_size,pad)
|
841
|
+
|
842
|
+
|
843
|
+
|
844
|
+
|
845
|
+
|
846
|
+
|
847
|
+
|
848
|
+
|
849
|
+
|
850
|
+
|
851
|
+
|
852
|
+
|
853
|
+
from hmmlearn.hmm import CategoricalHMM
|
854
|
+
from typing import Callable
|
855
|
+
from numpy.lib.stride_tricks import as_strided
|
856
|
+
from pretty_midi_fix import PrettyMIDI , PitchBend , Note ,Instrument
|
857
|
+
from huggingface_hub import hf_hub_download
|
858
|
+
from torch.utils.data import DataLoader, TensorDataset
|
26
859
|
|
27
860
|
class Crepe():
|
861
|
+
|
28
862
|
def __init__(self,model_type="full",model_path=None):
|
29
|
-
|
30
|
-
model_path = hf_hub_download("shethjenil/Audio2Midi_Models",f"crepe_{model_type}.h5")
|
31
|
-
model_type_importance = {'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32}[model_type]
|
32
|
-
filters = [n * model_type_importance for n in [32, 4, 4, 4, 8, 16]]
|
33
|
-
widths = [512, 64, 64, 64, 64, 64]
|
34
|
-
strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
|
35
|
-
x = Input(shape=(1024,), name='input', dtype='float32')
|
36
|
-
y = Reshape(target_shape=(1024, 1, 1), name='input-reshape')(x)
|
37
|
-
layers = [1, 2, 3, 4, 5, 6]
|
38
|
-
for l, f, w, s in zip(layers, filters, widths, strides):
|
39
|
-
y = Conv2D(f, (w, 1), strides=s, padding='same', activation='relu', name="conv%d" % l)(y)
|
40
|
-
y = BatchNormalization(name="conv%d-BN" % l)(y)
|
41
|
-
y = MaxPool2D(pool_size=(2, 1), strides=None, padding='valid', name="conv%d-maxpool" % l)(y)
|
42
|
-
y = Dropout(0.25, name="conv%d-dropout" % l)(y)
|
43
|
-
y = Permute((2, 1, 3), name="transpose")(y)
|
44
|
-
y = Flatten(name="flatten")(y)
|
45
|
-
y = Dense(360, activation='sigmoid', name="classifier")(y)
|
46
|
-
self.model = Model(inputs=x, outputs=y)
|
47
|
-
self.model.load_weights(model_path)
|
48
|
-
self.model.compile('adam', 'binary_crossentropy')
|
863
|
+
self.model = CrepeTorch(model_type,model_path)
|
49
864
|
self.cents_mapping=(np.linspace(0, 7180, 360) + 1997.3794084376191)
|
50
865
|
|
51
866
|
def to_local_average_cents(self, salience, center=None):
|
867
|
+
if isinstance(salience, torch.Tensor):
|
868
|
+
salience = salience.numpy()
|
869
|
+
|
52
870
|
if salience.ndim == 1:
|
53
871
|
if center is None:
|
54
872
|
center = int(np.argmax(salience))
|
@@ -63,6 +881,8 @@ class Crepe():
|
|
63
881
|
raise Exception("label should be either 1d or 2d ndarray")
|
64
882
|
|
65
883
|
def to_viterbi_cents(self,salience):
|
884
|
+
if isinstance(salience, torch.Tensor):
|
885
|
+
salience = salience.numpy()
|
66
886
|
starting = np.ones(360) / 360
|
67
887
|
xx, yy = np.meshgrid(range(360), range(360))
|
68
888
|
transition = np.maximum(12 - abs(xx - yy), 0)
|
@@ -84,11 +904,21 @@ class Crepe():
|
|
84
904
|
frames = frames.transpose().copy()
|
85
905
|
frames -= np.mean(frames, axis=1)[:, np.newaxis]
|
86
906
|
frames /= np.clip(np.std(frames, axis=1)[:, np.newaxis], 1e-8, None)
|
87
|
-
|
88
|
-
|
907
|
+
device = self.model.get_device()
|
908
|
+
all_outputs = []
|
909
|
+
all_batch = list(DataLoader(TensorDataset(torch.from_numpy(frames)), batch_size=batch_size, shuffle=False))
|
910
|
+
total_batch = len(all_batch)
|
911
|
+
with torch.no_grad():
|
912
|
+
for i , batch in enumerate(all_batch):
|
913
|
+
inputs = batch[0].to(device)
|
914
|
+
outputs = self.model(inputs)
|
915
|
+
all_outputs.append(outputs.cpu())
|
916
|
+
progress_callback(i,total_batch)
|
917
|
+
return torch.cat(all_outputs, dim=0)
|
918
|
+
|
89
919
|
def model_predict(self,audio:np.ndarray,viterbi, center, step_size,progress_callback,batch_size):
|
90
920
|
activation = self.get_activation(audio.astype(np.float32), center, step_size,progress_callback,batch_size)
|
91
|
-
confidence = activation.max(axis=1)
|
921
|
+
confidence = activation.max(axis=1).values # Access the values from the named tuple
|
92
922
|
cents = self.to_viterbi_cents(activation) if viterbi else self.to_local_average_cents(activation)
|
93
923
|
frequency = 10 * 2 ** (cents / 1200)
|
94
924
|
frequency[np.isnan(frequency)] = 0
|
@@ -96,7 +926,7 @@ class Crepe():
|
|
96
926
|
return time, frequency, confidence
|
97
927
|
|
98
928
|
def predict(self,audio_path,viterbi=False, center=True, step_size=10,min_confidence=0.8,batch_size=32,progress_callback: Callable[[int, int], None] = None,output_file= "output.mid"):
|
99
|
-
time, frequency, confidence = self.model_predict(
|
929
|
+
time, frequency, confidence = self.model_predict(librosa.load(audio_path, sr=16000, mono=True)[0],viterbi,center,step_size,progress_callback,batch_size)
|
100
930
|
mask = confidence > min_confidence
|
101
931
|
times = time[mask]
|
102
932
|
frequencies = frequency[mask]
|