audio2midi 0.3.0__py2.py3-none-any.whl → 0.5.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,57 +1,872 @@
1
- from math import ceil as math_ceil
2
- from typing import Callable
3
- from numpy.lib.stride_tricks import as_strided
4
- from keras.layers import Input, Reshape, Conv2D, BatchNormalization
5
- from keras.layers import MaxPool2D, Dropout, Permute, Flatten, Dense
6
- from keras.models import Model
7
- from keras.callbacks import Callback
8
- from hmmlearn.hmm import CategoricalHMM
9
- from librosa import load as librosa_load
10
- from pretty_midi_fix import PrettyMIDI , PitchBend , Note ,Instrument
1
+ import warnings
11
2
  import numpy as np
3
+ import torch
4
+ import librosa
5
+ from torch.nn import functional as F
6
+ from tqdm import tqdm
7
+ from functools import partial
12
8
  from huggingface_hub import hf_hub_download
9
+ from scipy.stats import triang
10
+
11
+ ###############################################################################
12
+ # Constants
13
+ ###############################################################################
14
+
15
+ CENTS_PER_BIN = 20 # cents
16
+ MAX_FMAX = 2006. # hz
17
+ PITCH_BINS = 360
18
+ SAMPLE_RATE = 16000 # hz
19
+ WINDOW_SIZE = 1024 # samples
20
+ UNVOICED = np.nan
21
+ # Minimum decibel level
22
+ MIN_DB = -100.
23
+
24
+ # Reference decibel level
25
+ REF_DB = 20.
26
+
27
+
28
+
29
+
30
+ ###############################################################################
31
+ # Probability sequence decoding methods
32
+ ###############################################################################
33
+
34
+
35
+ def argmax(logits):
36
+ """Sample observations by taking the argmax"""
37
+ bins = logits.argmax(dim=1)
38
+
39
+ # Convert to frequency in Hz
40
+ return bins, bins_to_frequency(bins)
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+ ###############################################################################
53
+ # Pitch unit conversions
54
+ ###############################################################################
55
+
56
+
57
+ def bins_to_cents(bins):
58
+ """Converts pitch bins to cents"""
59
+ cents = CENTS_PER_BIN * bins + 1997.3794084376191
60
+
61
+ # Trade quantization error for noise
62
+ return dither(cents)
63
+
64
+
65
+ def bins_to_frequency(bins):
66
+ """Converts pitch bins to frequency in Hz"""
67
+ return cents_to_frequency(bins_to_cents(bins))
68
+
69
+
70
+ def cents_to_bins(cents, quantize_fn=torch.floor):
71
+ """Converts cents to pitch bins"""
72
+ bins = (cents - 1997.3794084376191) / CENTS_PER_BIN
73
+ return quantize_fn(bins).int()
74
+
75
+
76
+ def cents_to_frequency(cents):
77
+ """Converts cents to frequency in Hz"""
78
+ return 10 * 2 ** (cents / 1200)
79
+
80
+
81
+ def frequency_to_bins(frequency, quantize_fn=torch.floor):
82
+ """Convert frequency in Hz to pitch bins"""
83
+ return cents_to_bins(frequency_to_cents(frequency), quantize_fn)
84
+
85
+
86
+ def frequency_to_cents(frequency):
87
+ """Convert frequency in Hz to cents"""
88
+ return 1200 * torch.log2(frequency / 10.)
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+ ###############################################################################
105
+ # Pitch thresholding methods
106
+ ###############################################################################
107
+
108
+
109
+ class At:
110
+ """Simple thresholding at a specified probability value"""
111
+
112
+ def __init__(self, value):
113
+ self.value = value
114
+
115
+ def __call__(self, pitch, periodicity):
116
+ # Make a copy to prevent in-place modification
117
+ pitch = torch.clone(pitch)
118
+
119
+ # Threshold
120
+ pitch[periodicity < self.value] = UNVOICED
121
+ return pitch
122
+
123
+
124
+ class Hysteresis:
125
+ """Hysteresis thresholding"""
126
+
127
+ def __init__(self,
128
+ lower_bound=.19,
129
+ upper_bound=.31,
130
+ width=.2,
131
+ stds=1.7,
132
+ return_threshold=False):
133
+ self.lower_bound = lower_bound
134
+ self.upper_bound = upper_bound
135
+ self.width = width
136
+ self.stds = stds
137
+ self.return_threshold = return_threshold
138
+
139
+ def __call__(self, pitch, periodicity):
140
+ # Save output device
141
+ device = pitch.device
142
+
143
+ # Perform hysteresis in log-2 space
144
+ pitch = torch.log2(pitch).detach().flatten().cpu().numpy()
145
+
146
+ # Flatten periodicity
147
+ periodicity = periodicity.flatten().cpu().numpy()
148
+
149
+ # Ignore confidently unvoiced pitch
150
+ pitch[periodicity < self.lower_bound] = UNVOICED
151
+
152
+ # Whiten pitch
153
+ mean, std = np.nanmean(pitch), np.nanstd(pitch)
154
+ pitch = (pitch - mean) / std
155
+
156
+ # Require high confidence to make predictions far from the mean
157
+ parabola = self.width * pitch ** 2 - self.width * self.stds ** 2
158
+ threshold = \
159
+ self.lower_bound + np.clip(parabola, 0, 1 - self.lower_bound)
160
+ threshold[np.isnan(threshold)] = self.lower_bound
161
+
162
+ # Apply hysteresis to prevent short, unconfident voiced regions
163
+ i = 0
164
+ while i < len(periodicity) - 1:
165
+
166
+ # Detect unvoiced to voiced transition
167
+ if periodicity[i] < threshold[i] and \
168
+ periodicity[i + 1] > threshold[i + 1]:
169
+
170
+ # Grow region until next unvoiced or end of array
171
+ start, end, keep = i + 1, i + 1, False
172
+ while end < len(periodicity) and \
173
+ periodicity[end] > threshold[end]:
174
+ if periodicity[end] > self.upper_bound:
175
+ keep = True
176
+ end += 1
177
+
178
+ # Force unvoiced if we didn't pass the confidence required by
179
+ # the hysteresis
180
+ if not keep:
181
+ threshold[start:end] = 1
182
+
183
+ i = end
184
+
185
+ else:
186
+ i += 1
187
+
188
+ # Remove pitch with low periodicity
189
+ pitch[periodicity < threshold] = UNVOICED
190
+
191
+ # Unwhiten
192
+ pitch = pitch * std + mean
193
+
194
+ # Convert to Hz
195
+ pitch = torch.tensor(2 ** pitch, device=device)[None, :]
196
+
197
+ # Optionally return threshold
198
+ if self.return_threshold:
199
+ return pitch, torch.tensor(threshold, device=device)
200
+
201
+ return pitch
202
+
203
+
204
+ ###############################################################################
205
+ # Periodicity thresholding methods
206
+ ###############################################################################
207
+
208
+
209
+ class Silence:
210
+ """Set periodicity to zero in silent regions"""
211
+
212
+ def __init__(self, value=-60):
213
+ self.value = value
214
+ self.a_weighted_weights = self.perceptual_weights()
215
+ def perceptual_weights(self):
216
+ """A-weighted frequency-dependent perceptual loudness weights"""
217
+ frequencies = librosa.fft_frequencies(sr=SAMPLE_RATE,n_fft=WINDOW_SIZE)
218
+
219
+ # A warning is raised for nearly inaudible frequencies, but it ends up
220
+ # defaulting to -100 db. That default is fine for our purposes.
221
+ with warnings.catch_warnings():
222
+ warnings.simplefilter('ignore', RuntimeWarning)
223
+ return librosa.A_weighting(frequencies)[:, None] - REF_DB
224
+
225
+ def a_weighted(self,audio, sample_rate, hop_length=None, pad=True):
226
+ """Retrieve the per-frame loudness"""
227
+ # Save device
228
+ device = audio.device
229
+
230
+ # Default hop length of 10 ms
231
+ hop_length = sample_rate // 100 if hop_length is None else hop_length
232
+
233
+ # Convert to numpy
234
+ audio = audio.detach().cpu().numpy().squeeze(0)
235
+
236
+ # Take stft
237
+ stft = librosa.stft(audio,
238
+ n_fft=WINDOW_SIZE,
239
+ hop_length=hop_length,
240
+ win_length=WINDOW_SIZE,
241
+ center=pad,
242
+ pad_mode='constant')
243
+
244
+ # Compute magnitude on db scale
245
+ db = librosa.amplitude_to_db(np.abs(stft))
246
+
247
+ # Apply A-weighting
248
+ weighted = db + self.a_weighted_weights
249
+
250
+ # Threshold
251
+ weighted[weighted < MIN_DB] = MIN_DB
252
+
253
+ # Average over weighted frequencies
254
+ return torch.from_numpy(weighted.mean(axis=0)).float().to(device)[None]
255
+
256
+ def __call__(self,
257
+ periodicity,
258
+ audio,
259
+ sample_rate=SAMPLE_RATE,
260
+ hop_length=None,
261
+ pad=True):
262
+ # Don't modify in-place
263
+ periodicity = torch.clone(periodicity)
264
+
265
+ # Compute loudness
266
+ loudness = self.a_weighted(
267
+ audio, sample_rate, hop_length, pad)
268
+
269
+ # Threshold silence
270
+ periodicity[loudness < self.value] = 0.
271
+
272
+ return periodicity
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+ ###############################################################################
281
+ # Sequence filters
282
+ ###############################################################################
283
+
284
+
285
+ def mean(signals, win_length=9):
286
+ """Averave filtering for signals containing nan values
287
+
288
+ Arguments
289
+ signals (torch.tensor (shape=(batch, time)))
290
+ The signals to filter
291
+ win_length
292
+ The size of the analysis window
293
+
294
+ Returns
295
+ filtered (torch.tensor (shape=(batch, time)))
296
+ """
297
+
298
+ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
299
+ signals = signals.unsqueeze(1)
300
+
301
+ # Apply the mask by setting masked elements to zero, or make NaNs zero
302
+ mask = ~torch.isnan(signals)
303
+ masked_x = torch.where(mask, signals, torch.zeros_like(signals))
304
+
305
+ # Create a ones kernel with the same number of channels as the input tensor
306
+ ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
307
+
308
+ # Perform sum pooling
309
+ sum_pooled = F.conv1d(
310
+ masked_x,
311
+ ones_kernel,
312
+ stride=1,
313
+ padding=win_length // 2,
314
+ )
315
+
316
+ # Count the non-masked (valid) elements in each pooling window
317
+ valid_count = F.conv1d(
318
+ mask.float(),
319
+ ones_kernel,
320
+ stride=1,
321
+ padding=win_length // 2,
322
+ )
323
+ valid_count = valid_count.clamp(min=1) # Avoid division by zero
324
+
325
+ # Perform masked average pooling
326
+ avg_pooled = sum_pooled / valid_count
327
+
328
+ # Fill zero values with NaNs
329
+ avg_pooled[avg_pooled == 0] = float("nan")
330
+
331
+ return avg_pooled.squeeze(1)
332
+
333
+
334
+ def median(signals, win_length):
335
+ """Median filtering for signals containing nan values
336
+
337
+ Arguments
338
+ signals (torch.tensor (shape=(batch, time)))
339
+ The signals to filter
340
+ win_length
341
+ The size of the analysis window
342
+
343
+ Returns
344
+ filtered (torch.tensor (shape=(batch, time)))
345
+ """
346
+
347
+ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
348
+ signals = signals.unsqueeze(1)
349
+
350
+ mask = ~torch.isnan(signals)
351
+ masked_x = torch.where(mask, signals, torch.zeros_like(signals))
352
+ padding = win_length // 2
353
+
354
+ x = F.pad(masked_x, (padding, padding), mode="reflect")
355
+ mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)
356
+
357
+ x = x.unfold(2, win_length, 1)
358
+ mask = mask.unfold(2, win_length, 1)
359
+
360
+ x = x.contiguous().view(x.size()[:3] + (-1,))
361
+ mask = mask.contiguous().view(mask.size()[:3] + (-1,))
362
+
363
+ # Combine the mask with the input tensor
364
+ x_masked = torch.where(mask.bool(), x.float(), float("inf")).to(x)
365
+
366
+ # Sort the masked tensor along the last dimension
367
+ x_sorted, _ = torch.sort(x_masked, dim=-1)
368
+
369
+ # Compute the count of non-masked (valid) values
370
+ valid_count = mask.sum(dim=-1)
371
+
372
+ # Calculate the index of the median value for each pooling window
373
+ median_idx = ((valid_count - 1) // 2).clamp(min=0)
374
+
375
+ # Gather the median values using the calculated indices
376
+ median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
377
+
378
+ # Fill infinite values with NaNs
379
+ median_pooled[torch.isinf(median_pooled)] = float("nan")
13
380
 
14
- class PredictProgressCallback(Callback):
15
- def __init__(self, total_batches,progress_callback: Callable[[int, int], None] = None):
381
+ return median_pooled.squeeze(1)
382
+
383
+
384
+ ###############################################################################
385
+ # Utilities
386
+ ###############################################################################
387
+
388
+
389
+ def nanfilter(signals, win_length, filter_fn):
390
+ """Filters a sequence, ignoring nan values
391
+
392
+ Arguments
393
+ signals (torch.tensor (shape=(batch, time)))
394
+ The signals to filter
395
+ win_length
396
+ The size of the analysis window
397
+ filter_fn (function)
398
+ The function to use for filtering
399
+
400
+ Returns
401
+ filtered (torch.tensor (shape=(batch, time)))
402
+ """
403
+ # Output buffer
404
+ filtered = torch.empty_like(signals)
405
+
406
+ # Loop over frames
407
+ for i in range(signals.size(1)):
408
+
409
+ # Get analysis window bounds
410
+ start = max(0, i - win_length // 2)
411
+ end = min(signals.size(1), i + win_length // 2 + 1)
412
+
413
+ # Apply filter to window
414
+ filtered[:, i] = filter_fn(signals[:, start:end])
415
+
416
+ return filtered
417
+
418
+
419
+ def nanmean(signals):
420
+ """Computes the mean, ignoring nans
421
+
422
+ Arguments
423
+ signals (torch.tensor [shape=(batch, time)])
424
+ The signals to filter
425
+
426
+ Returns
427
+ filtered (torch.tensor [shape=(batch, time)])
428
+ """
429
+ signals = signals.clone()
430
+
431
+ # Find nans
432
+ nans = torch.isnan(signals)
433
+
434
+ # Set nans to 0.
435
+ signals[nans] = 0.
436
+
437
+ # Compute average
438
+ return signals.sum(dim=1) / (~nans).float().sum(dim=1)
439
+
440
+
441
+ def nanmedian(signals):
442
+ """Computes the median, ignoring nans
443
+
444
+ Arguments
445
+ signals (torch.tensor [shape=(batch, time)])
446
+ The signals to filter
447
+
448
+ Returns
449
+ filtered (torch.tensor [shape=(batch, time)])
450
+ """
451
+ # Find nans
452
+ nans = torch.isnan(signals)
453
+
454
+ # Compute median for each slice
455
+ medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)]
456
+
457
+ # Stack results
458
+ return torch.tensor(medians, dtype=signals.dtype, device=signals.device)
459
+
460
+
461
+ def nanmedian1d(signal):
462
+ """Computes the median. If signal is empty, returns torch.nan
463
+
464
+ Arguments
465
+ signal (torch.tensor [shape=(time,)])
466
+
467
+ Returns
468
+ median (torch.tensor [shape=(1,)])
469
+ """
470
+ return torch.median(signal) if signal.numel() else np.nan
471
+
472
+
473
+ def dither(cents):
474
+ """Dither the predicted pitch in cents to remove quantization error"""
475
+ noise = triang.rvs(c=0.5,loc=-CENTS_PER_BIN,scale=2 * CENTS_PER_BIN,size=cents.size())
476
+ return cents + cents.new_tensor(noise)
477
+
478
+ def periodicity(probabilities, bins):
479
+ """Computes the periodicity from the network output and pitch bins"""
480
+ # shape=(batch * time / hop_length, 360)
481
+ probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
482
+
483
+ # shape=(batch * time / hop_length, 1)
484
+ bins_stacked = bins.reshape(-1, 1).to(torch.int64)
485
+
486
+ # Use maximum logit over pitch bins as periodicity
487
+ periodicity = probs_stacked.gather(1, bins_stacked)
488
+
489
+ # shape=(batch, time / hop_length)
490
+ return periodicity.reshape(probabilities.size(0), probabilities.size(2))
491
+
492
+
493
+
494
+
495
+
496
+
497
+
498
+
499
+
500
+
501
+
502
+
503
+
504
+ class CrepeTorch(torch.nn.Module):
505
+
506
+ def __init__(self, model_type='full',model_path=None):
16
507
  super().__init__()
17
- self.total_batches = total_batches
18
- self.progress_callback = progress_callback
19
- def on_predict_begin(self, logs=None):
20
- if self.progress_callback:
21
- self.progress_callback(0,self.total_batches)
22
- def on_predict_batch_end(self, batch, logs=None):
23
- if self.progress_callback:
24
- self.progress_callback(batch,self.total_batches)
25
- def on_predict_end(self, logs=None):
26
- if self.progress_callback:
27
- self.progress_callback(self.total_batches,self.total_batches)
508
+ xx, yy = np.meshgrid(range(360), range(360))
509
+ transition = np.maximum(12 - abs(xx - yy), 0)
510
+ self.viterbi_transition = transition / transition.sum(axis=1, keepdims=True)
511
+ model_type_importance = {'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32}[model_type]
512
+ out_channels = [n * model_type_importance for n in [32, 4, 4, 4, 8, 16]]
513
+ in_channels = [n * model_type_importance for n in [32, 4, 4, 4, 8]]
514
+ in_channels.insert(0,1)
515
+ self.in_features = 64*model_type_importance
516
+ # Shared layer parameters
517
+ kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
518
+ strides = [(4, 1)] + 5 * [(1, 1)]
519
+
520
+ # Overload with eps and momentum conversion given by MMdnn
521
+ batch_norm_fn = partial(torch.nn.BatchNorm2d,eps=0.0010000000474974513,momentum=0.0)
522
+
523
+ # Layer definitions
524
+ self.conv1 = torch.nn.Conv2d(
525
+ in_channels=in_channels[0],
526
+ out_channels=out_channels[0],
527
+ kernel_size=kernel_sizes[0],
528
+ stride=strides[0])
529
+ self.conv1_BN = batch_norm_fn(
530
+ num_features=out_channels[0])
531
+
532
+ self.conv2 = torch.nn.Conv2d(
533
+ in_channels=in_channels[1],
534
+ out_channels=out_channels[1],
535
+ kernel_size=kernel_sizes[1],
536
+ stride=strides[1])
537
+ self.conv2_BN = batch_norm_fn(
538
+ num_features=out_channels[1])
539
+
540
+ self.conv3 = torch.nn.Conv2d(
541
+ in_channels=in_channels[2],
542
+ out_channels=out_channels[2],
543
+ kernel_size=kernel_sizes[2],
544
+ stride=strides[2])
545
+ self.conv3_BN = batch_norm_fn(
546
+ num_features=out_channels[2])
547
+
548
+ self.conv4 = torch.nn.Conv2d(
549
+ in_channels=in_channels[3],
550
+ out_channels=out_channels[3],
551
+ kernel_size=kernel_sizes[3],
552
+ stride=strides[3])
553
+ self.conv4_BN = batch_norm_fn(
554
+ num_features=out_channels[3])
555
+
556
+ self.conv5 = torch.nn.Conv2d(
557
+ in_channels=in_channels[4],
558
+ out_channels=out_channels[4],
559
+ kernel_size=kernel_sizes[4],
560
+ stride=strides[4])
561
+ self.conv5_BN = batch_norm_fn(
562
+ num_features=out_channels[4])
563
+
564
+ self.conv6 = torch.nn.Conv2d(
565
+ in_channels=in_channels[5],
566
+ out_channels=out_channels[5],
567
+ kernel_size=kernel_sizes[5],
568
+ stride=strides[5])
569
+ self.conv6_BN = batch_norm_fn(
570
+ num_features=out_channels[5])
571
+
572
+ self.classifier = torch.nn.Linear(
573
+ in_features=self.in_features,
574
+ out_features=PITCH_BINS)
575
+ if not model_path:
576
+ model_path = hf_hub_download("shethjenil/Audio2Midi_Models",f"crepe_{model_type}.pt")
577
+ self.load_state_dict(torch.load(model_path))
578
+ self.eval()
579
+
580
+ def forward(self, x, embed=False):
581
+ # Forward pass through first five layers
582
+ x = self.embed(x)
583
+
584
+ if embed:
585
+ return x
586
+
587
+ # Forward pass through layer six
588
+ x = self.layer(x, self.conv6, self.conv6_BN)
589
+
590
+ # shape=(batch, self.in_features)
591
+ x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features)
592
+
593
+ # Compute logits
594
+ return torch.sigmoid(self.classifier(x))
595
+
596
+ def embed(self, x):
597
+ """Map input audio to pitch embedding"""
598
+ # shape=(batch, 1, 1024, 1)
599
+ x = x[:, None, :, None]
28
600
 
601
+ # Forward pass through first five layers
602
+ x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254))
603
+ x = self.layer(x, self.conv2, self.conv2_BN)
604
+ x = self.layer(x, self.conv3, self.conv3_BN)
605
+ x = self.layer(x, self.conv4, self.conv4_BN)
606
+ x = self.layer(x, self.conv5, self.conv5_BN)
607
+ return x
608
+
609
+ def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
610
+ """Forward pass through one layer"""
611
+ x = F.pad(x, padding)
612
+ x = conv(x)
613
+ x = F.relu(x)
614
+ x = batch_norm(x)
615
+ return F.max_pool2d(x, (2, 1), (2, 1))
616
+
617
+ def viterbi(self,logits):
618
+ """Sample observations using viterbi decoding"""
619
+ # Normalize logits
620
+ with torch.no_grad():
621
+ probs = torch.nn.functional.softmax(logits, dim=1)
622
+
623
+ # Convert to numpy
624
+ sequences = probs.cpu().numpy()
625
+
626
+ # Perform viterbi decoding
627
+ bins = np.array([
628
+ librosa.sequence.viterbi(sequence, self.viterbi_transition).astype(np.int64)
629
+ for sequence in sequences])
630
+
631
+ # Convert to pytorch
632
+ bins = torch.tensor(bins, device=probs.device)
633
+
634
+ # Convert to frequency in Hz
635
+ return bins, bins_to_frequency(bins)
636
+
637
+ def get_device(self):
638
+ return next(self.parameters()).device
639
+
640
+ def preprocess(self,audio,sample_rate,hop_length=None,batch_size=None,pad=True):
641
+ """Convert audio to model input
642
+
643
+ Arguments
644
+ audio (torch.tensor [shape=(1, time)])
645
+ The audio signals
646
+ sample_rate (int)
647
+ The sampling rate in Hz
648
+ hop_length (int)
649
+ The hop_length in samples
650
+ batch_size (int)
651
+ The number of frames per batch
652
+ pad (bool)
653
+ Whether to zero-pad the audio
654
+
655
+ Returns
656
+ frames (torch.tensor [shape=(1 + int(time // hop_length), 1024)])
657
+ """
658
+ # Default hop length of 10 ms
659
+ hop_length = sample_rate // 100 if hop_length is None else hop_length
660
+
661
+ # Get total number of frames
662
+
663
+ # Maybe pad
664
+ if pad:
665
+ total_frames = 1 + int(audio.size(1) // hop_length)
666
+ audio = torch.nn.functional.pad(
667
+ audio,
668
+ (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
669
+ else:
670
+ total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
671
+
672
+ # Default to running all frames in a single batch
673
+ batch_size = total_frames if batch_size is None else batch_size
674
+
675
+ # Generate batches
676
+ for i in range(0, total_frames, batch_size):
677
+
678
+ # Batch indices
679
+ start = max(0, i * hop_length)
680
+ end = min(audio.size(1),
681
+ (i + batch_size - 1) * hop_length + WINDOW_SIZE)
682
+
683
+ # Chunk
684
+ frames = torch.nn.functional.unfold(
685
+ audio[:, None, None, start:end],
686
+ kernel_size=(1, WINDOW_SIZE),
687
+ stride=(1, hop_length))
688
+
689
+ # shape=(1 + int(time / hop_length, 1024)
690
+ frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE)
691
+
692
+ # Place on device
693
+ frames = frames.to(self.get_device())
694
+
695
+ # Mean-center
696
+ frames -= frames.mean(dim=1, keepdim=True)
697
+
698
+ # Scale
699
+ # Note: during silent frames, this produces very large values. But
700
+ # this seems to be what the network expects.
701
+ frames /= torch.max(torch.tensor(1e-10, device=frames.device),frames.std(dim=1, keepdim=True))
702
+
703
+ yield frames
704
+
705
+ def postprocess(self,probabilities,fmin=0.,fmax=MAX_FMAX,return_periodicity=False):
706
+ """Convert model output to F0 and periodicity
707
+
708
+ Arguments
709
+ probabilities (torch.tensor [shape=(1, 360, time / hop_length)])
710
+ The probabilities for each pitch bin inferred by the network
711
+ fmin (float)
712
+ The minimum allowable frequency in Hz
713
+ fmax (float)
714
+ The maximum allowable frequency in Hz
715
+ viterbi (bool)
716
+ Whether to use viterbi decoding
717
+ return_periodicity (bool)
718
+ Whether to also return the network confidence
719
+
720
+ Returns
721
+ pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
722
+ periodicity (torch.tensor [shape=(1, 1 + int(time // hop_length))])
723
+ """
724
+ # Sampling is non-differentiable, so remove from graph
725
+ probabilities = probabilities.detach()
726
+
727
+ # Convert frequency range to pitch bin range
728
+ minidx = frequency_to_bins(torch.tensor(fmin))
729
+ maxidx = frequency_to_bins(torch.tensor(fmax),
730
+ torch.ceil)
731
+
732
+ # Remove frequencies outside of allowable range
733
+ probabilities[:, :minidx] = -float('inf')
734
+ probabilities[:, maxidx:] = -float('inf')
735
+
736
+ # Perform argmax or viterbi sampling
737
+ bins, pitch = self.viterbi(probabilities)
738
+
739
+ if not return_periodicity:
740
+ return pitch
741
+
742
+ # Compute periodicity from probabilities and decoded pitch bins
743
+ return pitch, periodicity(probabilities, bins)
744
+
745
+ def predict(self,audio,sample_rate,hop_length=None,fmin=50.,fmax=MAX_FMAX,return_periodicity=False,batch_size=None,pad=True):
746
+ """Performs pitch estimation
747
+
748
+ Arguments
749
+ audio (torch.tensor [shape=(1, time)])
750
+ The audio signal
751
+ sample_rate (int)
752
+ The sampling rate in Hz
753
+ hop_length (int)
754
+ The hop_length in samples
755
+ fmin (float)
756
+ The minimum allowable frequency in Hz
757
+ fmax (float)
758
+ The maximum allowable frequency in Hz
759
+ return_periodicity (bool)
760
+ Whether to also return the network confidence
761
+ batch_size (int)
762
+ The number of frames per batch
763
+ pad (bool)
764
+ Whether to zero-pad the audio
765
+
766
+ Returns
767
+ pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
768
+ (Optional) periodicity (torch.tensor
769
+ [shape=(1, 1 + int(time // hop_length))])
770
+ """
771
+
772
+ results = []
773
+ with torch.no_grad():
774
+ print("prediction started")
775
+ for frames in self.preprocess(audio,sample_rate,hop_length,batch_size,pad):
776
+ # shape=(batch, 360, time / hop_length)
777
+ result = self.postprocess(self.forward(frames, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2),fmin,fmax,return_periodicity)
778
+ if isinstance(result, tuple):
779
+ result = (result[0].to(audio.device),result[1].to(audio.device))
780
+ else:
781
+ result = result.to(audio.device)
782
+ results.append(result)
783
+ print("prediction finished")
784
+ if return_periodicity:
785
+ pitch, periodicity = zip(*results)
786
+ return torch.cat(pitch, 1), torch.cat(periodicity, 1)
787
+ return torch.cat(results, 1)
788
+
789
+ def predict_from_file(self,audio_file,hop_length=None,fmin=50.,fmax=MAX_FMAX,return_periodicity=False,batch_size=None,pad=True):
790
+ audio, sample_rate = librosa.load(audio_file, sr=SAMPLE_RATE)
791
+ return self.predict(torch.from_numpy(audio).unsqueeze(0),sample_rate,hop_length,fmin,fmax,return_periodicity,batch_size,pad)
792
+
793
+ def predict_from_file_to_file(self,audio_file,output_pitch_file,output_periodicity_file=None,hop_length=None,fmin=50.,fmax=MAX_FMAX,batch_size=None,pad=True):
794
+ prediction = self.predict_from_file(audio_file,hop_length,fmin,fmax,False,output_periodicity_file is not None,batch_size,pad)
795
+ if output_periodicity_file is not None:
796
+ torch.save(prediction[0].detach(), output_pitch_file)
797
+ torch.save(prediction[1].detach(), output_periodicity_file)
798
+ else:
799
+ torch.save(prediction.detach(), output_pitch_file)
800
+
801
+ def predict_from_files_to_files(self,audio_files,output_pitch_files,output_periodicity_files=None,hop_length=None,fmin=50.,fmax=MAX_FMAX,batch_size=None,pad=True):
802
+ if output_periodicity_files is None:
803
+ output_periodicity_files = len(audio_files) * [None]
804
+ for audio_file, output_pitch_file, output_periodicity_file in tqdm(zip(audio_files, output_pitch_files, output_periodicity_files), desc='torchcrepe', dynamic_ncols=True):
805
+ self.predict_from_file_to_file(audio_file,output_pitch_file,None,output_periodicity_file,hop_length,fmin,fmax,batch_size,pad)
806
+
807
+ def embedding(self,audio,sample_rate,hop_length=None,batch_size=None,pad=True):
808
+ """Embeds audio to the output of CREPE's fifth maxpool layer
809
+
810
+ Arguments
811
+ audio (torch.tensor [shape=(1, time)])
812
+ The audio signals
813
+ sample_rate (int)
814
+ The sampling rate in Hz
815
+ hop_length (int)
816
+ The hop_length in samples
817
+ batch_size (int)
818
+ The number of frames per batch
819
+ pad (bool)
820
+ Whether to zero-pad the audio
821
+
822
+ Returns
823
+ embedding (torch.tensor [shape=(1,
824
+ 1 + int(time // hop_length), 32, -1)])
825
+ """
826
+ # shape=(batch, time / hop_length, 32, embedding_size)
827
+ with torch.no_grad():
828
+ return torch.cat([self.forward(frames, embed=True).reshape(audio.size(0), frames.size(0), 32, -1).to(audio.device) for frames in self.preprocess(audio,sample_rate,hop_length,batch_size,pad)], 1)
829
+
830
+ def embedding_from_file(self,audio_file,hop_length=None,batch_size=None,pad=True):
831
+ audio, sample_rate = librosa.load(audio_file, sr=SAMPLE_RATE)
832
+ return self.embed(torch.from_numpy(audio).unsqueeze(0),sample_rate,hop_length,batch_size,pad)
833
+
834
+ def embedding_from_file_to_file(self,audio_file,output_file,hop_length=None,batch_size=None,pad=True):
835
+ with torch.no_grad():
836
+ torch.save(self.embed_from_file(audio_file,hop_length,batch_size,pad).detach(), output_file)
837
+
838
+ def embedding_from_files_to_files(self,audio_files,output_files,hop_length=None,batch_size=None,pad=True):
839
+ for audio_file, output_file in tqdm(zip(audio_files, output_files), desc='torchcrepe', dynamic_ncols=True):
840
+ self.embed_from_file_to_file(audio_file,output_file,hop_length,batch_size,pad)
841
+
842
+
843
+
844
+
845
+
846
+
847
+
848
+
849
+
850
+
851
+
852
+
853
+ from hmmlearn.hmm import CategoricalHMM
854
+ from typing import Callable
855
+ from numpy.lib.stride_tricks import as_strided
856
+ from pretty_midi_fix import PrettyMIDI , PitchBend , Note ,Instrument
857
+ from huggingface_hub import hf_hub_download
858
+ from torch.utils.data import DataLoader, TensorDataset
29
859
 
30
860
  class Crepe():
861
+
31
862
  def __init__(self,model_type="full",model_path=None):
32
- if not model_path:
33
- model_path = hf_hub_download("shethjenil/Audio2Midi_Models",f"crepe_{model_type}.h5")
34
- model_type_importance = {'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32}[model_type]
35
- filters = [n * model_type_importance for n in [32, 4, 4, 4, 8, 16]]
36
- widths = [512, 64, 64, 64, 64, 64]
37
- strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
38
- x = Input(shape=(1024,), name='input', dtype='float32')
39
- y = Reshape(target_shape=(1024, 1, 1), name='input-reshape')(x)
40
- layers = [1, 2, 3, 4, 5, 6]
41
- for l, f, w, s in zip(layers, filters, widths, strides):
42
- y = Conv2D(f, (w, 1), strides=s, padding='same', activation='relu', name="conv%d" % l)(y)
43
- y = BatchNormalization(name="conv%d-BN" % l)(y)
44
- y = MaxPool2D(pool_size=(2, 1), strides=None, padding='valid', name="conv%d-maxpool" % l)(y)
45
- y = Dropout(0.25, name="conv%d-dropout" % l)(y)
46
- y = Permute((2, 1, 3), name="transpose")(y)
47
- y = Flatten(name="flatten")(y)
48
- y = Dense(360, activation='sigmoid', name="classifier")(y)
49
- self.model = Model(inputs=x, outputs=y)
50
- self.model.load_weights(model_path)
51
- self.model.compile('adam', 'binary_crossentropy')
863
+ self.model = CrepeTorch(model_type,model_path)
52
864
  self.cents_mapping=(np.linspace(0, 7180, 360) + 1997.3794084376191)
53
865
 
54
866
  def to_local_average_cents(self, salience, center=None):
867
+ if isinstance(salience, torch.Tensor):
868
+ salience = salience.numpy()
869
+
55
870
  if salience.ndim == 1:
56
871
  if center is None:
57
872
  center = int(np.argmax(salience))
@@ -66,6 +881,8 @@ class Crepe():
66
881
  raise Exception("label should be either 1d or 2d ndarray")
67
882
 
68
883
  def to_viterbi_cents(self,salience):
884
+ if isinstance(salience, torch.Tensor):
885
+ salience = salience.numpy()
69
886
  starting = np.ones(360) / 360
70
887
  xx, yy = np.meshgrid(range(360), range(360))
71
888
  transition = np.maximum(12 - abs(xx - yy), 0)
@@ -87,11 +904,21 @@ class Crepe():
87
904
  frames = frames.transpose().copy()
88
905
  frames -= np.mean(frames, axis=1)[:, np.newaxis]
89
906
  frames /= np.clip(np.std(frames, axis=1)[:, np.newaxis], 1e-8, None)
90
- return self.model.predict(frames,batch_size,0,callbacks=[PredictProgressCallback(math_ceil(len(frames) / batch_size),progress_callback)])
91
-
907
+ device = self.model.get_device()
908
+ all_outputs = []
909
+ all_batch = list(DataLoader(TensorDataset(torch.from_numpy(frames)), batch_size=batch_size, shuffle=False))
910
+ total_batch = len(all_batch)
911
+ with torch.no_grad():
912
+ for i , batch in enumerate(all_batch):
913
+ inputs = batch[0].to(device)
914
+ outputs = self.model(inputs)
915
+ all_outputs.append(outputs.cpu())
916
+ progress_callback(i,total_batch)
917
+ return torch.cat(all_outputs, dim=0)
918
+
92
919
  def model_predict(self,audio:np.ndarray,viterbi, center, step_size,progress_callback,batch_size):
93
920
  activation = self.get_activation(audio.astype(np.float32), center, step_size,progress_callback,batch_size)
94
- confidence = activation.max(axis=1)
921
+ confidence = activation.max(axis=1).values # Access the values from the named tuple
95
922
  cents = self.to_viterbi_cents(activation) if viterbi else self.to_local_average_cents(activation)
96
923
  frequency = 10 * 2 ** (cents / 1200)
97
924
  frequency[np.isnan(frequency)] = 0
@@ -99,7 +926,7 @@ class Crepe():
99
926
  return time, frequency, confidence
100
927
 
101
928
  def predict(self,audio_path,viterbi=False, center=True, step_size=10,min_confidence=0.8,batch_size=32,progress_callback: Callable[[int, int], None] = None,output_file= "output.mid"):
102
- time, frequency, confidence = self.model_predict(librosa_load(audio_path, sr=16000, mono=True)[0],viterbi,center,step_size,progress_callback,batch_size)
929
+ time, frequency, confidence = self.model_predict(librosa.load(audio_path, sr=16000, mono=True)[0],viterbi,center,step_size,progress_callback,batch_size)
103
930
  mask = confidence > min_confidence
104
931
  times = time[mask]
105
932
  frequencies = frequency[mask]