audio2midi 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2604 @@
1
+ import torch
2
+ import copy
3
+ import os
4
+ import numpy as np
5
+ import pretty_midi_fix
6
+ from librosa.core import resample as librosa_resample
7
+ from scipy.interpolate import interp1d
8
+ from json import load as json_load , dumps as json_dumps
9
+ from math import log as math_log
10
+ from typing import Optional, Union
11
+ from torch import nn
12
+ from librosa import load as librosa_load
13
+ from huggingface_hub import snapshot_download
14
+ from essentia.standard import RhythmExtractor2013
15
+ from transformers.generation import GenerationConfig , GenerationMixin
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_layers import GradientCheckpointingLayer
18
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
19
+ from transformers.utils import is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling , TensorType, to_numpy
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
22
+ from transformers.feature_extraction_utils import BatchFeature
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.configuration_utils import PretrainedConfig
25
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
26
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
+ from transformers.audio_utils import mel_filter_bank, spectrogram
28
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
29
+ from transformers.feature_extraction_utils import BatchFeature
30
+ from transformers.feature_extraction_utils import BatchFeature
31
+ from transformers.tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy
32
+
33
+ if is_torch_flex_attn_available():
34
+ from torch.nn.attention.flex_attention import BlockMask
35
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
36
+
37
+
38
+ class Pop2PianoConfig(PretrainedConfig):
39
+ model_type = "pop2piano"
40
+ keys_to_ignore_at_inference = ["past_key_values"]
41
+ def __init__(
42
+ self,
43
+ vocab_size=2400,
44
+ composer_vocab_size=21,
45
+ d_model=512,
46
+ d_kv=64,
47
+ d_ff=2048,
48
+ num_layers=6,
49
+ num_decoder_layers=None,
50
+ num_heads=8,
51
+ relative_attention_num_buckets=32,
52
+ relative_attention_max_distance=128,
53
+ dropout_rate=0.1,
54
+ layer_norm_epsilon=1e-6,
55
+ initializer_factor=1.0,
56
+ feed_forward_proj="gated-gelu", # noqa
57
+ is_encoder_decoder=True,
58
+ use_cache=True,
59
+ pad_token_id=0,
60
+ eos_token_id=1,
61
+ dense_act_fn="relu",
62
+ **kwargs,
63
+ ):
64
+ self.vocab_size = vocab_size
65
+ self.composer_vocab_size = composer_vocab_size
66
+ self.d_model = d_model
67
+ self.d_kv = d_kv
68
+ self.d_ff = d_ff
69
+ self.num_layers = num_layers
70
+ self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
71
+ self.num_heads = num_heads
72
+ self.relative_attention_num_buckets = relative_attention_num_buckets
73
+ self.relative_attention_max_distance = relative_attention_max_distance
74
+ self.dropout_rate = dropout_rate
75
+ self.layer_norm_epsilon = layer_norm_epsilon
76
+ self.initializer_factor = initializer_factor
77
+ self.feed_forward_proj = feed_forward_proj
78
+ self.use_cache = use_cache
79
+ self.dense_act_fn = dense_act_fn
80
+ self.is_gated_act = self.feed_forward_proj.split("-")[0] == "gated"
81
+ self.hidden_size = self.d_model
82
+ self.num_attention_heads = num_heads
83
+ self.num_hidden_layers = num_layers
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ eos_token_id=eos_token_id,
88
+ is_encoder_decoder=is_encoder_decoder,
89
+ **kwargs,
90
+ )
91
+
92
+
93
+ # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano
94
+ class Pop2PianoLayerNorm(nn.Module):
95
+ def __init__(self, hidden_size, eps=1e-6):
96
+ """
97
+ Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
98
+ """
99
+ super().__init__()
100
+ self.weight = nn.Parameter(torch.ones(hidden_size))
101
+ self.variance_epsilon = eps
102
+
103
+ def forward(self, hidden_states):
104
+ # Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
105
+ # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
106
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
107
+ # half-precision inputs is done in fp32
108
+
109
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
110
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
111
+
112
+ # convert into half-precision if necessary
113
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
114
+ hidden_states = hidden_states.to(self.weight.dtype)
115
+
116
+ return self.weight * hidden_states
117
+
118
+ # from apex.normalization import FusedRMSNorm
119
+ # Pop2PianoLayerNorm = FusedRMSNorm # noqa
120
+ # # Other Approach
121
+
122
+ # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Pop2Piano,t5->pop2piano
123
+ class Pop2PianoDenseActDense(nn.Module):
124
+ def __init__(self, config: Pop2PianoConfig):
125
+ super().__init__()
126
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
127
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
128
+ self.dropout = nn.Dropout(config.dropout_rate)
129
+ self.act = ACT2FN[config.dense_act_fn]
130
+
131
+ def forward(self, hidden_states):
132
+ hidden_states = self.wi(hidden_states)
133
+ hidden_states = self.act(hidden_states)
134
+ hidden_states = self.dropout(hidden_states)
135
+ if (
136
+ isinstance(self.wo.weight, torch.Tensor)
137
+ and hidden_states.dtype != self.wo.weight.dtype
138
+ and self.wo.weight.dtype != torch.int8
139
+ ):
140
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
141
+ hidden_states = self.wo(hidden_states)
142
+ return hidden_states
143
+
144
+
145
+ # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pop2Piano
146
+ class Pop2PianoDenseGatedActDense(nn.Module):
147
+ def __init__(self, config: Pop2PianoConfig):
148
+ super().__init__()
149
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
150
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
151
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
152
+ self.dropout = nn.Dropout(config.dropout_rate)
153
+ self.act = ACT2FN[config.dense_act_fn]
154
+
155
+ def forward(self, hidden_states):
156
+ hidden_gelu = self.act(self.wi_0(hidden_states))
157
+ hidden_linear = self.wi_1(hidden_states)
158
+ hidden_states = hidden_gelu * hidden_linear
159
+ hidden_states = self.dropout(hidden_states)
160
+
161
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
162
+ # See https://github.com/huggingface/transformers/issues/20287
163
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
164
+ if (
165
+ isinstance(self.wo.weight, torch.Tensor)
166
+ and hidden_states.dtype != self.wo.weight.dtype
167
+ and self.wo.weight.dtype != torch.int8
168
+ ):
169
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
170
+
171
+ hidden_states = self.wo(hidden_states)
172
+ return hidden_states
173
+
174
+
175
+ # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Pop2Piano
176
+ class Pop2PianoLayerFF(nn.Module):
177
+ def __init__(self, config: Pop2PianoConfig):
178
+ super().__init__()
179
+ if config.is_gated_act:
180
+ self.DenseReluDense = Pop2PianoDenseGatedActDense(config)
181
+ else:
182
+ self.DenseReluDense = Pop2PianoDenseActDense(config)
183
+
184
+ self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
185
+ self.dropout = nn.Dropout(config.dropout_rate)
186
+
187
+ def forward(self, hidden_states):
188
+ forwarded_states = self.layer_norm(hidden_states)
189
+ forwarded_states = self.DenseReluDense(forwarded_states)
190
+ hidden_states = hidden_states + self.dropout(forwarded_states)
191
+ return hidden_states
192
+
193
+
194
+ # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano
195
+ class Pop2PianoAttention(nn.Module):
196
+ def __init__(
197
+ self,
198
+ config: Pop2PianoConfig,
199
+ has_relative_attention_bias=False,
200
+ layer_idx: Optional[int] = None,
201
+ ):
202
+ super().__init__()
203
+ self.is_decoder = config.is_decoder
204
+ self.has_relative_attention_bias = has_relative_attention_bias
205
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
206
+ self.relative_attention_max_distance = config.relative_attention_max_distance
207
+ self.d_model = config.d_model
208
+ self.key_value_proj_dim = config.d_kv
209
+ self.n_heads = config.num_heads
210
+ self.dropout = config.dropout_rate
211
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
212
+ self.layer_idx = layer_idx
213
+ if layer_idx is None and self.is_decoder:
214
+ print(
215
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
216
+ "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
217
+ "when creating this class."
218
+ )
219
+
220
+ # Mesh TensorFlow initialization to avoid scaling before softmax
221
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
222
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
223
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
224
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
225
+
226
+ if self.has_relative_attention_bias:
227
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
228
+ self.pruned_heads = set()
229
+ self.gradient_checkpointing = False
230
+
231
+ def prune_heads(self, heads):
232
+ if len(heads) == 0:
233
+ return
234
+ heads, index = find_pruneable_heads_and_indices(
235
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
236
+ )
237
+ # Prune linear layers
238
+ self.q = prune_linear_layer(self.q, index)
239
+ self.k = prune_linear_layer(self.k, index)
240
+ self.v = prune_linear_layer(self.v, index)
241
+ self.o = prune_linear_layer(self.o, index, dim=1)
242
+ # Update hyper params
243
+ self.n_heads = self.n_heads - len(heads)
244
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
245
+ self.pruned_heads = self.pruned_heads.union(heads)
246
+
247
+ @staticmethod
248
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
249
+ """
250
+ Adapted from Mesh Tensorflow:
251
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
252
+
253
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
254
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
255
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
256
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
257
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
258
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
259
+
260
+ Args:
261
+ relative_position: an int32 Tensor
262
+ bidirectional: a boolean - whether the attention is bidirectional
263
+ num_buckets: an integer
264
+ max_distance: an integer
265
+
266
+ Returns:
267
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
268
+ """
269
+ relative_buckets = 0
270
+ if bidirectional:
271
+ num_buckets //= 2
272
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
273
+ relative_position = torch.abs(relative_position)
274
+ else:
275
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
276
+ # now relative_position is in the range [0, inf)
277
+
278
+ # half of the buckets are for exact increments in positions
279
+ max_exact = num_buckets // 2
280
+ is_small = relative_position < max_exact
281
+
282
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
283
+ relative_position_if_large = max_exact + (
284
+ torch.log(relative_position.float() / max_exact)
285
+ / math_log(max_distance / max_exact)
286
+ * (num_buckets - max_exact)
287
+ ).to(torch.long)
288
+ relative_position_if_large = torch.min(
289
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
290
+ )
291
+
292
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
293
+ return relative_buckets
294
+
295
+ def compute_bias(self, query_length, key_length, device=None, cache_position=None):
296
+ """Compute binned relative position bias"""
297
+ if device is None:
298
+ device = self.relative_attention_bias.weight.device
299
+ if cache_position is None:
300
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
301
+ else:
302
+ context_position = cache_position[:, None].to(device)
303
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
304
+ relative_position = memory_position - context_position # shape (query_length, key_length)
305
+ relative_position_bucket = self._relative_position_bucket(
306
+ relative_position, # shape (query_length, key_length)
307
+ bidirectional=(not self.is_decoder),
308
+ num_buckets=self.relative_attention_num_buckets,
309
+ max_distance=self.relative_attention_max_distance,
310
+ )
311
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
312
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
313
+ return values
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states,
318
+ mask=None,
319
+ key_value_states=None,
320
+ position_bias=None,
321
+ past_key_value=None,
322
+ layer_head_mask=None,
323
+ query_length=None,
324
+ use_cache=False,
325
+ output_attentions=False,
326
+ cache_position=None,
327
+ ):
328
+ """
329
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
330
+ """
331
+ # Input is (batch_size, seq_length, dim)
332
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
333
+ batch_size, seq_length = hidden_states.shape[:2]
334
+
335
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
336
+ is_cross_attention = key_value_states is not None
337
+
338
+ query_states = self.q(hidden_states)
339
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
340
+
341
+ if past_key_value is not None:
342
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
343
+ if is_cross_attention:
344
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
345
+ curr_past_key_value = past_key_value.cross_attention_cache
346
+ else:
347
+ curr_past_key_value = past_key_value.self_attention_cache
348
+
349
+ current_states = key_value_states if is_cross_attention else hidden_states
350
+ if is_cross_attention and past_key_value is not None and is_updated:
351
+ # reuse k,v, cross_attentions
352
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
353
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
354
+ else:
355
+ key_states = self.k(current_states)
356
+ value_states = self.v(current_states)
357
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
358
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
359
+
360
+ if past_key_value is not None:
361
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
362
+ cache_position = cache_position if not is_cross_attention else None
363
+ key_states, value_states = curr_past_key_value.update(
364
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
365
+ )
366
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
367
+ if is_cross_attention:
368
+ past_key_value.is_updated[self.layer_idx] = True
369
+
370
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
371
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
372
+
373
+ if position_bias is None:
374
+ key_length = key_states.shape[-2]
375
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
376
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
377
+ if not self.has_relative_attention_bias:
378
+ position_bias = torch.zeros(
379
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
380
+ )
381
+ if self.gradient_checkpointing and self.training:
382
+ position_bias.requires_grad = True
383
+ else:
384
+ position_bias = self.compute_bias(
385
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
386
+ )
387
+ position_bias = position_bias[:, :, -seq_length:, :]
388
+
389
+ if mask is not None:
390
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
391
+ position_bias = position_bias + causal_mask
392
+
393
+ if self.pruned_heads:
394
+ mask = torch.ones(position_bias.shape[1])
395
+ mask[list(self.pruned_heads)] = 0
396
+ position_bias_masked = position_bias[:, mask.bool()]
397
+ else:
398
+ position_bias_masked = position_bias
399
+
400
+ scores += position_bias_masked
401
+
402
+ # (batch_size, n_heads, seq_length, key_length)
403
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
404
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
405
+
406
+ # Mask heads if we want to
407
+ if layer_head_mask is not None:
408
+ attn_weights = attn_weights * layer_head_mask
409
+
410
+ attn_output = torch.matmul(attn_weights, value_states)
411
+
412
+ attn_output = attn_output.transpose(1, 2).contiguous()
413
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
414
+ attn_output = self.o(attn_output)
415
+
416
+ outputs = (attn_output, past_key_value, position_bias)
417
+
418
+ if output_attentions:
419
+ outputs = outputs + (attn_weights,)
420
+ return outputs
421
+
422
+
423
+ # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano
424
+ class Pop2PianoLayerSelfAttention(nn.Module):
425
+ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
426
+ super().__init__()
427
+ self.SelfAttention = Pop2PianoAttention(
428
+ config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
429
+ )
430
+ self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
431
+ self.dropout = nn.Dropout(config.dropout_rate)
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states,
436
+ attention_mask=None,
437
+ position_bias=None,
438
+ layer_head_mask=None,
439
+ past_key_value=None,
440
+ use_cache=False,
441
+ output_attentions=False,
442
+ cache_position=None,
443
+ ):
444
+ normed_hidden_states = self.layer_norm(hidden_states)
445
+ attention_output = self.SelfAttention(
446
+ normed_hidden_states,
447
+ mask=attention_mask,
448
+ position_bias=position_bias,
449
+ layer_head_mask=layer_head_mask,
450
+ past_key_value=past_key_value,
451
+ use_cache=use_cache,
452
+ output_attentions=output_attentions,
453
+ cache_position=cache_position,
454
+ )
455
+ hidden_states = hidden_states + self.dropout(attention_output[0])
456
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
457
+ return outputs
458
+
459
+
460
+ # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano
461
+ class Pop2PianoLayerCrossAttention(nn.Module):
462
+ def __init__(self, config, layer_idx: Optional[int] = None):
463
+ super().__init__()
464
+ self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
465
+ self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
466
+ self.dropout = nn.Dropout(config.dropout_rate)
467
+
468
+ def forward(
469
+ self,
470
+ hidden_states,
471
+ key_value_states,
472
+ attention_mask=None,
473
+ position_bias=None,
474
+ layer_head_mask=None,
475
+ past_key_value=None,
476
+ use_cache=False,
477
+ query_length=None,
478
+ output_attentions=False,
479
+ cache_position=None,
480
+ ):
481
+ normed_hidden_states = self.layer_norm(hidden_states)
482
+ attention_output = self.EncDecAttention(
483
+ normed_hidden_states,
484
+ mask=attention_mask,
485
+ key_value_states=key_value_states,
486
+ position_bias=position_bias,
487
+ layer_head_mask=layer_head_mask,
488
+ past_key_value=past_key_value,
489
+ use_cache=use_cache,
490
+ query_length=query_length,
491
+ output_attentions=output_attentions,
492
+ cache_position=cache_position,
493
+ )
494
+ layer_output = hidden_states + self.dropout(attention_output[0])
495
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
496
+ return outputs
497
+
498
+ # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano
499
+ class Pop2PianoBlock(GradientCheckpointingLayer):
500
+ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
501
+ super().__init__()
502
+ self.is_decoder = config.is_decoder
503
+ self.layer = nn.ModuleList()
504
+ self.layer.append(
505
+ Pop2PianoLayerSelfAttention(
506
+ config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
507
+ )
508
+ )
509
+ if self.is_decoder:
510
+ self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx))
511
+
512
+ self.layer.append(Pop2PianoLayerFF(config))
513
+
514
+ def forward(
515
+ self,
516
+ hidden_states,
517
+ attention_mask=None,
518
+ position_bias=None,
519
+ encoder_hidden_states=None,
520
+ encoder_attention_mask=None,
521
+ encoder_decoder_position_bias=None,
522
+ layer_head_mask=None,
523
+ cross_attn_layer_head_mask=None,
524
+ past_key_value=None,
525
+ use_cache=False,
526
+ output_attentions=False,
527
+ return_dict=True,
528
+ cache_position=None,
529
+ ):
530
+ self_attention_outputs = self.layer[0](
531
+ hidden_states,
532
+ attention_mask=attention_mask,
533
+ position_bias=position_bias,
534
+ layer_head_mask=layer_head_mask,
535
+ past_key_value=past_key_value,
536
+ use_cache=use_cache,
537
+ output_attentions=output_attentions,
538
+ cache_position=cache_position,
539
+ )
540
+ hidden_states, past_key_value = self_attention_outputs[:2]
541
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
542
+
543
+ # clamp inf values to enable fp16 training
544
+ if hidden_states.dtype == torch.float16:
545
+ clamp_value = torch.where(
546
+ torch.isinf(hidden_states).any(),
547
+ torch.finfo(hidden_states.dtype).max - 1000,
548
+ torch.finfo(hidden_states.dtype).max,
549
+ )
550
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
551
+
552
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
553
+ if do_cross_attention:
554
+ cross_attention_outputs = self.layer[1](
555
+ hidden_states,
556
+ key_value_states=encoder_hidden_states,
557
+ attention_mask=encoder_attention_mask,
558
+ position_bias=encoder_decoder_position_bias,
559
+ layer_head_mask=cross_attn_layer_head_mask,
560
+ past_key_value=past_key_value,
561
+ query_length=cache_position[-1] + 1,
562
+ use_cache=use_cache,
563
+ output_attentions=output_attentions,
564
+ )
565
+ hidden_states, past_key_value = cross_attention_outputs[:2]
566
+
567
+ # clamp inf values to enable fp16 training
568
+ if hidden_states.dtype == torch.float16:
569
+ clamp_value = torch.where(
570
+ torch.isinf(hidden_states).any(),
571
+ torch.finfo(hidden_states.dtype).max - 1000,
572
+ torch.finfo(hidden_states.dtype).max,
573
+ )
574
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
575
+
576
+ # Keep cross-attention outputs and relative position weights
577
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
578
+
579
+ # Apply Feed Forward layer
580
+ hidden_states = self.layer[-1](hidden_states)
581
+
582
+ # clamp inf values to enable fp16 training
583
+ if hidden_states.dtype == torch.float16:
584
+ clamp_value = torch.where(
585
+ torch.isinf(hidden_states).any(),
586
+ torch.finfo(hidden_states.dtype).max - 1000,
587
+ torch.finfo(hidden_states.dtype).max,
588
+ )
589
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
590
+
591
+ outputs = (hidden_states,)
592
+
593
+ if use_cache:
594
+ outputs = outputs + (past_key_value,) + attention_outputs
595
+ else:
596
+ outputs = outputs + attention_outputs
597
+
598
+ return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
599
+
600
+ class Pop2PianoConcatEmbeddingToMel(nn.Module):
601
+ """Embedding Matrix for `composer` tokens."""
602
+
603
+ def __init__(self, config):
604
+ super().__init__()
605
+ self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model)
606
+
607
+ def forward(self, feature, index_value, embedding_offset):
608
+ index_shifted = index_value - embedding_offset
609
+ composer_embedding = self.embedding(index_shifted).unsqueeze(1)
610
+ inputs_embeds = torch.cat([composer_embedding, feature], dim=1)
611
+ return inputs_embeds
612
+
613
+ class Pop2PianoPreTrainedModel(PreTrainedModel):
614
+ config_class = Pop2PianoConfig
615
+ base_model_prefix = "transformer"
616
+ is_parallelizable = False
617
+ supports_gradient_checkpointing = True
618
+ _supports_cache_class = True
619
+ _supports_static_cache = False
620
+ _no_split_modules = ["Pop2PianoBlock"]
621
+ _keep_in_fp32_modules = ["wo"]
622
+
623
+ def _init_weights(self, module):
624
+ """Initialize the weights"""
625
+ factor = self.config.initializer_factor # Used for testing weights initialization
626
+ if isinstance(module, Pop2PianoLayerNorm):
627
+ module.weight.data.fill_(factor * 1.0)
628
+ elif isinstance(module, Pop2PianoConcatEmbeddingToMel):
629
+ module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0)
630
+ elif isinstance(module, Pop2PianoForConditionalGeneration):
631
+ # Mesh TensorFlow embeddings initialization
632
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
633
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
634
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
635
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
636
+ elif isinstance(module, Pop2PianoDenseActDense):
637
+ # Mesh TensorFlow FF initialization
638
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
639
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
640
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
641
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
642
+ module.wi.bias.data.zero_()
643
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
644
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
645
+ module.wo.bias.data.zero_()
646
+ elif isinstance(module, Pop2PianoDenseGatedActDense):
647
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
648
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
649
+ module.wi_0.bias.data.zero_()
650
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
651
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
652
+ module.wi_1.bias.data.zero_()
653
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
654
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
655
+ module.wo.bias.data.zero_()
656
+ elif isinstance(module, Pop2PianoAttention):
657
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
658
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
659
+ d_model = self.config.d_model
660
+ key_value_proj_dim = self.config.d_kv
661
+ n_heads = self.config.num_heads
662
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
663
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
664
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
665
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
666
+ if module.has_relative_attention_bias:
667
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
668
+
669
+ def _shift_right(self, input_ids):
670
+ decoder_start_token_id = self.config.decoder_start_token_id
671
+ pad_token_id = self.config.pad_token_id
672
+
673
+ if decoder_start_token_id is None:
674
+ raise ValueError(
675
+ "self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id."
676
+ )
677
+
678
+ # shift inputs to the right
679
+ if is_torch_fx_proxy(input_ids):
680
+ # Item assignment is not supported natively for proxies.
681
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
682
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
683
+ else:
684
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
685
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
686
+ shifted_input_ids[..., 0] = decoder_start_token_id
687
+
688
+ if pad_token_id is None:
689
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
690
+ # replace possible -100 values in labels by `pad_token_id`
691
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
692
+
693
+ return shifted_input_ids
694
+
695
+ class Pop2PianoStack(Pop2PianoPreTrainedModel):
696
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano
697
+ def __init__(self, config, embed_tokens=None):
698
+ super().__init__(config)
699
+
700
+ self.embed_tokens = embed_tokens
701
+ self.is_decoder = config.is_decoder
702
+
703
+ self.block = nn.ModuleList(
704
+ [
705
+ Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
706
+ for i in range(config.num_layers)
707
+ ]
708
+ )
709
+ self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
710
+ self.dropout = nn.Dropout(config.dropout_rate)
711
+
712
+ # Initialize weights and apply final processing
713
+ self.post_init()
714
+ # Model parallel
715
+ self.model_parallel = False
716
+ self.device_map = None
717
+ self.gradient_checkpointing = False
718
+
719
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
720
+ def get_input_embeddings(self):
721
+ return self.embed_tokens
722
+
723
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
724
+ def set_input_embeddings(self, new_embeddings):
725
+ self.embed_tokens = new_embeddings
726
+
727
+ def forward(
728
+ self,
729
+ input_ids=None,
730
+ attention_mask=None,
731
+ encoder_hidden_states=None,
732
+ encoder_attention_mask=None,
733
+ inputs_embeds=None,
734
+ head_mask=None,
735
+ cross_attn_head_mask=None,
736
+ past_key_values=None,
737
+ use_cache=None,
738
+ output_attentions=None,
739
+ output_hidden_states=None,
740
+ return_dict=None,
741
+ cache_position=None,
742
+ ):
743
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
744
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
+ output_hidden_states = (
746
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
+ )
748
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
749
+
750
+ if input_ids is not None and inputs_embeds is not None:
751
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
752
+ raise ValueError(
753
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
754
+ )
755
+ elif input_ids is not None:
756
+ input_shape = input_ids.size()
757
+ input_ids = input_ids.view(-1, input_shape[-1])
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ else:
761
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
762
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
763
+
764
+ if self.gradient_checkpointing and self.training:
765
+ if use_cache:
766
+ print(
767
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
768
+ )
769
+ use_cache = False
770
+
771
+ if inputs_embeds is None:
772
+ if self.embed_tokens is None:
773
+ raise ValueError("You have to initialize the model with valid token embeddings")
774
+ inputs_embeds = self.embed_tokens(input_ids)
775
+
776
+ batch_size, seq_length = input_shape
777
+
778
+ if use_cache is True:
779
+ if not self.is_decoder:
780
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
781
+
782
+ # initialize past_key_values
783
+ return_legacy_cache = False
784
+ return_self_attention_cache = False
785
+ if self.is_decoder and (use_cache or past_key_values is not None):
786
+ if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
787
+ return_self_attention_cache = True
788
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
789
+ elif not isinstance(past_key_values, EncoderDecoderCache):
790
+ return_legacy_cache = True
791
+ print(
792
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
793
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
794
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
795
+ )
796
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
797
+ elif past_key_values is None:
798
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
799
+ elif not self.is_decoder:
800
+ # do not pass cache object down the line for encoder stack
801
+ # it messes indexing later in decoder-stack because cache object is modified in-place
802
+ past_key_values = None
803
+
804
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
805
+ if cache_position is None:
806
+ cache_position = torch.arange(
807
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
808
+ )
809
+
810
+ if attention_mask is None and not is_torchdynamo_compiling():
811
+ # required mask seq length can be calculated via length of past cache
812
+ mask_seq_length = past_key_values_length + seq_length
813
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
814
+
815
+ if self.config.is_decoder:
816
+ causal_mask = self._update_causal_mask(
817
+ attention_mask,
818
+ inputs_embeds,
819
+ cache_position,
820
+ past_key_values.self_attention_cache if past_key_values is not None else None,
821
+ output_attentions,
822
+ )
823
+ else:
824
+ causal_mask = attention_mask[:, None, None, :]
825
+ causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
826
+ causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
827
+
828
+ # If a 2D or 3D attention mask is provided for the cross-attention
829
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
830
+ if self.is_decoder and encoder_hidden_states is not None:
831
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
832
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
833
+ if encoder_attention_mask is None:
834
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
835
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
836
+ else:
837
+ encoder_extended_attention_mask = None
838
+
839
+ # Prepare head mask if needed
840
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
841
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
842
+ all_hidden_states = () if output_hidden_states else None
843
+ all_attentions = () if output_attentions else None
844
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
845
+ position_bias = None
846
+ encoder_decoder_position_bias = None
847
+
848
+ hidden_states = self.dropout(inputs_embeds)
849
+
850
+ for i, layer_module in enumerate(self.block):
851
+ layer_head_mask = head_mask[i]
852
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
853
+ if output_hidden_states:
854
+ all_hidden_states = all_hidden_states + (hidden_states,)
855
+
856
+ layer_outputs = layer_module(
857
+ hidden_states,
858
+ causal_mask,
859
+ position_bias,
860
+ encoder_hidden_states,
861
+ encoder_extended_attention_mask,
862
+ encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
863
+ layer_head_mask=layer_head_mask,
864
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
865
+ past_key_value=past_key_values,
866
+ use_cache=use_cache,
867
+ output_attentions=output_attentions,
868
+ cache_position=cache_position,
869
+ )
870
+
871
+ # layer_outputs is a tuple with:
872
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
873
+ if use_cache is False:
874
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
875
+
876
+ hidden_states, next_decoder_cache = layer_outputs[:2]
877
+
878
+ # We share the position biases between the layers - the first layer store them
879
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
880
+ # (cross-attention position bias), (cross-attention weights)
881
+ position_bias = layer_outputs[2]
882
+ if self.is_decoder and encoder_hidden_states is not None:
883
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
884
+
885
+ if output_attentions:
886
+ all_attentions = all_attentions + (layer_outputs[3],)
887
+ if self.is_decoder:
888
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
889
+
890
+ hidden_states = self.final_layer_norm(hidden_states)
891
+ hidden_states = self.dropout(hidden_states)
892
+
893
+ # Add last layer
894
+ if output_hidden_states:
895
+ all_hidden_states = all_hidden_states + (hidden_states,)
896
+
897
+ next_cache = next_decoder_cache if use_cache else None
898
+ if return_self_attention_cache:
899
+ next_cache = past_key_values.self_attention_cache
900
+ if return_legacy_cache:
901
+ next_cache = past_key_values.to_legacy_cache()
902
+
903
+ if not return_dict:
904
+ return tuple(
905
+ v
906
+ for v in [
907
+ hidden_states,
908
+ next_cache,
909
+ all_hidden_states,
910
+ all_attentions,
911
+ all_cross_attentions,
912
+ ]
913
+ if v is not None
914
+ )
915
+ return BaseModelOutputWithPastAndCrossAttentions(
916
+ last_hidden_state=hidden_states,
917
+ past_key_values=next_cache,
918
+ hidden_states=all_hidden_states,
919
+ attentions=all_attentions,
920
+ cross_attentions=all_cross_attentions,
921
+ )
922
+
923
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
924
+ def _update_causal_mask(
925
+ self,
926
+ attention_mask: Union[torch.Tensor, "BlockMask"],
927
+ input_tensor: torch.Tensor,
928
+ cache_position: torch.Tensor,
929
+ past_key_values: Cache,
930
+ output_attentions: bool = False,
931
+ ):
932
+ if self.config._attn_implementation == "flash_attention_2":
933
+ if attention_mask is not None and (attention_mask == 0.0).any():
934
+ return attention_mask
935
+ return None
936
+ if self.config._attn_implementation == "flex_attention":
937
+ if isinstance(attention_mask, torch.Tensor):
938
+ attention_mask = make_flex_block_causal_mask(attention_mask)
939
+ return attention_mask
940
+
941
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
942
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
943
+ # to infer the attention mask.
944
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
945
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
946
+
947
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
948
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
949
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
950
+ attention_mask,
951
+ inputs_embeds=input_tensor,
952
+ past_key_values_length=past_seen_tokens,
953
+ is_training=self.training,
954
+ ):
955
+ return None
956
+
957
+ dtype = input_tensor.dtype
958
+ sequence_length = input_tensor.shape[1]
959
+ if using_compilable_cache:
960
+ target_length = past_key_values.get_max_cache_shape()
961
+ else:
962
+ target_length = (
963
+ attention_mask.shape[-1]
964
+ if isinstance(attention_mask, torch.Tensor)
965
+ else past_seen_tokens + sequence_length + 1
966
+ )
967
+
968
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
969
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
970
+ attention_mask,
971
+ sequence_length=sequence_length,
972
+ target_length=target_length,
973
+ dtype=dtype,
974
+ cache_position=cache_position,
975
+ batch_size=input_tensor.shape[0],
976
+ )
977
+
978
+ if (
979
+ self.config._attn_implementation == "sdpa"
980
+ and attention_mask is not None
981
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
982
+ and not output_attentions
983
+ ):
984
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
985
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
986
+ # Details: https://github.com/pytorch/pytorch/issues/110213
987
+ min_dtype = torch.finfo(dtype).min
988
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
989
+
990
+ return causal_mask
991
+
992
+ @staticmethod
993
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
994
+ def _prepare_4d_causal_attention_mask_with_cache_position(
995
+ attention_mask: torch.Tensor,
996
+ sequence_length: int,
997
+ target_length: int,
998
+ dtype: torch.dtype,
999
+ cache_position: torch.Tensor,
1000
+ batch_size: int,
1001
+ **kwargs,
1002
+ ):
1003
+ """
1004
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1005
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1006
+
1007
+ Args:
1008
+ attention_mask (`torch.Tensor`):
1009
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1010
+ `(batch_size, 1, query_length, key_value_length)`.
1011
+ sequence_length (`int`):
1012
+ The sequence length being processed.
1013
+ target_length (`int`):
1014
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1015
+ to account for the 0 padding, the part of the cache that is not filled yet.
1016
+ dtype (`torch.dtype`):
1017
+ The dtype to use for the 4D attention mask.
1018
+ cache_position (`torch.Tensor`):
1019
+ Indices depicting the position of the input sequence tokens in the sequence.
1020
+ batch_size (`torch.Tensor`):
1021
+ Batch size.
1022
+ """
1023
+ if attention_mask is not None and attention_mask.dim() == 4:
1024
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1025
+ causal_mask = attention_mask
1026
+ else:
1027
+ min_dtype = torch.finfo(dtype).min
1028
+ causal_mask = torch.full(
1029
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
1030
+ )
1031
+ if sequence_length != 1:
1032
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1033
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
1034
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1035
+ if attention_mask is not None:
1036
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1037
+ mask_length = attention_mask.shape[-1]
1038
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1039
+ causal_mask.device
1040
+ )
1041
+ padding_mask = padding_mask == 0
1042
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1043
+ padding_mask, min_dtype
1044
+ )
1045
+
1046
+ return causal_mask
1047
+
1048
+ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin):
1049
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1050
+
1051
+ def __init__(self, config: Pop2PianoConfig):
1052
+ super().__init__(config)
1053
+ self.config = config
1054
+ self.model_dim = config.d_model
1055
+
1056
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1057
+
1058
+ self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config)
1059
+
1060
+ encoder_config = copy.deepcopy(config)
1061
+ encoder_config.is_decoder = False
1062
+ encoder_config.use_cache = False
1063
+ encoder_config.is_encoder_decoder = False
1064
+
1065
+ self.encoder = Pop2PianoStack(encoder_config, self.shared)
1066
+
1067
+ decoder_config = copy.deepcopy(config)
1068
+ decoder_config.is_decoder = True
1069
+ decoder_config.is_encoder_decoder = False
1070
+ decoder_config.num_layers = config.num_decoder_layers
1071
+ self.decoder = Pop2PianoStack(decoder_config, self.shared)
1072
+
1073
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1074
+
1075
+ # Initialize weights and apply final processing
1076
+ self.post_init()
1077
+
1078
+ def get_input_embeddings(self):
1079
+ return self.shared
1080
+
1081
+ def set_input_embeddings(self, new_embeddings):
1082
+ self.shared = new_embeddings
1083
+ self.encoder.set_input_embeddings(new_embeddings)
1084
+ self.decoder.set_input_embeddings(new_embeddings)
1085
+
1086
+ def set_output_embeddings(self, new_embeddings):
1087
+ self.lm_head = new_embeddings
1088
+
1089
+ def get_output_embeddings(self):
1090
+ return self.lm_head
1091
+
1092
+ def get_encoder(self):
1093
+ return self.encoder
1094
+
1095
+ def get_decoder(self):
1096
+ return self.decoder
1097
+
1098
+ def get_mel_conditioner_outputs(
1099
+ self,
1100
+ input_features: torch.FloatTensor,
1101
+ composer: str,
1102
+ generation_config: GenerationConfig,
1103
+ attention_mask: Optional[torch.FloatTensor] = None,
1104
+ ):
1105
+ """
1106
+ This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
1107
+ control the type of MIDI token generated by the model.
1108
+
1109
+ Args:
1110
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1111
+ input features extracted from the feature extractor.
1112
+ composer (`str`):
1113
+ composer token which determines the type of MIDI tokens to be generated.
1114
+ generation_config (`~generation.GenerationConfig`):
1115
+ The generation is used to get the composer-feature_token pair.
1116
+ attention_mask (``, *optional*):
1117
+ For batched generation `input_features` are padded to have the same shape across all examples.
1118
+ `attention_mask` helps to determine which areas were padded and which were not.
1119
+ - 1 for tokens that are **not padded**,
1120
+ - 0 for tokens that are **padded**.
1121
+ """
1122
+ composer_to_feature_token = generation_config.composer_to_feature_token
1123
+ if composer not in composer_to_feature_token.keys():
1124
+ raise ValueError(
1125
+ f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}"
1126
+ )
1127
+ composer_value = composer_to_feature_token[composer]
1128
+ composer_value = torch.tensor(composer_value, device=self.device)
1129
+ composer_value = composer_value.repeat(input_features.shape[0])
1130
+
1131
+ embedding_offset = min(composer_to_feature_token.values())
1132
+
1133
+ input_features = self.mel_conditioner(
1134
+ feature=input_features,
1135
+ index_value=composer_value,
1136
+ embedding_offset=embedding_offset,
1137
+ )
1138
+ if attention_mask is not None:
1139
+ input_features[~attention_mask[:, 0].bool()] = 0.0
1140
+
1141
+ # since self.mel_conditioner adds a new array at the front of inputs_embeds we need to do the same for attention_mask to keep the shapes same
1142
+ attention_mask = torch.concatenate([attention_mask[:, 0].view(-1, 1), attention_mask], axis=1)
1143
+ return input_features, attention_mask
1144
+
1145
+ return input_features, None
1146
+
1147
+ def forward(
1148
+ self,
1149
+ input_ids: Optional[torch.LongTensor] = None,
1150
+ attention_mask: Optional[torch.FloatTensor] = None,
1151
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1152
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1153
+ head_mask: Optional[torch.FloatTensor] = None,
1154
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1155
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1156
+ encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
1157
+ past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None,
1158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1159
+ input_features: Optional[torch.FloatTensor] = None,
1160
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1161
+ labels: Optional[torch.LongTensor] = None,
1162
+ use_cache: Optional[bool] = None,
1163
+ output_attentions: Optional[bool] = None,
1164
+ output_hidden_states: Optional[bool] = None,
1165
+ return_dict: Optional[bool] = None,
1166
+ cache_position: Optional[torch.LongTensor] = None,
1167
+ ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1168
+ r"""
1169
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1170
+ Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
1171
+ so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
1172
+ [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
1173
+ [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
1174
+ take a look a [Pop2Piano Training](./Pop2Piano#training).
1175
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1176
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
1177
+ [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
1178
+ [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
1179
+ starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
1180
+ `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
1181
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1182
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1183
+ be used by default.
1184
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1185
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1186
+ 1]`:
1187
+ - 1 indicates the head is **not masked**,
1188
+ - 0 indicates the head is **masked**.
1189
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1190
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1191
+ `[0, 1]`:
1192
+ - 1 indicates the head is **not masked**,
1193
+ - 0 indicates the head is **masked**.
1194
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1195
+ Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present
1196
+ then `input_features` will be considered as `inputs_embeds`.
1197
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1198
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1199
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1200
+ labels in `[0, ..., config.vocab_size]`
1201
+ """
1202
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1204
+
1205
+ if inputs_embeds is not None and input_features is not None:
1206
+ raise ValueError("Both `inputs_embeds` and `input_features` received! Please provide only one of them")
1207
+ elif input_features is not None and inputs_embeds is None:
1208
+ inputs_embeds = input_features
1209
+
1210
+ # Encode if needed (training, first prediction pass)
1211
+ if encoder_outputs is None:
1212
+ # Convert encoder inputs in embeddings if needed
1213
+ encoder_outputs = self.encoder(
1214
+ input_ids=input_ids,
1215
+ attention_mask=attention_mask,
1216
+ inputs_embeds=inputs_embeds,
1217
+ head_mask=head_mask,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ )
1222
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1223
+ encoder_outputs = BaseModelOutput(
1224
+ last_hidden_state=encoder_outputs[0],
1225
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1226
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1227
+ )
1228
+
1229
+ hidden_states = encoder_outputs[0]
1230
+
1231
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1232
+ # get decoder inputs from shifting lm labels to the right
1233
+ decoder_input_ids = self._shift_right(labels)
1234
+
1235
+ # Decode
1236
+ decoder_outputs = self.decoder(
1237
+ input_ids=decoder_input_ids,
1238
+ attention_mask=decoder_attention_mask,
1239
+ inputs_embeds=decoder_inputs_embeds,
1240
+ past_key_values=past_key_values,
1241
+ encoder_hidden_states=hidden_states,
1242
+ encoder_attention_mask=attention_mask,
1243
+ head_mask=decoder_head_mask,
1244
+ cross_attn_head_mask=cross_attn_head_mask,
1245
+ use_cache=use_cache,
1246
+ output_attentions=output_attentions,
1247
+ output_hidden_states=output_hidden_states,
1248
+ return_dict=return_dict,
1249
+ cache_position=cache_position,
1250
+ )
1251
+
1252
+ sequence_output = decoder_outputs[0]
1253
+
1254
+ if self.config.tie_word_embeddings:
1255
+ # Rescale output before projecting on vocab
1256
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1257
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1258
+
1259
+ lm_logits = self.lm_head(sequence_output)
1260
+
1261
+ loss = None
1262
+ if labels is not None:
1263
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
1264
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1265
+
1266
+ if not return_dict:
1267
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1268
+ return ((loss,) + output) if loss is not None else output
1269
+
1270
+ return Seq2SeqLMOutput(
1271
+ loss=loss,
1272
+ logits=lm_logits,
1273
+ past_key_values=decoder_outputs.past_key_values,
1274
+ decoder_hidden_states=decoder_outputs.hidden_states,
1275
+ decoder_attentions=decoder_outputs.attentions,
1276
+ cross_attentions=decoder_outputs.cross_attentions,
1277
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1278
+ encoder_hidden_states=encoder_outputs.hidden_states,
1279
+ encoder_attentions=encoder_outputs.attentions,
1280
+ )
1281
+
1282
+ @torch.no_grad()
1283
+ def generate(
1284
+ self,
1285
+ input_features,
1286
+ attention_mask=None,
1287
+ composer="composer1",
1288
+ generation_config=None,
1289
+ **kwargs,
1290
+ ):
1291
+ """
1292
+ Generates token ids for midi outputs.
1293
+
1294
+ <Tip warning={true}>
1295
+
1296
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
1297
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
1298
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
1299
+ strategies and code examples, check out the [following guide](./generation_strategies).
1300
+
1301
+ </Tip>
1302
+
1303
+ Parameters:
1304
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1305
+ This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
1306
+ attention_mask:
1307
+ For batched generation `input_features` are padded to have the same shape across all examples.
1308
+ `attention_mask` helps to determine which areas were padded and which were not.
1309
+ - 1 for tokens that are **not padded**,
1310
+ - 0 for tokens that are **padded**.
1311
+ composer (`str`, *optional*, defaults to `"composer1"`):
1312
+ This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
1313
+ `"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
1314
+ `generation_config`. For an example please see
1315
+ https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
1316
+ generation_config (`~generation.GenerationConfig`, *optional*):
1317
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1318
+ passed to generate matching the attributes of `generation_config` will override them. If
1319
+ `generation_config` is not provided, the default will be used, which had the following loading
1320
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1321
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1322
+ default values, whose documentation should be checked to parameterize generation.
1323
+ kwargs:
1324
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
1325
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
1326
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
1327
+ Return:
1328
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
1329
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1330
+ Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1331
+ [`~utils.ModelOutput`] types are:
1332
+ - [`~generation.GenerateEncoderDecoderOutput`],
1333
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
1334
+ """
1335
+
1336
+ if generation_config is None:
1337
+ generation_config = self.generation_config
1338
+ generation_config.update(**kwargs)
1339
+
1340
+ # check for composer_to_feature_token
1341
+ if not hasattr(generation_config, "composer_to_feature_token"):
1342
+ raise ValueError(
1343
+ "`composer_to_feature_token` was not found! Please refer to "
1344
+ "https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json"
1345
+ "and parse a dict like that."
1346
+ )
1347
+
1348
+ if len(generation_config.composer_to_feature_token) != self.config.composer_vocab_size:
1349
+ raise ValueError(
1350
+ "config.composer_vocab_size must be same as the number of keys in "
1351
+ f"generation_config.composer_to_feature_token! "
1352
+ f"Found {self.config.composer_vocab_size} vs {len(generation_config.composer_to_feature_token)}."
1353
+ )
1354
+
1355
+ # to control the variation of generated MIDI tokens we concatenate mel-conditioner tokens(which depends on composer_token)
1356
+ # at the front of input_features.
1357
+ input_features, attention_mask = self.get_mel_conditioner_outputs(
1358
+ input_features=input_features,
1359
+ attention_mask=attention_mask,
1360
+ composer=composer,
1361
+ generation_config=generation_config,
1362
+ )
1363
+
1364
+ return super().generate(
1365
+ inputs=None,
1366
+ inputs_embeds=input_features,
1367
+ attention_mask=attention_mask,
1368
+ generation_config=generation_config,
1369
+ **kwargs,
1370
+ )
1371
+
1372
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1373
+ return self._shift_right(labels)
1374
+
1375
+ def _reorder_cache(self, past_key_values, beam_idx):
1376
+ # if decoder past is not included in output
1377
+ # speedy decoding is disabled and no need to reorder
1378
+ if past_key_values is None:
1379
+ print("You might want to consider setting `use_cache=True` to speed up decoding")
1380
+ return past_key_values
1381
+
1382
+ reordered_decoder_past = ()
1383
+ for layer_past_states in past_key_values:
1384
+ # get the correct batch idx from layer past batch dim
1385
+ # batch dim of `past` is at 2nd position
1386
+ reordered_layer_past_states = ()
1387
+ for layer_past_state in layer_past_states:
1388
+ # need to set correct `past` for each of the four key / value states
1389
+ reordered_layer_past_states = reordered_layer_past_states + (
1390
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1391
+ )
1392
+
1393
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
1394
+ raise ValueError(
1395
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
1396
+ )
1397
+ if len(reordered_layer_past_states) != len(layer_past_states):
1398
+ raise ValueError(
1399
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
1400
+ )
1401
+
1402
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1403
+ return reordered_decoder_past
1404
+
1405
+ class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
1406
+ r"""
1407
+ Constructs a Pop2Piano feature extractor.
1408
+
1409
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
1410
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
1411
+
1412
+ This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed
1413
+ to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as
1414
+ well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate
1415
+ extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate
1416
+ and preprocessed and then log mel spectogram is computed from that to be used in our transformer model.
1417
+
1418
+ Args:
1419
+ sampling_rate (`int`, *optional*, defaults to 22050):
1420
+ Target Sampling rate of audio signal. It's the sampling rate that we forward to the model.
1421
+ padding_value (`int`, *optional*, defaults to 0):
1422
+ Padding value used to pad the audio. Should correspond to silences.
1423
+ window_size (`int`, *optional*, defaults to 4096):
1424
+ Length of the window in samples to which the Fourier transform is applied.
1425
+ hop_length (`int`, *optional*, defaults to 1024):
1426
+ Step size between each window of the waveform, in samples.
1427
+ min_frequency (`float`, *optional*, defaults to 10.0):
1428
+ Lowest frequency that will be used in the log-mel spectrogram.
1429
+ feature_size (`int`, *optional*, defaults to 512):
1430
+ The feature dimension of the extracted features.
1431
+ num_bars (`int`, *optional*, defaults to 2):
1432
+ Determines interval between each sequence.
1433
+ """
1434
+
1435
+ model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"]
1436
+
1437
+ def __init__(
1438
+ self,
1439
+ sampling_rate: int = 22050,
1440
+ padding_value: int = 0,
1441
+ window_size: int = 4096,
1442
+ hop_length: int = 1024,
1443
+ min_frequency: float = 10.0,
1444
+ feature_size: int = 512,
1445
+ num_bars: int = 2,
1446
+ **kwargs,
1447
+ ):
1448
+ super().__init__(
1449
+ feature_size=feature_size,
1450
+ sampling_rate=sampling_rate,
1451
+ padding_value=padding_value,
1452
+ **kwargs,
1453
+ )
1454
+ self.sampling_rate = sampling_rate
1455
+ self.padding_value = padding_value
1456
+ self.window_size = window_size
1457
+ self.hop_length = hop_length
1458
+ self.min_frequency = min_frequency
1459
+ self.feature_size = feature_size
1460
+ self.num_bars = num_bars
1461
+ self.mel_filters = mel_filter_bank(
1462
+ num_frequency_bins=(self.window_size // 2) + 1,
1463
+ num_mel_filters=self.feature_size,
1464
+ min_frequency=self.min_frequency,
1465
+ max_frequency=float(self.sampling_rate // 2),
1466
+ sampling_rate=self.sampling_rate,
1467
+ norm=None,
1468
+ mel_scale="htk",
1469
+ )
1470
+
1471
+ def mel_spectrogram(self, sequence: np.ndarray):
1472
+ """
1473
+ Generates MelSpectrogram.
1474
+
1475
+ Args:
1476
+ sequence (`np.ndarray`):
1477
+ The sequence of which the mel-spectrogram will be computed.
1478
+ """
1479
+ mel_specs = []
1480
+ for seq in sequence:
1481
+ window = np.hanning(self.window_size + 1)[:-1]
1482
+ mel_specs.append(
1483
+ spectrogram(
1484
+ waveform=seq,
1485
+ window=window,
1486
+ frame_length=self.window_size,
1487
+ hop_length=self.hop_length,
1488
+ power=2.0,
1489
+ mel_filters=self.mel_filters,
1490
+ )
1491
+ )
1492
+ mel_specs = np.array(mel_specs)
1493
+
1494
+ return mel_specs
1495
+
1496
+ def extract_rhythm(self, audio: np.ndarray):
1497
+ """
1498
+ This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as
1499
+ tempo in bpm for an audio signal. For more information please visit
1500
+ https://essentia.upf.edu/reference/std_RhythmExtractor2013.html .
1501
+
1502
+ Args:
1503
+ audio(`np.ndarray`):
1504
+ raw audio waveform which is passed to the Rhythm Extractor.
1505
+ """
1506
+ essentia_tracker = RhythmExtractor2013(method="multifeature")
1507
+ bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio)
1508
+
1509
+ return bpm, beat_times, confidence, estimates, essentia_beat_intervals
1510
+
1511
+ def interpolate_beat_times(
1512
+ self, beat_times: np.ndarray, steps_per_beat: np.ndarray, n_extend: np.ndarray
1513
+ ):
1514
+ """
1515
+ This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is
1516
+ then used to convert raw audio to log-mel-spectrogram.
1517
+
1518
+ Args:
1519
+ beat_times (`np.ndarray`):
1520
+ beat_times is passed into `scipy.interpolate.interp1d` for processing.
1521
+ steps_per_beat (`int`):
1522
+ used as an parameter to control the interpolation.
1523
+ n_extend (`int`):
1524
+ used as an parameter to control the interpolation.
1525
+ """
1526
+
1527
+ beat_times_function = interp1d(
1528
+ np.arange(beat_times.size),
1529
+ beat_times,
1530
+ bounds_error=False,
1531
+ fill_value="extrapolate",
1532
+ )
1533
+
1534
+ ext_beats = beat_times_function(
1535
+ np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend)
1536
+ )
1537
+
1538
+ return ext_beats
1539
+
1540
+ def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray):
1541
+ """
1542
+ Preprocessing for log-mel-spectrogram
1543
+
1544
+ Args:
1545
+ audio (`np.ndarray` of shape `(audio_length, )` ):
1546
+ Raw audio waveform to be processed.
1547
+ beatstep (`np.ndarray`):
1548
+ Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by
1549
+ the value at beatstep[0].
1550
+ """
1551
+
1552
+ if audio is not None and len(audio.shape) != 1:
1553
+ raise ValueError(
1554
+ f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}."
1555
+ )
1556
+ if beatstep[0] > 0.0:
1557
+ beatstep = beatstep - beatstep[0]
1558
+
1559
+ num_steps = self.num_bars * 4
1560
+ num_target_steps = len(beatstep)
1561
+ extrapolated_beatstep = self.interpolate_beat_times(
1562
+ beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1
1563
+ )
1564
+
1565
+ sample_indices = []
1566
+ max_feature_length = 0
1567
+ for i in range(0, num_target_steps, num_steps):
1568
+ start_idx = i
1569
+ end_idx = min(i + num_steps, num_target_steps)
1570
+ start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate)
1571
+ end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate)
1572
+ sample_indices.append((start_sample, end_sample))
1573
+ max_feature_length = max(max_feature_length, end_sample - start_sample)
1574
+ padded_batch = []
1575
+ for start_sample, end_sample in sample_indices:
1576
+ feature = audio[start_sample:end_sample]
1577
+ padded_feature = np.pad(
1578
+ feature,
1579
+ ((0, max_feature_length - feature.shape[0]),),
1580
+ "constant",
1581
+ constant_values=0,
1582
+ )
1583
+ padded_batch.append(padded_feature)
1584
+
1585
+ padded_batch = np.asarray(padded_batch)
1586
+ return padded_batch, extrapolated_beatstep
1587
+
1588
+ def _pad(self, features: np.ndarray, add_zero_line=True):
1589
+ features_shapes = [each_feature.shape for each_feature in features]
1590
+ attention_masks, padded_features = [], []
1591
+ for i, each_feature in enumerate(features):
1592
+ # To pad "input_features".
1593
+ if len(each_feature.shape) == 3:
1594
+ features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1]
1595
+ attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64)
1596
+ feature_padding = ((0, 0), (0, features_pad_value), (0, 0))
1597
+ attention_mask_padding = (feature_padding[0], feature_padding[1])
1598
+
1599
+ # To pad "beatsteps" and "extrapolated_beatstep".
1600
+ else:
1601
+ each_feature = each_feature.reshape(1, -1)
1602
+ features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0]
1603
+ attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1)
1604
+ feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value))
1605
+
1606
+ each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value)
1607
+ attention_mask = np.pad(
1608
+ attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value
1609
+ )
1610
+
1611
+ if add_zero_line:
1612
+ # if it is batched then we separate each examples using zero array
1613
+ zero_array_len = max([*zip(*features_shapes)][1])
1614
+
1615
+ # we concatenate the zero array line here
1616
+ each_padded_feature = np.concatenate(
1617
+ [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0
1618
+ )
1619
+ attention_mask = np.concatenate(
1620
+ [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0
1621
+ )
1622
+
1623
+ padded_features.append(each_padded_feature)
1624
+ attention_masks.append(attention_mask)
1625
+
1626
+ padded_features = np.concatenate(padded_features, axis=0).astype(np.float32)
1627
+ attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64)
1628
+
1629
+ return padded_features, attention_masks
1630
+
1631
+ def pad(
1632
+ self,
1633
+ inputs: BatchFeature,
1634
+ is_batched: bool,
1635
+ return_attention_mask: bool,
1636
+ return_tensors: Optional[Union[str, TensorType]] = None,
1637
+ ):
1638
+ """
1639
+ Pads the inputs to same length and returns attention_mask.
1640
+
1641
+ Args:
1642
+ inputs (`BatchFeature`):
1643
+ Processed audio features.
1644
+ is_batched (`bool`):
1645
+ Whether inputs are batched or not.
1646
+ return_attention_mask (`bool`):
1647
+ Whether to return attention mask or not.
1648
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
1649
+ If set, will return tensors instead of list of python integers. Acceptable values are:
1650
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
1651
+ - `'np'`: Return Numpy `np.ndarray` objects.
1652
+ If nothing is specified, it will return list of `np.ndarray` arrays.
1653
+ Return:
1654
+ `BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added
1655
+ to it:
1656
+ - **attention_mask** np.ndarray of shape `(batch_size, max_input_features_seq_length)` --
1657
+ Example :
1658
+ 1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 that's why there are 2 zeros at
1659
+ the end indicating they are padded)
1660
+
1661
+ 0, 0, 0, 0, 0 (zero pad to separate audio 1 and 2)
1662
+
1663
+ 1, 1, 1, 1, 1 (audio 2)
1664
+
1665
+ 0, 0, 0, 0, 0 (zero pad to separate audio 2 and 3)
1666
+
1667
+ 1, 1, 1, 1, 1 (audio 3)
1668
+ - **attention_mask_beatsteps** np.ndarray of shape `(batch_size, max_beatsteps_seq_length)`
1669
+ - **attention_mask_extrapolated_beatstep** np.ndarray of shape `(batch_size,
1670
+ max_extrapolated_beatstep_seq_length)`
1671
+ """
1672
+
1673
+ processed_features_dict = {}
1674
+ for feature_name, feature_value in inputs.items():
1675
+ if feature_name == "input_features":
1676
+ padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True)
1677
+ processed_features_dict[feature_name] = padded_feature_values
1678
+ if return_attention_mask:
1679
+ processed_features_dict["attention_mask"] = attention_mask
1680
+ else:
1681
+ padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False)
1682
+ processed_features_dict[feature_name] = padded_feature_values
1683
+ if return_attention_mask:
1684
+ processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask
1685
+
1686
+ # If we are processing only one example, we should remove the zero array line since we don't need it to
1687
+ # separate examples from each other.
1688
+ if not is_batched and not return_attention_mask:
1689
+ processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...]
1690
+
1691
+ outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors)
1692
+
1693
+ return outputs
1694
+
1695
+ def __call__(
1696
+ self,
1697
+ audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
1698
+ sampling_rate: Union[int, list[int]],
1699
+ steps_per_beat: int = 2,
1700
+ resample: Optional[bool] = True,
1701
+ return_attention_mask: Optional[bool] = False,
1702
+ return_tensors: Optional[Union[str, TensorType]] = None,
1703
+ **kwargs,
1704
+ ) -> BatchFeature:
1705
+ """
1706
+ Main method to featurize and prepare for the model.
1707
+
1708
+ Args:
1709
+ audio (`np.ndarray`, `List`):
1710
+ The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a
1711
+ list of numpy arrays or a list of list of float values.
1712
+ sampling_rate (`int`):
1713
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
1714
+ `sampling_rate` at the forward call to prevent silent errors.
1715
+ steps_per_beat (`int`, *optional*, defaults to 2):
1716
+ This is used in interpolating `beat_times`.
1717
+ resample (`bool`, *optional*, defaults to `True`):
1718
+ Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True
1719
+ during inference.
1720
+ return_attention_mask (`bool` *optional*, defaults to `False`):
1721
+ Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as
1722
+ output or not. Automatically set to True for batched inputs.
1723
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
1724
+ If set, will return tensors instead of list of python integers. Acceptable values are:
1725
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
1726
+ - `'np'`: Return Numpy `np.ndarray` objects.
1727
+ If nothing is specified, it will return list of `np.ndarray` arrays.
1728
+ """
1729
+ is_batched = bool(isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list)))
1730
+ if is_batched:
1731
+ # This enables the user to process files of different sampling_rate at same time
1732
+ if not isinstance(sampling_rate, list):
1733
+ raise ValueError(
1734
+ "Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. "
1735
+ f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]."
1736
+ )
1737
+ return_attention_mask = True if return_attention_mask is None else return_attention_mask
1738
+ else:
1739
+ audio = [audio]
1740
+ sampling_rate = [sampling_rate]
1741
+ return_attention_mask = False if return_attention_mask is None else return_attention_mask
1742
+
1743
+ batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], []
1744
+ total_len = len(audio)
1745
+ for index, (single_raw_audio, single_sampling_rate) in enumerate(zip(audio, sampling_rate)):
1746
+ bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm(
1747
+ audio=single_raw_audio
1748
+ )
1749
+ beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1)
1750
+ if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None:
1751
+ if resample:
1752
+ # Change sampling_rate to self.sampling_rate
1753
+ single_raw_audio = librosa_resample(
1754
+ single_raw_audio,
1755
+ orig_sr=single_sampling_rate,
1756
+ target_sr=self.sampling_rate,
1757
+ res_type="kaiser_best",
1758
+ )
1759
+ else:
1760
+ print(
1761
+ f"The sampling_rate of the provided audio is different from the target sampling_rate "
1762
+ f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. "
1763
+ f"In these cases it is recommended to use `resample=True` in the `__call__` method to "
1764
+ f"get the optimal behaviour."
1765
+ )
1766
+
1767
+ single_sampling_rate = self.sampling_rate
1768
+ start_sample = int(beatsteps[0] * single_sampling_rate)
1769
+ end_sample = int(beatsteps[-1] * single_sampling_rate)
1770
+
1771
+ input_features, extrapolated_beatstep = self.preprocess_mel(
1772
+ single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0]
1773
+ )
1774
+
1775
+ mel_specs = self.mel_spectrogram(input_features.astype(np.float32))
1776
+
1777
+ # apply np.log to get log mel-spectrograms
1778
+ log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None))
1779
+
1780
+ input_features = np.transpose(log_mel_specs, (0, -1, -2))
1781
+
1782
+ batch_input_features.append(input_features)
1783
+ batch_beatsteps.append(beatsteps)
1784
+ batch_ext_beatstep.append(extrapolated_beatstep)
1785
+ output = BatchFeature(
1786
+ {
1787
+ "input_features": batch_input_features,
1788
+ "beatsteps": batch_beatsteps,
1789
+ "extrapolated_beatstep": batch_ext_beatstep,
1790
+ }
1791
+ )
1792
+
1793
+ output = self.pad(
1794
+ output,
1795
+ is_batched=is_batched,
1796
+ return_attention_mask=return_attention_mask,
1797
+ return_tensors=return_tensors,
1798
+ )
1799
+
1800
+ return output
1801
+
1802
+ VOCAB_FILES_NAMES = {
1803
+ "vocab": "vocab.json",
1804
+ }
1805
+
1806
+ def token_time_to_note(number, cutoff_time_idx, current_idx):
1807
+ current_idx += number
1808
+ if cutoff_time_idx is not None:
1809
+ current_idx = min(current_idx, cutoff_time_idx)
1810
+
1811
+ return current_idx
1812
+
1813
+ def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes):
1814
+ if note_onsets_ready[number] is not None:
1815
+ # offset with onset
1816
+ onset_idx = note_onsets_ready[number]
1817
+ if onset_idx < current_idx:
1818
+ # Time shift after previous note_on
1819
+ offset_idx = current_idx
1820
+ notes.append([onset_idx, offset_idx, number, default_velocity])
1821
+ onsets_ready = None if current_velocity == 0 else current_idx
1822
+ note_onsets_ready[number] = onsets_ready
1823
+ else:
1824
+ note_onsets_ready[number] = current_idx
1825
+ return notes
1826
+
1827
+ class Pop2PianoTokenizer(PreTrainedTokenizer):
1828
+ """
1829
+ Constructs a Pop2Piano tokenizer. This tokenizer does not require training.
1830
+
1831
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
1832
+ this superclass for more information regarding those methods.
1833
+
1834
+ Args:
1835
+ vocab (`str`):
1836
+ Path to the vocab file which contains the vocabulary.
1837
+ default_velocity (`int`, *optional*, defaults to 77):
1838
+ Determines the default velocity to be used while creating midi Notes.
1839
+ num_bars (`int`, *optional*, defaults to 2):
1840
+ Determines cutoff_time_idx in for each token.
1841
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`):
1842
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
1843
+ token instead.
1844
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1):
1845
+ The end of sequence token.
1846
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0):
1847
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
1848
+ attention mechanisms or loss computation.
1849
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2):
1850
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
1851
+ """
1852
+
1853
+ model_input_names = ["token_ids", "attention_mask"]
1854
+ vocab_files_names = VOCAB_FILES_NAMES
1855
+
1856
+ def __init__(
1857
+ self,
1858
+ vocab,
1859
+ default_velocity=77,
1860
+ num_bars=2,
1861
+ unk_token="-1",
1862
+ eos_token="1",
1863
+ pad_token="0",
1864
+ bos_token="2",
1865
+ **kwargs,
1866
+ ):
1867
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
1868
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
1869
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
1870
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
1871
+
1872
+ self.default_velocity = default_velocity
1873
+ self.num_bars = num_bars
1874
+
1875
+ # Load the vocab
1876
+ with open(vocab, "rb") as file:
1877
+ self.encoder = json_load(file)
1878
+
1879
+ # create mappings for encoder
1880
+ self.decoder = {v: k for k, v in self.encoder.items()}
1881
+
1882
+ super().__init__(
1883
+ unk_token=unk_token,
1884
+ eos_token=eos_token,
1885
+ pad_token=pad_token,
1886
+ bos_token=bos_token,
1887
+ **kwargs,
1888
+ )
1889
+
1890
+ @property
1891
+ def vocab_size(self):
1892
+ """Returns the vocabulary size of the tokenizer."""
1893
+ return len(self.encoder)
1894
+
1895
+ def get_vocab(self):
1896
+ """Returns the vocabulary of the tokenizer."""
1897
+ return dict(self.encoder, **self.added_tokens_encoder)
1898
+
1899
+ def _convert_id_to_token(self, token_id: int) -> list:
1900
+ """
1901
+ Decodes the token ids generated by the transformer into notes.
1902
+
1903
+ Args:
1904
+ token_id (`int`):
1905
+ This denotes the ids generated by the transformers to be converted to Midi tokens.
1906
+
1907
+ Returns:
1908
+ `List`: A list consists of token_type (`str`) and value (`int`).
1909
+ """
1910
+
1911
+ token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME")
1912
+ token_type_value = token_type_value.split("_")
1913
+ token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0])
1914
+
1915
+ return [token_type, value]
1916
+
1917
+ def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int:
1918
+ """
1919
+ Encodes the Midi tokens to transformer generated token ids.
1920
+
1921
+ Args:
1922
+ token (`int`):
1923
+ This denotes the token value.
1924
+ token_type (`str`):
1925
+ This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
1926
+ "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".
1927
+
1928
+ Returns:
1929
+ `int`: returns the id of the token.
1930
+ """
1931
+ return self.encoder.get(f"{token}_{token_type}", int(self.unk_token))
1932
+
1933
+ def relative_batch_tokens_ids_to_notes(
1934
+ self,
1935
+ tokens: np.ndarray,
1936
+ beat_offset_idx: int,
1937
+ bars_per_batch: int,
1938
+ cutoff_time_idx: int,
1939
+ ):
1940
+ """
1941
+ Converts relative tokens to notes which are then used to generate pretty midi object.
1942
+
1943
+ Args:
1944
+ tokens (`np.ndarray`):
1945
+ Tokens to be converted to notes.
1946
+ beat_offset_idx (`int`):
1947
+ Denotes beat offset index for each note in generated Midi.
1948
+ bars_per_batch (`int`):
1949
+ A parameter to control the Midi output generation.
1950
+ cutoff_time_idx (`int`):
1951
+ Denotes the cutoff time index for each note in generated Midi.
1952
+ """
1953
+
1954
+ notes = None
1955
+
1956
+ for index in range(len(tokens)):
1957
+ _tokens = tokens[index]
1958
+ _start_idx = beat_offset_idx + index * bars_per_batch * 4
1959
+ _cutoff_time_idx = cutoff_time_idx + _start_idx
1960
+ _notes = self.relative_tokens_ids_to_notes(
1961
+ _tokens,
1962
+ start_idx=_start_idx,
1963
+ cutoff_time_idx=_cutoff_time_idx,
1964
+ )
1965
+
1966
+ if len(_notes) == 0:
1967
+ pass
1968
+ elif notes is None:
1969
+ notes = _notes
1970
+ else:
1971
+ notes = np.concatenate((notes, _notes), axis=0)
1972
+
1973
+ if notes is None:
1974
+ return []
1975
+ return notes
1976
+
1977
+ def relative_batch_tokens_ids_to_midi(
1978
+ self,
1979
+ tokens: np.ndarray,
1980
+ beatstep: np.ndarray,
1981
+ beat_offset_idx: int = 0,
1982
+ bars_per_batch: int = 2,
1983
+ cutoff_time_idx: int = 12,
1984
+ ):
1985
+ """
1986
+ Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
1987
+ to notes then uses `notes_to_midi` method to convert them to Midi.
1988
+
1989
+ Args:
1990
+ tokens (`np.ndarray`):
1991
+ Denotes tokens which alongside beatstep will be converted to Midi.
1992
+ beatstep (`np.ndarray`):
1993
+ We get beatstep from feature extractor which is also used to get Midi.
1994
+ beat_offset_idx (`int`, *optional*, defaults to 0):
1995
+ Denotes beat offset index for each note in generated Midi.
1996
+ bars_per_batch (`int`, *optional*, defaults to 2):
1997
+ A parameter to control the Midi output generation.
1998
+ cutoff_time_idx (`int`, *optional*, defaults to 12):
1999
+ Denotes the cutoff time index for each note in generated Midi.
2000
+ """
2001
+ beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx
2002
+ notes = self.relative_batch_tokens_ids_to_notes(
2003
+ tokens=tokens,
2004
+ beat_offset_idx=beat_offset_idx,
2005
+ bars_per_batch=bars_per_batch,
2006
+ cutoff_time_idx=cutoff_time_idx,
2007
+ )
2008
+ midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx])
2009
+ return midi
2010
+
2011
+ # Taken from the original code
2012
+ # Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257
2013
+ def relative_tokens_ids_to_notes(
2014
+ self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: Optional[float] = None
2015
+ ):
2016
+ """
2017
+ Converts relative tokens to notes which will then be used to create Pretty Midi objects.
2018
+
2019
+ Args:
2020
+ tokens (`np.ndarray`):
2021
+ Relative Tokens which will be converted to notes.
2022
+ start_idx (`float`):
2023
+ A parameter which denotes the starting index.
2024
+ cutoff_time_idx (`float`, *optional*):
2025
+ A parameter used while converting tokens to notes.
2026
+ """
2027
+ words = [self._convert_id_to_token(token) for token in tokens]
2028
+
2029
+ current_idx = start_idx
2030
+ current_velocity = 0
2031
+ note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder.keys()]) + 1)]
2032
+ notes = []
2033
+ for token_type, number in words:
2034
+ if token_type == "TOKEN_SPECIAL":
2035
+ if number == 1:
2036
+ break
2037
+ elif token_type == "TOKEN_TIME":
2038
+ current_idx = token_time_to_note(
2039
+ number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx
2040
+ )
2041
+ elif token_type == "TOKEN_VELOCITY":
2042
+ current_velocity = number
2043
+
2044
+ elif token_type == "TOKEN_NOTE":
2045
+ notes = token_note_to_note(
2046
+ number=number,
2047
+ current_velocity=current_velocity,
2048
+ default_velocity=self.default_velocity,
2049
+ note_onsets_ready=note_onsets_ready,
2050
+ current_idx=current_idx,
2051
+ notes=notes,
2052
+ )
2053
+ else:
2054
+ raise ValueError("Token type not understood!")
2055
+
2056
+ for pitch, note_onset in enumerate(note_onsets_ready):
2057
+ # force offset if no offset for each pitch
2058
+ if note_onset is not None:
2059
+ if cutoff_time_idx is None:
2060
+ cutoff = note_onset + 1
2061
+ else:
2062
+ cutoff = max(cutoff_time_idx, note_onset + 1)
2063
+
2064
+ offset_idx = max(current_idx, cutoff)
2065
+ notes.append([note_onset, offset_idx, pitch, self.default_velocity])
2066
+
2067
+ if len(notes) == 0:
2068
+ return []
2069
+ else:
2070
+ notes = np.array(notes)
2071
+ note_order = notes[:, 0] * 128 + notes[:, 1]
2072
+ notes = notes[note_order.argsort()]
2073
+ return notes
2074
+
2075
+ def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0):
2076
+ """
2077
+ Converts notes to Midi.
2078
+
2079
+ Args:
2080
+ notes (`np.ndarray`):
2081
+ This is used to create Pretty Midi objects.
2082
+ beatstep (`np.ndarray`):
2083
+ This is the extrapolated beatstep that we get from feature extractor.
2084
+ offset_sec (`int`, *optional*, defaults to 0.0):
2085
+ This represents the offset seconds which is used while creating each Pretty Midi Note.
2086
+ """
2087
+ new_pm = pretty_midi_fix.PrettyMIDI(resolution=384, initial_tempo=120.0)
2088
+ new_inst = pretty_midi_fix.Instrument(program=0)
2089
+ new_notes = []
2090
+
2091
+ for onset_idx, offset_idx, pitch, velocity in notes:
2092
+ new_note = pretty_midi_fix.Note(
2093
+ velocity=velocity,
2094
+ pitch=pitch,
2095
+ start=beatstep[onset_idx] - offset_sec,
2096
+ end=beatstep[offset_idx] - offset_sec,
2097
+ )
2098
+ new_notes.append(new_note)
2099
+ new_inst.notes = new_notes
2100
+ new_pm.instruments.append(new_inst)
2101
+ new_pm.remove_invalid_notes()
2102
+ return new_pm
2103
+
2104
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
2105
+ """
2106
+ Saves the tokenizer's vocabulary dictionary to the provided save_directory.
2107
+
2108
+ Args:
2109
+ save_directory (`str`):
2110
+ A path to the directory where to saved. It will be created if it doesn't exist.
2111
+ filename_prefix (`Optional[str]`, *optional*):
2112
+ A prefix to add to the names of the files saved by the tokenizer.
2113
+ """
2114
+ if not os.path.isdir(save_directory):
2115
+ print(f"Vocabulary path ({save_directory}) should be a directory")
2116
+ return
2117
+
2118
+ # Save the encoder.
2119
+ out_vocab_file = os.path.join(
2120
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
2121
+ )
2122
+ with open(out_vocab_file, "w") as file:
2123
+ file.write(json_dumps(self.encoder))
2124
+
2125
+ return (out_vocab_file,)
2126
+
2127
+ def encode_plus(
2128
+ self,
2129
+ notes: Union[np.ndarray, list[pretty_midi_fix.Note]],
2130
+ truncation_strategy: Optional[TruncationStrategy] = None,
2131
+ max_length: Optional[int] = None,
2132
+ **kwargs,
2133
+ ) -> BatchEncoding:
2134
+ r"""
2135
+ This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
2136
+ generated token ids. It only works on a single batch, to process multiple batches please use
2137
+ `batch_encode_plus` or `__call__` method.
2138
+
2139
+ Args:
2140
+ notes (`np.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
2141
+ This represents the midi notes. If `notes` is a `np.ndarray`:
2142
+ - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
2143
+ If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
2144
+ - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
2145
+ truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
2146
+ Indicates the truncation strategy that is going to be used during truncation.
2147
+ max_length (`int`, *optional*):
2148
+ Maximum length of the returned list and optionally padding length (see above).
2149
+
2150
+ Returns:
2151
+ `BatchEncoding` containing the tokens ids.
2152
+ """
2153
+ # check if notes is a pretty_midi_fix object or not, if yes then extract the attributes and put them into a numpy
2154
+ # array.
2155
+ if isinstance(notes[0], pretty_midi_fix.Note):
2156
+ notes = np.array(
2157
+ [[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes]
2158
+ ).reshape(-1, 4)
2159
+
2160
+ # to round up all the values to the closest int values.
2161
+ notes = np.round(notes).astype(np.int32)
2162
+ max_time_idx = notes[:, :2].max()
2163
+
2164
+ times = [[] for i in range(max_time_idx + 1)]
2165
+ for onset, offset, pitch, velocity in notes:
2166
+ times[onset].append([pitch, velocity])
2167
+ times[offset].append([pitch, 0])
2168
+
2169
+ tokens = []
2170
+ current_velocity = 0
2171
+ for i, time in enumerate(times):
2172
+ if len(time) == 0:
2173
+ continue
2174
+ tokens.append(self._convert_token_to_id(i, "TOKEN_TIME"))
2175
+ for pitch, velocity in time:
2176
+ velocity = int(velocity > 0)
2177
+ if current_velocity != velocity:
2178
+ current_velocity = velocity
2179
+ tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY"))
2180
+ tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE"))
2181
+
2182
+ total_len = len(tokens)
2183
+
2184
+ # truncation
2185
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
2186
+ tokens, _, _ = self.truncate_sequences(
2187
+ ids=tokens,
2188
+ num_tokens_to_remove=total_len - max_length,
2189
+ truncation_strategy=truncation_strategy,
2190
+ **kwargs,
2191
+ )
2192
+
2193
+ return BatchEncoding({"token_ids": tokens})
2194
+
2195
+ def batch_encode_plus(
2196
+ self,
2197
+ notes: Union[np.ndarray, list[pretty_midi_fix.Note]],
2198
+ truncation_strategy: Optional[TruncationStrategy] = None,
2199
+ max_length: Optional[int] = None,
2200
+ **kwargs,
2201
+ ) -> BatchEncoding:
2202
+ r"""
2203
+ This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
2204
+ generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.
2205
+
2206
+ Args:
2207
+ notes (`np.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
2208
+ This represents the midi notes. If `notes` is a `np.ndarray`:
2209
+ - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
2210
+ If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
2211
+ - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
2212
+ truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
2213
+ Indicates the truncation strategy that is going to be used during truncation.
2214
+ max_length (`int`, *optional*):
2215
+ Maximum length of the returned list and optionally padding length (see above).
2216
+
2217
+ Returns:
2218
+ `BatchEncoding` containing the tokens ids.
2219
+ """
2220
+
2221
+ encoded_batch_token_ids = []
2222
+ for i in range(len(notes)):
2223
+ encoded_batch_token_ids.append(
2224
+ self.encode_plus(
2225
+ notes[i],
2226
+ truncation_strategy=truncation_strategy,
2227
+ max_length=max_length,
2228
+ **kwargs,
2229
+ )["token_ids"]
2230
+ )
2231
+
2232
+ return BatchEncoding({"token_ids": encoded_batch_token_ids})
2233
+
2234
+ def __call__(
2235
+ self,
2236
+ notes: Union[
2237
+ np.ndarray,
2238
+ list[pretty_midi_fix.Note],
2239
+ list[list[pretty_midi_fix.Note]],
2240
+ ],
2241
+ padding: Union[bool, str, PaddingStrategy] = False,
2242
+ truncation: Union[bool, str, TruncationStrategy] = None,
2243
+ max_length: Optional[int] = None,
2244
+ pad_to_multiple_of: Optional[int] = None,
2245
+ return_attention_mask: Optional[bool] = None,
2246
+ return_tensors: Optional[Union[str, TensorType]] = None,
2247
+ verbose: bool = True,
2248
+ **kwargs,
2249
+ ) -> BatchEncoding:
2250
+ r"""
2251
+ This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
2252
+ token ids.
2253
+
2254
+ Args:
2255
+ notes (`np.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
2256
+ This represents the midi notes.
2257
+
2258
+ If `notes` is a `np.ndarray`:
2259
+ - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
2260
+ If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
2261
+ - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
2262
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
2263
+ Activates and controls padding. Accepts the following values:
2264
+
2265
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
2266
+ sequence if provided).
2267
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
2268
+ acceptable input length for the model if that argument is not provided.
2269
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
2270
+ lengths).
2271
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
2272
+ Activates and controls truncation. Accepts the following values:
2273
+
2274
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
2275
+ to the maximum acceptable input length for the model if that argument is not provided. This will
2276
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
2277
+ sequences (or a batch of pairs) is provided.
2278
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
2279
+ maximum acceptable input length for the model if that argument is not provided. This will only
2280
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
2281
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
2282
+ maximum acceptable input length for the model if that argument is not provided. This will only
2283
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
2284
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
2285
+ greater than the model maximum admissible input size).
2286
+ max_length (`int`, *optional*):
2287
+ Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
2288
+ `None`, this will use the predefined model maximum length if a maximum length is required by one of the
2289
+ truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
2290
+ truncation/padding to a maximum length will be deactivated.
2291
+ pad_to_multiple_of (`int`, *optional*):
2292
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
2293
+ the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
2294
+ return_attention_mask (`bool`, *optional*):
2295
+ Whether to return the attention mask. If left to the default, will return the attention mask according
2296
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
2297
+
2298
+ [What are attention masks?](../glossary#attention-mask)
2299
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
2300
+ If set, will return tensors instead of list of python integers. Acceptable values are:
2301
+
2302
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
2303
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
2304
+ - `'np'`: Return Numpy `np.ndarray` objects.
2305
+ verbose (`bool`, *optional*, defaults to `True`):
2306
+ Whether or not to print more information and warnings.
2307
+
2308
+ Returns:
2309
+ `BatchEncoding` containing the token_ids.
2310
+ """
2311
+
2312
+ # check if it is batched or not
2313
+ # it is batched if its a list containing a list of `pretty_midi_fix.Notes` where the outer list contains all the
2314
+ # batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be
2315
+ # considered batched if it has shape of `[batch_size, seqence_length, 4]` or ndim=3.
2316
+ is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list)
2317
+
2318
+ # get the truncation and padding strategy
2319
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
2320
+ padding=padding,
2321
+ truncation=truncation,
2322
+ max_length=max_length,
2323
+ pad_to_multiple_of=pad_to_multiple_of,
2324
+ verbose=verbose,
2325
+ **kwargs,
2326
+ )
2327
+
2328
+ if is_batched:
2329
+ # If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True
2330
+ return_attention_mask = True if return_attention_mask is None else return_attention_mask
2331
+ token_ids = self.batch_encode_plus(
2332
+ notes=notes,
2333
+ truncation_strategy=truncation_strategy,
2334
+ max_length=max_length,
2335
+ **kwargs,
2336
+ )
2337
+ else:
2338
+ token_ids = self.encode_plus(
2339
+ notes=notes,
2340
+ truncation_strategy=truncation_strategy,
2341
+ max_length=max_length,
2342
+ **kwargs,
2343
+ )
2344
+
2345
+ # since we already have truncated sequnences we are just left to do padding
2346
+ token_ids = self.pad(
2347
+ token_ids,
2348
+ padding=padding_strategy,
2349
+ max_length=max_length,
2350
+ pad_to_multiple_of=pad_to_multiple_of,
2351
+ return_attention_mask=return_attention_mask,
2352
+ return_tensors=return_tensors,
2353
+ verbose=verbose,
2354
+ )
2355
+
2356
+ return token_ids
2357
+
2358
+ def batch_decode(
2359
+ self,
2360
+ token_ids,
2361
+ feature_extractor_output: BatchFeature,
2362
+ return_midi: bool = True,
2363
+ ):
2364
+ r"""
2365
+ This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
2366
+ transformer to midi_notes and returns them.
2367
+
2368
+ Args:
2369
+ token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`):
2370
+ Output token_ids of `Pop2PianoConditionalGeneration` model.
2371
+ feature_extractor_output (`BatchFeature`):
2372
+ Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
2373
+ `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
2374
+ `"attention_mask_extrapolated_beatstep"`
2375
+ should be present if they were returned by the feature extractor.
2376
+ return_midi (`bool`, *optional*, defaults to `True`):
2377
+ Whether to return midi object or not.
2378
+ Returns:
2379
+ If `return_midi` is True:
2380
+ - `BatchEncoding` containing both `notes` and `pretty_midi_fix.pretty_midi_fix.PrettyMIDI` objects.
2381
+ If `return_midi` is False:
2382
+ - `BatchEncoding` containing `notes`.
2383
+ """
2384
+
2385
+ # check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not
2386
+ attention_masks_present = bool(
2387
+ hasattr(feature_extractor_output, "attention_mask")
2388
+ and hasattr(feature_extractor_output, "attention_mask_beatsteps")
2389
+ and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep")
2390
+ )
2391
+
2392
+ # if we are processing batched inputs then we must need attention_masks
2393
+ if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1:
2394
+ raise ValueError(
2395
+ "attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present "
2396
+ "for batched inputs! But one of them were not present."
2397
+ )
2398
+
2399
+ # check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep
2400
+ if attention_masks_present:
2401
+ # since we know about the number of examples in token_ids from attention_mask
2402
+ if (
2403
+ sum(feature_extractor_output["attention_mask"][:, 0] == 0)
2404
+ != feature_extractor_output["beatsteps"].shape[0]
2405
+ or feature_extractor_output["beatsteps"].shape[0]
2406
+ != feature_extractor_output["extrapolated_beatstep"].shape[0]
2407
+ ):
2408
+ raise ValueError(
2409
+ "Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found "
2410
+ f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} "
2411
+ f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}"
2412
+ )
2413
+ if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]:
2414
+ raise ValueError(
2415
+ f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}"
2416
+ )
2417
+ else:
2418
+ # if there is no attention mask present then it's surely a single example
2419
+ if (
2420
+ feature_extractor_output["beatsteps"].shape[0] != 1
2421
+ or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1
2422
+ ):
2423
+ raise ValueError(
2424
+ "Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, "
2425
+ f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}."
2426
+ )
2427
+
2428
+ if attention_masks_present:
2429
+ # check for zeros(since token_ids are separated by zero arrays)
2430
+ batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0]
2431
+ else:
2432
+ batch_idx = [token_ids.shape[0]]
2433
+
2434
+ notes_list = []
2435
+ pretty_midi_fix_objects_list = []
2436
+ start_idx = 0
2437
+ for index, end_idx in enumerate(batch_idx):
2438
+ each_tokens_ids = token_ids[start_idx:end_idx]
2439
+ # check where the whole example ended by searching for eos_token_id and getting the upper bound
2440
+ each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1]
2441
+ beatsteps = feature_extractor_output["beatsteps"][index]
2442
+ extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index]
2443
+
2444
+ # if attention mask is present then mask out real array/tensor
2445
+ if attention_masks_present:
2446
+ attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index]
2447
+ attention_mask_extrapolated_beatstep = feature_extractor_output[
2448
+ "attention_mask_extrapolated_beatstep"
2449
+ ][index]
2450
+ beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1]
2451
+ extrapolated_beatstep = extrapolated_beatstep[
2452
+ : np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1
2453
+ ]
2454
+
2455
+ each_tokens_ids = to_numpy(each_tokens_ids)
2456
+ beatsteps = to_numpy(beatsteps)
2457
+ extrapolated_beatstep = to_numpy(extrapolated_beatstep)
2458
+
2459
+ pretty_midi_fix_object = self.relative_batch_tokens_ids_to_midi(
2460
+ tokens=each_tokens_ids,
2461
+ beatstep=extrapolated_beatstep,
2462
+ bars_per_batch=self.num_bars,
2463
+ cutoff_time_idx=(self.num_bars + 1) * 4,
2464
+ )
2465
+
2466
+ for note in pretty_midi_fix_object.instruments[0].notes:
2467
+ note.start += beatsteps[0]
2468
+ note.end += beatsteps[0]
2469
+ notes_list.append(note)
2470
+
2471
+ pretty_midi_fix_objects_list.append(pretty_midi_fix_object)
2472
+ start_idx += end_idx + 1 # 1 represents the zero array
2473
+
2474
+ if return_midi:
2475
+ return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_fix_objects_list})
2476
+
2477
+ return BatchEncoding({"notes": notes_list})
2478
+
2479
+ class Pop2PianoProcessor(ProcessorMixin):
2480
+ r"""
2481
+ Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
2482
+ processor.
2483
+
2484
+ [`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
2485
+ See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.
2486
+
2487
+ Args:
2488
+ feature_extractor (`Pop2PianoFeatureExtractor`):
2489
+ An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
2490
+ tokenizer (`Pop2PianoTokenizer`):
2491
+ An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
2492
+ """
2493
+
2494
+ attributes = ["feature_extractor", "tokenizer"]
2495
+ feature_extractor_class = "Pop2PianoFeatureExtractor"
2496
+ tokenizer_class = "Pop2PianoTokenizer"
2497
+
2498
+ def __init__(self, feature_extractor, tokenizer):
2499
+ super().__init__(feature_extractor, tokenizer)
2500
+
2501
+ def __call__(
2502
+ self,
2503
+ audio: Union[np.ndarray, list[float], list[np.ndarray]] = None,
2504
+ sampling_rate: Optional[Union[int, list[int]]] = None,
2505
+ steps_per_beat: int = 2,
2506
+ resample: Optional[bool] = True,
2507
+ notes: Union[list, TensorType] = None,
2508
+ padding: Union[bool, str, PaddingStrategy] = False,
2509
+ truncation: Union[bool, str, TruncationStrategy] = None,
2510
+ max_length: Optional[int] = None,
2511
+ pad_to_multiple_of: Optional[int] = None,
2512
+ verbose: bool = True,
2513
+ **kwargs,
2514
+ ) -> Union[BatchFeature, BatchEncoding]:
2515
+ """
2516
+ This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model,
2517
+ and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes.
2518
+
2519
+ Please refer to the docstring of the above two methods for more information.
2520
+ """
2521
+
2522
+ # Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
2523
+ # feature_extractor_output, we must check for both.
2524
+ if (audio is None and sampling_rate is None) and (notes is None):
2525
+ raise ValueError(
2526
+ "You have to specify at least audios and sampling_rate in order to use feature extractor or "
2527
+ "notes to use the tokenizer part."
2528
+ )
2529
+
2530
+ if audio is not None and sampling_rate is not None:
2531
+ inputs = self.feature_extractor(
2532
+ audio=audio,
2533
+ sampling_rate=sampling_rate,
2534
+ steps_per_beat=steps_per_beat,
2535
+ resample=resample,
2536
+ **kwargs,
2537
+ )
2538
+
2539
+ if notes is not None:
2540
+ encoded_token_ids = self.tokenizer(
2541
+ notes=notes,
2542
+ padding=padding,
2543
+ truncation=truncation,
2544
+ max_length=max_length,
2545
+ pad_to_multiple_of=pad_to_multiple_of,
2546
+ verbose=verbose,
2547
+ **kwargs,
2548
+ )
2549
+
2550
+ if notes is None:
2551
+ return inputs
2552
+
2553
+ elif audio is None or sampling_rate is None:
2554
+ return encoded_token_ids
2555
+
2556
+ else:
2557
+ inputs["token_ids"] = encoded_token_ids["token_ids"]
2558
+ return inputs
2559
+
2560
+ def batch_decode(
2561
+ self,
2562
+ token_ids,
2563
+ feature_extractor_output: BatchFeature,
2564
+ return_midi: bool = True,
2565
+ ) -> BatchEncoding:
2566
+ """
2567
+ This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
2568
+
2569
+ Please refer to the docstring of the above two methods for more information.
2570
+ """
2571
+
2572
+ return self.tokenizer.batch_decode(
2573
+ token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
2574
+ )
2575
+
2576
+ @property
2577
+ def model_input_names(self):
2578
+ tokenizer_input_names = self.tokenizer.model_input_names
2579
+ feature_extractor_input_names = self.feature_extractor.model_input_names
2580
+ return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
2581
+
2582
+ def save_pretrained(self, save_directory, **kwargs):
2583
+ if os.path.isfile(save_directory):
2584
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
2585
+ os.makedirs(save_directory, exist_ok=True)
2586
+ return super().save_pretrained(save_directory, **kwargs)
2587
+
2588
+ @classmethod
2589
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
2590
+ args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
2591
+ return cls(*args)
2592
+
2593
+
2594
+
2595
+ class Pop2Piano:
2596
+ def __init__(self,device="cpu",model_path=snapshot_download("sweetcocoa/pop2piano")):
2597
+ self.model = Pop2PianoForConditionalGeneration.from_pretrained(model_path).to(device)
2598
+ self.processor = Pop2PianoProcessor.from_pretrained(model_path)
2599
+
2600
+ def predict(self,audio,composer=1,num_bars=2,num_beams=1,steps_per_beat=2,output_file="output.mid"):
2601
+ data, sr = librosa_load(audio, sr=None)
2602
+ inputs = self.processor(data, sr, steps_per_beat,return_tensors="pt",num_bars=num_bars)
2603
+ self.processor.batch_decode(self.model.generate(num_beams=num_beams,do_sample=True,input_features=inputs["input_features"], composer="composer" + str(composer)),inputs)["pretty_midi_objects"][0].write(open(output_file, "wb"))
2604
+ return output_file