audio2midi 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audio2midi/__init__.py +0 -0
- audio2midi/basic_pitch_pitch_detector.py +783 -0
- audio2midi/crepe_pitch_detector.py +130 -0
- audio2midi/librosa_pitch_detector.py +153 -0
- audio2midi/melodia_pitch_detector.py +58 -0
- audio2midi/pop2piano.py +2604 -0
- audio2midi/py.typed +0 -0
- audio2midi/violin_pitch_detector.py +1281 -0
- audio2midi-0.1.0.dist-info/METADATA +100 -0
- audio2midi-0.1.0.dist-info/RECORD +11 -0
- audio2midi-0.1.0.dist-info/WHEEL +5 -0
audio2midi/pop2piano.py
ADDED
@@ -0,0 +1,2604 @@
|
|
1
|
+
import torch
|
2
|
+
import copy
|
3
|
+
import os
|
4
|
+
import numpy as np
|
5
|
+
import pretty_midi_fix
|
6
|
+
from librosa.core import resample as librosa_resample
|
7
|
+
from scipy.interpolate import interp1d
|
8
|
+
from json import load as json_load , dumps as json_dumps
|
9
|
+
from math import log as math_log
|
10
|
+
from typing import Optional, Union
|
11
|
+
from torch import nn
|
12
|
+
from librosa import load as librosa_load
|
13
|
+
from huggingface_hub import snapshot_download
|
14
|
+
from essentia.standard import RhythmExtractor2013
|
15
|
+
from transformers.generation import GenerationConfig , GenerationMixin
|
16
|
+
from transformers.activations import ACT2FN
|
17
|
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
18
|
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
|
19
|
+
from transformers.utils import is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling , TensorType, to_numpy
|
20
|
+
from transformers.modeling_utils import PreTrainedModel
|
21
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
22
|
+
from transformers.feature_extraction_utils import BatchFeature
|
23
|
+
from transformers.processing_utils import ProcessorMixin
|
24
|
+
from transformers.configuration_utils import PretrainedConfig
|
25
|
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
26
|
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
27
|
+
from transformers.audio_utils import mel_filter_bank, spectrogram
|
28
|
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
29
|
+
from transformers.feature_extraction_utils import BatchFeature
|
30
|
+
from transformers.feature_extraction_utils import BatchFeature
|
31
|
+
from transformers.tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy
|
32
|
+
|
33
|
+
if is_torch_flex_attn_available():
|
34
|
+
from torch.nn.attention.flex_attention import BlockMask
|
35
|
+
from transformers.integrations.flex_attention import make_flex_block_causal_mask
|
36
|
+
|
37
|
+
|
38
|
+
class Pop2PianoConfig(PretrainedConfig):
|
39
|
+
model_type = "pop2piano"
|
40
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
vocab_size=2400,
|
44
|
+
composer_vocab_size=21,
|
45
|
+
d_model=512,
|
46
|
+
d_kv=64,
|
47
|
+
d_ff=2048,
|
48
|
+
num_layers=6,
|
49
|
+
num_decoder_layers=None,
|
50
|
+
num_heads=8,
|
51
|
+
relative_attention_num_buckets=32,
|
52
|
+
relative_attention_max_distance=128,
|
53
|
+
dropout_rate=0.1,
|
54
|
+
layer_norm_epsilon=1e-6,
|
55
|
+
initializer_factor=1.0,
|
56
|
+
feed_forward_proj="gated-gelu", # noqa
|
57
|
+
is_encoder_decoder=True,
|
58
|
+
use_cache=True,
|
59
|
+
pad_token_id=0,
|
60
|
+
eos_token_id=1,
|
61
|
+
dense_act_fn="relu",
|
62
|
+
**kwargs,
|
63
|
+
):
|
64
|
+
self.vocab_size = vocab_size
|
65
|
+
self.composer_vocab_size = composer_vocab_size
|
66
|
+
self.d_model = d_model
|
67
|
+
self.d_kv = d_kv
|
68
|
+
self.d_ff = d_ff
|
69
|
+
self.num_layers = num_layers
|
70
|
+
self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
71
|
+
self.num_heads = num_heads
|
72
|
+
self.relative_attention_num_buckets = relative_attention_num_buckets
|
73
|
+
self.relative_attention_max_distance = relative_attention_max_distance
|
74
|
+
self.dropout_rate = dropout_rate
|
75
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
76
|
+
self.initializer_factor = initializer_factor
|
77
|
+
self.feed_forward_proj = feed_forward_proj
|
78
|
+
self.use_cache = use_cache
|
79
|
+
self.dense_act_fn = dense_act_fn
|
80
|
+
self.is_gated_act = self.feed_forward_proj.split("-")[0] == "gated"
|
81
|
+
self.hidden_size = self.d_model
|
82
|
+
self.num_attention_heads = num_heads
|
83
|
+
self.num_hidden_layers = num_layers
|
84
|
+
|
85
|
+
super().__init__(
|
86
|
+
pad_token_id=pad_token_id,
|
87
|
+
eos_token_id=eos_token_id,
|
88
|
+
is_encoder_decoder=is_encoder_decoder,
|
89
|
+
**kwargs,
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano
|
94
|
+
class Pop2PianoLayerNorm(nn.Module):
|
95
|
+
def __init__(self, hidden_size, eps=1e-6):
|
96
|
+
"""
|
97
|
+
Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
|
98
|
+
"""
|
99
|
+
super().__init__()
|
100
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
101
|
+
self.variance_epsilon = eps
|
102
|
+
|
103
|
+
def forward(self, hidden_states):
|
104
|
+
# Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
105
|
+
# Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
|
106
|
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
107
|
+
# half-precision inputs is done in fp32
|
108
|
+
|
109
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
110
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
111
|
+
|
112
|
+
# convert into half-precision if necessary
|
113
|
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
114
|
+
hidden_states = hidden_states.to(self.weight.dtype)
|
115
|
+
|
116
|
+
return self.weight * hidden_states
|
117
|
+
|
118
|
+
# from apex.normalization import FusedRMSNorm
|
119
|
+
# Pop2PianoLayerNorm = FusedRMSNorm # noqa
|
120
|
+
# # Other Approach
|
121
|
+
|
122
|
+
# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Pop2Piano,t5->pop2piano
|
123
|
+
class Pop2PianoDenseActDense(nn.Module):
|
124
|
+
def __init__(self, config: Pop2PianoConfig):
|
125
|
+
super().__init__()
|
126
|
+
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
127
|
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
128
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
129
|
+
self.act = ACT2FN[config.dense_act_fn]
|
130
|
+
|
131
|
+
def forward(self, hidden_states):
|
132
|
+
hidden_states = self.wi(hidden_states)
|
133
|
+
hidden_states = self.act(hidden_states)
|
134
|
+
hidden_states = self.dropout(hidden_states)
|
135
|
+
if (
|
136
|
+
isinstance(self.wo.weight, torch.Tensor)
|
137
|
+
and hidden_states.dtype != self.wo.weight.dtype
|
138
|
+
and self.wo.weight.dtype != torch.int8
|
139
|
+
):
|
140
|
+
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
141
|
+
hidden_states = self.wo(hidden_states)
|
142
|
+
return hidden_states
|
143
|
+
|
144
|
+
|
145
|
+
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pop2Piano
|
146
|
+
class Pop2PianoDenseGatedActDense(nn.Module):
|
147
|
+
def __init__(self, config: Pop2PianoConfig):
|
148
|
+
super().__init__()
|
149
|
+
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
150
|
+
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
151
|
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
152
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
153
|
+
self.act = ACT2FN[config.dense_act_fn]
|
154
|
+
|
155
|
+
def forward(self, hidden_states):
|
156
|
+
hidden_gelu = self.act(self.wi_0(hidden_states))
|
157
|
+
hidden_linear = self.wi_1(hidden_states)
|
158
|
+
hidden_states = hidden_gelu * hidden_linear
|
159
|
+
hidden_states = self.dropout(hidden_states)
|
160
|
+
|
161
|
+
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
162
|
+
# See https://github.com/huggingface/transformers/issues/20287
|
163
|
+
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
164
|
+
if (
|
165
|
+
isinstance(self.wo.weight, torch.Tensor)
|
166
|
+
and hidden_states.dtype != self.wo.weight.dtype
|
167
|
+
and self.wo.weight.dtype != torch.int8
|
168
|
+
):
|
169
|
+
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
170
|
+
|
171
|
+
hidden_states = self.wo(hidden_states)
|
172
|
+
return hidden_states
|
173
|
+
|
174
|
+
|
175
|
+
# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Pop2Piano
|
176
|
+
class Pop2PianoLayerFF(nn.Module):
|
177
|
+
def __init__(self, config: Pop2PianoConfig):
|
178
|
+
super().__init__()
|
179
|
+
if config.is_gated_act:
|
180
|
+
self.DenseReluDense = Pop2PianoDenseGatedActDense(config)
|
181
|
+
else:
|
182
|
+
self.DenseReluDense = Pop2PianoDenseActDense(config)
|
183
|
+
|
184
|
+
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
185
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
186
|
+
|
187
|
+
def forward(self, hidden_states):
|
188
|
+
forwarded_states = self.layer_norm(hidden_states)
|
189
|
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
190
|
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
191
|
+
return hidden_states
|
192
|
+
|
193
|
+
|
194
|
+
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano
|
195
|
+
class Pop2PianoAttention(nn.Module):
|
196
|
+
def __init__(
|
197
|
+
self,
|
198
|
+
config: Pop2PianoConfig,
|
199
|
+
has_relative_attention_bias=False,
|
200
|
+
layer_idx: Optional[int] = None,
|
201
|
+
):
|
202
|
+
super().__init__()
|
203
|
+
self.is_decoder = config.is_decoder
|
204
|
+
self.has_relative_attention_bias = has_relative_attention_bias
|
205
|
+
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
206
|
+
self.relative_attention_max_distance = config.relative_attention_max_distance
|
207
|
+
self.d_model = config.d_model
|
208
|
+
self.key_value_proj_dim = config.d_kv
|
209
|
+
self.n_heads = config.num_heads
|
210
|
+
self.dropout = config.dropout_rate
|
211
|
+
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
212
|
+
self.layer_idx = layer_idx
|
213
|
+
if layer_idx is None and self.is_decoder:
|
214
|
+
print(
|
215
|
+
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
216
|
+
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
217
|
+
"when creating this class."
|
218
|
+
)
|
219
|
+
|
220
|
+
# Mesh TensorFlow initialization to avoid scaling before softmax
|
221
|
+
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
222
|
+
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
223
|
+
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
224
|
+
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
225
|
+
|
226
|
+
if self.has_relative_attention_bias:
|
227
|
+
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
228
|
+
self.pruned_heads = set()
|
229
|
+
self.gradient_checkpointing = False
|
230
|
+
|
231
|
+
def prune_heads(self, heads):
|
232
|
+
if len(heads) == 0:
|
233
|
+
return
|
234
|
+
heads, index = find_pruneable_heads_and_indices(
|
235
|
+
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
|
236
|
+
)
|
237
|
+
# Prune linear layers
|
238
|
+
self.q = prune_linear_layer(self.q, index)
|
239
|
+
self.k = prune_linear_layer(self.k, index)
|
240
|
+
self.v = prune_linear_layer(self.v, index)
|
241
|
+
self.o = prune_linear_layer(self.o, index, dim=1)
|
242
|
+
# Update hyper params
|
243
|
+
self.n_heads = self.n_heads - len(heads)
|
244
|
+
self.inner_dim = self.key_value_proj_dim * self.n_heads
|
245
|
+
self.pruned_heads = self.pruned_heads.union(heads)
|
246
|
+
|
247
|
+
@staticmethod
|
248
|
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
249
|
+
"""
|
250
|
+
Adapted from Mesh Tensorflow:
|
251
|
+
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
252
|
+
|
253
|
+
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
254
|
+
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
255
|
+
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
256
|
+
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
257
|
+
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
258
|
+
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
259
|
+
|
260
|
+
Args:
|
261
|
+
relative_position: an int32 Tensor
|
262
|
+
bidirectional: a boolean - whether the attention is bidirectional
|
263
|
+
num_buckets: an integer
|
264
|
+
max_distance: an integer
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
268
|
+
"""
|
269
|
+
relative_buckets = 0
|
270
|
+
if bidirectional:
|
271
|
+
num_buckets //= 2
|
272
|
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
273
|
+
relative_position = torch.abs(relative_position)
|
274
|
+
else:
|
275
|
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
276
|
+
# now relative_position is in the range [0, inf)
|
277
|
+
|
278
|
+
# half of the buckets are for exact increments in positions
|
279
|
+
max_exact = num_buckets // 2
|
280
|
+
is_small = relative_position < max_exact
|
281
|
+
|
282
|
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
283
|
+
relative_position_if_large = max_exact + (
|
284
|
+
torch.log(relative_position.float() / max_exact)
|
285
|
+
/ math_log(max_distance / max_exact)
|
286
|
+
* (num_buckets - max_exact)
|
287
|
+
).to(torch.long)
|
288
|
+
relative_position_if_large = torch.min(
|
289
|
+
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
290
|
+
)
|
291
|
+
|
292
|
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
293
|
+
return relative_buckets
|
294
|
+
|
295
|
+
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
|
296
|
+
"""Compute binned relative position bias"""
|
297
|
+
if device is None:
|
298
|
+
device = self.relative_attention_bias.weight.device
|
299
|
+
if cache_position is None:
|
300
|
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
301
|
+
else:
|
302
|
+
context_position = cache_position[:, None].to(device)
|
303
|
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
304
|
+
relative_position = memory_position - context_position # shape (query_length, key_length)
|
305
|
+
relative_position_bucket = self._relative_position_bucket(
|
306
|
+
relative_position, # shape (query_length, key_length)
|
307
|
+
bidirectional=(not self.is_decoder),
|
308
|
+
num_buckets=self.relative_attention_num_buckets,
|
309
|
+
max_distance=self.relative_attention_max_distance,
|
310
|
+
)
|
311
|
+
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
312
|
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
313
|
+
return values
|
314
|
+
|
315
|
+
def forward(
|
316
|
+
self,
|
317
|
+
hidden_states,
|
318
|
+
mask=None,
|
319
|
+
key_value_states=None,
|
320
|
+
position_bias=None,
|
321
|
+
past_key_value=None,
|
322
|
+
layer_head_mask=None,
|
323
|
+
query_length=None,
|
324
|
+
use_cache=False,
|
325
|
+
output_attentions=False,
|
326
|
+
cache_position=None,
|
327
|
+
):
|
328
|
+
"""
|
329
|
+
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
330
|
+
"""
|
331
|
+
# Input is (batch_size, seq_length, dim)
|
332
|
+
# Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
|
333
|
+
batch_size, seq_length = hidden_states.shape[:2]
|
334
|
+
|
335
|
+
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
|
336
|
+
is_cross_attention = key_value_states is not None
|
337
|
+
|
338
|
+
query_states = self.q(hidden_states)
|
339
|
+
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
340
|
+
|
341
|
+
if past_key_value is not None:
|
342
|
+
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
343
|
+
if is_cross_attention:
|
344
|
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
345
|
+
curr_past_key_value = past_key_value.cross_attention_cache
|
346
|
+
else:
|
347
|
+
curr_past_key_value = past_key_value.self_attention_cache
|
348
|
+
|
349
|
+
current_states = key_value_states if is_cross_attention else hidden_states
|
350
|
+
if is_cross_attention and past_key_value is not None and is_updated:
|
351
|
+
# reuse k,v, cross_attentions
|
352
|
+
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
353
|
+
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
354
|
+
else:
|
355
|
+
key_states = self.k(current_states)
|
356
|
+
value_states = self.v(current_states)
|
357
|
+
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
358
|
+
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
359
|
+
|
360
|
+
if past_key_value is not None:
|
361
|
+
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
362
|
+
cache_position = cache_position if not is_cross_attention else None
|
363
|
+
key_states, value_states = curr_past_key_value.update(
|
364
|
+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
365
|
+
)
|
366
|
+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
367
|
+
if is_cross_attention:
|
368
|
+
past_key_value.is_updated[self.layer_idx] = True
|
369
|
+
|
370
|
+
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
371
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
372
|
+
|
373
|
+
if position_bias is None:
|
374
|
+
key_length = key_states.shape[-2]
|
375
|
+
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
|
376
|
+
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
|
377
|
+
if not self.has_relative_attention_bias:
|
378
|
+
position_bias = torch.zeros(
|
379
|
+
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
|
380
|
+
)
|
381
|
+
if self.gradient_checkpointing and self.training:
|
382
|
+
position_bias.requires_grad = True
|
383
|
+
else:
|
384
|
+
position_bias = self.compute_bias(
|
385
|
+
real_seq_length, key_length, device=scores.device, cache_position=cache_position
|
386
|
+
)
|
387
|
+
position_bias = position_bias[:, :, -seq_length:, :]
|
388
|
+
|
389
|
+
if mask is not None:
|
390
|
+
causal_mask = mask[:, :, :, : key_states.shape[-2]]
|
391
|
+
position_bias = position_bias + causal_mask
|
392
|
+
|
393
|
+
if self.pruned_heads:
|
394
|
+
mask = torch.ones(position_bias.shape[1])
|
395
|
+
mask[list(self.pruned_heads)] = 0
|
396
|
+
position_bias_masked = position_bias[:, mask.bool()]
|
397
|
+
else:
|
398
|
+
position_bias_masked = position_bias
|
399
|
+
|
400
|
+
scores += position_bias_masked
|
401
|
+
|
402
|
+
# (batch_size, n_heads, seq_length, key_length)
|
403
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
404
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
405
|
+
|
406
|
+
# Mask heads if we want to
|
407
|
+
if layer_head_mask is not None:
|
408
|
+
attn_weights = attn_weights * layer_head_mask
|
409
|
+
|
410
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
411
|
+
|
412
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
413
|
+
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
|
414
|
+
attn_output = self.o(attn_output)
|
415
|
+
|
416
|
+
outputs = (attn_output, past_key_value, position_bias)
|
417
|
+
|
418
|
+
if output_attentions:
|
419
|
+
outputs = outputs + (attn_weights,)
|
420
|
+
return outputs
|
421
|
+
|
422
|
+
|
423
|
+
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano
|
424
|
+
class Pop2PianoLayerSelfAttention(nn.Module):
|
425
|
+
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
|
426
|
+
super().__init__()
|
427
|
+
self.SelfAttention = Pop2PianoAttention(
|
428
|
+
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
|
429
|
+
)
|
430
|
+
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
431
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
432
|
+
|
433
|
+
def forward(
|
434
|
+
self,
|
435
|
+
hidden_states,
|
436
|
+
attention_mask=None,
|
437
|
+
position_bias=None,
|
438
|
+
layer_head_mask=None,
|
439
|
+
past_key_value=None,
|
440
|
+
use_cache=False,
|
441
|
+
output_attentions=False,
|
442
|
+
cache_position=None,
|
443
|
+
):
|
444
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
445
|
+
attention_output = self.SelfAttention(
|
446
|
+
normed_hidden_states,
|
447
|
+
mask=attention_mask,
|
448
|
+
position_bias=position_bias,
|
449
|
+
layer_head_mask=layer_head_mask,
|
450
|
+
past_key_value=past_key_value,
|
451
|
+
use_cache=use_cache,
|
452
|
+
output_attentions=output_attentions,
|
453
|
+
cache_position=cache_position,
|
454
|
+
)
|
455
|
+
hidden_states = hidden_states + self.dropout(attention_output[0])
|
456
|
+
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
457
|
+
return outputs
|
458
|
+
|
459
|
+
|
460
|
+
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano
|
461
|
+
class Pop2PianoLayerCrossAttention(nn.Module):
|
462
|
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
463
|
+
super().__init__()
|
464
|
+
self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
|
465
|
+
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
466
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
467
|
+
|
468
|
+
def forward(
|
469
|
+
self,
|
470
|
+
hidden_states,
|
471
|
+
key_value_states,
|
472
|
+
attention_mask=None,
|
473
|
+
position_bias=None,
|
474
|
+
layer_head_mask=None,
|
475
|
+
past_key_value=None,
|
476
|
+
use_cache=False,
|
477
|
+
query_length=None,
|
478
|
+
output_attentions=False,
|
479
|
+
cache_position=None,
|
480
|
+
):
|
481
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
482
|
+
attention_output = self.EncDecAttention(
|
483
|
+
normed_hidden_states,
|
484
|
+
mask=attention_mask,
|
485
|
+
key_value_states=key_value_states,
|
486
|
+
position_bias=position_bias,
|
487
|
+
layer_head_mask=layer_head_mask,
|
488
|
+
past_key_value=past_key_value,
|
489
|
+
use_cache=use_cache,
|
490
|
+
query_length=query_length,
|
491
|
+
output_attentions=output_attentions,
|
492
|
+
cache_position=cache_position,
|
493
|
+
)
|
494
|
+
layer_output = hidden_states + self.dropout(attention_output[0])
|
495
|
+
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
496
|
+
return outputs
|
497
|
+
|
498
|
+
# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano
|
499
|
+
class Pop2PianoBlock(GradientCheckpointingLayer):
|
500
|
+
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
|
501
|
+
super().__init__()
|
502
|
+
self.is_decoder = config.is_decoder
|
503
|
+
self.layer = nn.ModuleList()
|
504
|
+
self.layer.append(
|
505
|
+
Pop2PianoLayerSelfAttention(
|
506
|
+
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
|
507
|
+
)
|
508
|
+
)
|
509
|
+
if self.is_decoder:
|
510
|
+
self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx))
|
511
|
+
|
512
|
+
self.layer.append(Pop2PianoLayerFF(config))
|
513
|
+
|
514
|
+
def forward(
|
515
|
+
self,
|
516
|
+
hidden_states,
|
517
|
+
attention_mask=None,
|
518
|
+
position_bias=None,
|
519
|
+
encoder_hidden_states=None,
|
520
|
+
encoder_attention_mask=None,
|
521
|
+
encoder_decoder_position_bias=None,
|
522
|
+
layer_head_mask=None,
|
523
|
+
cross_attn_layer_head_mask=None,
|
524
|
+
past_key_value=None,
|
525
|
+
use_cache=False,
|
526
|
+
output_attentions=False,
|
527
|
+
return_dict=True,
|
528
|
+
cache_position=None,
|
529
|
+
):
|
530
|
+
self_attention_outputs = self.layer[0](
|
531
|
+
hidden_states,
|
532
|
+
attention_mask=attention_mask,
|
533
|
+
position_bias=position_bias,
|
534
|
+
layer_head_mask=layer_head_mask,
|
535
|
+
past_key_value=past_key_value,
|
536
|
+
use_cache=use_cache,
|
537
|
+
output_attentions=output_attentions,
|
538
|
+
cache_position=cache_position,
|
539
|
+
)
|
540
|
+
hidden_states, past_key_value = self_attention_outputs[:2]
|
541
|
+
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
542
|
+
|
543
|
+
# clamp inf values to enable fp16 training
|
544
|
+
if hidden_states.dtype == torch.float16:
|
545
|
+
clamp_value = torch.where(
|
546
|
+
torch.isinf(hidden_states).any(),
|
547
|
+
torch.finfo(hidden_states.dtype).max - 1000,
|
548
|
+
torch.finfo(hidden_states.dtype).max,
|
549
|
+
)
|
550
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
551
|
+
|
552
|
+
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
553
|
+
if do_cross_attention:
|
554
|
+
cross_attention_outputs = self.layer[1](
|
555
|
+
hidden_states,
|
556
|
+
key_value_states=encoder_hidden_states,
|
557
|
+
attention_mask=encoder_attention_mask,
|
558
|
+
position_bias=encoder_decoder_position_bias,
|
559
|
+
layer_head_mask=cross_attn_layer_head_mask,
|
560
|
+
past_key_value=past_key_value,
|
561
|
+
query_length=cache_position[-1] + 1,
|
562
|
+
use_cache=use_cache,
|
563
|
+
output_attentions=output_attentions,
|
564
|
+
)
|
565
|
+
hidden_states, past_key_value = cross_attention_outputs[:2]
|
566
|
+
|
567
|
+
# clamp inf values to enable fp16 training
|
568
|
+
if hidden_states.dtype == torch.float16:
|
569
|
+
clamp_value = torch.where(
|
570
|
+
torch.isinf(hidden_states).any(),
|
571
|
+
torch.finfo(hidden_states.dtype).max - 1000,
|
572
|
+
torch.finfo(hidden_states.dtype).max,
|
573
|
+
)
|
574
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
575
|
+
|
576
|
+
# Keep cross-attention outputs and relative position weights
|
577
|
+
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
578
|
+
|
579
|
+
# Apply Feed Forward layer
|
580
|
+
hidden_states = self.layer[-1](hidden_states)
|
581
|
+
|
582
|
+
# clamp inf values to enable fp16 training
|
583
|
+
if hidden_states.dtype == torch.float16:
|
584
|
+
clamp_value = torch.where(
|
585
|
+
torch.isinf(hidden_states).any(),
|
586
|
+
torch.finfo(hidden_states.dtype).max - 1000,
|
587
|
+
torch.finfo(hidden_states.dtype).max,
|
588
|
+
)
|
589
|
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
590
|
+
|
591
|
+
outputs = (hidden_states,)
|
592
|
+
|
593
|
+
if use_cache:
|
594
|
+
outputs = outputs + (past_key_value,) + attention_outputs
|
595
|
+
else:
|
596
|
+
outputs = outputs + attention_outputs
|
597
|
+
|
598
|
+
return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
599
|
+
|
600
|
+
class Pop2PianoConcatEmbeddingToMel(nn.Module):
|
601
|
+
"""Embedding Matrix for `composer` tokens."""
|
602
|
+
|
603
|
+
def __init__(self, config):
|
604
|
+
super().__init__()
|
605
|
+
self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model)
|
606
|
+
|
607
|
+
def forward(self, feature, index_value, embedding_offset):
|
608
|
+
index_shifted = index_value - embedding_offset
|
609
|
+
composer_embedding = self.embedding(index_shifted).unsqueeze(1)
|
610
|
+
inputs_embeds = torch.cat([composer_embedding, feature], dim=1)
|
611
|
+
return inputs_embeds
|
612
|
+
|
613
|
+
class Pop2PianoPreTrainedModel(PreTrainedModel):
|
614
|
+
config_class = Pop2PianoConfig
|
615
|
+
base_model_prefix = "transformer"
|
616
|
+
is_parallelizable = False
|
617
|
+
supports_gradient_checkpointing = True
|
618
|
+
_supports_cache_class = True
|
619
|
+
_supports_static_cache = False
|
620
|
+
_no_split_modules = ["Pop2PianoBlock"]
|
621
|
+
_keep_in_fp32_modules = ["wo"]
|
622
|
+
|
623
|
+
def _init_weights(self, module):
|
624
|
+
"""Initialize the weights"""
|
625
|
+
factor = self.config.initializer_factor # Used for testing weights initialization
|
626
|
+
if isinstance(module, Pop2PianoLayerNorm):
|
627
|
+
module.weight.data.fill_(factor * 1.0)
|
628
|
+
elif isinstance(module, Pop2PianoConcatEmbeddingToMel):
|
629
|
+
module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
630
|
+
elif isinstance(module, Pop2PianoForConditionalGeneration):
|
631
|
+
# Mesh TensorFlow embeddings initialization
|
632
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
633
|
+
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
634
|
+
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
635
|
+
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
636
|
+
elif isinstance(module, Pop2PianoDenseActDense):
|
637
|
+
# Mesh TensorFlow FF initialization
|
638
|
+
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
639
|
+
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
640
|
+
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
641
|
+
if hasattr(module.wi, "bias") and module.wi.bias is not None:
|
642
|
+
module.wi.bias.data.zero_()
|
643
|
+
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
644
|
+
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
645
|
+
module.wo.bias.data.zero_()
|
646
|
+
elif isinstance(module, Pop2PianoDenseGatedActDense):
|
647
|
+
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
648
|
+
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
|
649
|
+
module.wi_0.bias.data.zero_()
|
650
|
+
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
651
|
+
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
|
652
|
+
module.wi_1.bias.data.zero_()
|
653
|
+
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
654
|
+
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
655
|
+
module.wo.bias.data.zero_()
|
656
|
+
elif isinstance(module, Pop2PianoAttention):
|
657
|
+
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
658
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
659
|
+
d_model = self.config.d_model
|
660
|
+
key_value_proj_dim = self.config.d_kv
|
661
|
+
n_heads = self.config.num_heads
|
662
|
+
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
663
|
+
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
664
|
+
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
665
|
+
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
|
666
|
+
if module.has_relative_attention_bias:
|
667
|
+
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
668
|
+
|
669
|
+
def _shift_right(self, input_ids):
|
670
|
+
decoder_start_token_id = self.config.decoder_start_token_id
|
671
|
+
pad_token_id = self.config.pad_token_id
|
672
|
+
|
673
|
+
if decoder_start_token_id is None:
|
674
|
+
raise ValueError(
|
675
|
+
"self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id."
|
676
|
+
)
|
677
|
+
|
678
|
+
# shift inputs to the right
|
679
|
+
if is_torch_fx_proxy(input_ids):
|
680
|
+
# Item assignment is not supported natively for proxies.
|
681
|
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
|
682
|
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
683
|
+
else:
|
684
|
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
685
|
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
686
|
+
shifted_input_ids[..., 0] = decoder_start_token_id
|
687
|
+
|
688
|
+
if pad_token_id is None:
|
689
|
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
690
|
+
# replace possible -100 values in labels by `pad_token_id`
|
691
|
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
692
|
+
|
693
|
+
return shifted_input_ids
|
694
|
+
|
695
|
+
class Pop2PianoStack(Pop2PianoPreTrainedModel):
|
696
|
+
# Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano
|
697
|
+
def __init__(self, config, embed_tokens=None):
|
698
|
+
super().__init__(config)
|
699
|
+
|
700
|
+
self.embed_tokens = embed_tokens
|
701
|
+
self.is_decoder = config.is_decoder
|
702
|
+
|
703
|
+
self.block = nn.ModuleList(
|
704
|
+
[
|
705
|
+
Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
|
706
|
+
for i in range(config.num_layers)
|
707
|
+
]
|
708
|
+
)
|
709
|
+
self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
710
|
+
self.dropout = nn.Dropout(config.dropout_rate)
|
711
|
+
|
712
|
+
# Initialize weights and apply final processing
|
713
|
+
self.post_init()
|
714
|
+
# Model parallel
|
715
|
+
self.model_parallel = False
|
716
|
+
self.device_map = None
|
717
|
+
self.gradient_checkpointing = False
|
718
|
+
|
719
|
+
# Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
|
720
|
+
def get_input_embeddings(self):
|
721
|
+
return self.embed_tokens
|
722
|
+
|
723
|
+
# Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
|
724
|
+
def set_input_embeddings(self, new_embeddings):
|
725
|
+
self.embed_tokens = new_embeddings
|
726
|
+
|
727
|
+
def forward(
|
728
|
+
self,
|
729
|
+
input_ids=None,
|
730
|
+
attention_mask=None,
|
731
|
+
encoder_hidden_states=None,
|
732
|
+
encoder_attention_mask=None,
|
733
|
+
inputs_embeds=None,
|
734
|
+
head_mask=None,
|
735
|
+
cross_attn_head_mask=None,
|
736
|
+
past_key_values=None,
|
737
|
+
use_cache=None,
|
738
|
+
output_attentions=None,
|
739
|
+
output_hidden_states=None,
|
740
|
+
return_dict=None,
|
741
|
+
cache_position=None,
|
742
|
+
):
|
743
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
744
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
745
|
+
output_hidden_states = (
|
746
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
747
|
+
)
|
748
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
749
|
+
|
750
|
+
if input_ids is not None and inputs_embeds is not None:
|
751
|
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
752
|
+
raise ValueError(
|
753
|
+
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
754
|
+
)
|
755
|
+
elif input_ids is not None:
|
756
|
+
input_shape = input_ids.size()
|
757
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
758
|
+
elif inputs_embeds is not None:
|
759
|
+
input_shape = inputs_embeds.size()[:-1]
|
760
|
+
else:
|
761
|
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
762
|
+
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
|
763
|
+
|
764
|
+
if self.gradient_checkpointing and self.training:
|
765
|
+
if use_cache:
|
766
|
+
print(
|
767
|
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
768
|
+
)
|
769
|
+
use_cache = False
|
770
|
+
|
771
|
+
if inputs_embeds is None:
|
772
|
+
if self.embed_tokens is None:
|
773
|
+
raise ValueError("You have to initialize the model with valid token embeddings")
|
774
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
775
|
+
|
776
|
+
batch_size, seq_length = input_shape
|
777
|
+
|
778
|
+
if use_cache is True:
|
779
|
+
if not self.is_decoder:
|
780
|
+
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
|
781
|
+
|
782
|
+
# initialize past_key_values
|
783
|
+
return_legacy_cache = False
|
784
|
+
return_self_attention_cache = False
|
785
|
+
if self.is_decoder and (use_cache or past_key_values is not None):
|
786
|
+
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
|
787
|
+
return_self_attention_cache = True
|
788
|
+
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
789
|
+
elif not isinstance(past_key_values, EncoderDecoderCache):
|
790
|
+
return_legacy_cache = True
|
791
|
+
print(
|
792
|
+
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
|
793
|
+
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
794
|
+
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
795
|
+
)
|
796
|
+
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
797
|
+
elif past_key_values is None:
|
798
|
+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
799
|
+
elif not self.is_decoder:
|
800
|
+
# do not pass cache object down the line for encoder stack
|
801
|
+
# it messes indexing later in decoder-stack because cache object is modified in-place
|
802
|
+
past_key_values = None
|
803
|
+
|
804
|
+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
805
|
+
if cache_position is None:
|
806
|
+
cache_position = torch.arange(
|
807
|
+
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
808
|
+
)
|
809
|
+
|
810
|
+
if attention_mask is None and not is_torchdynamo_compiling():
|
811
|
+
# required mask seq length can be calculated via length of past cache
|
812
|
+
mask_seq_length = past_key_values_length + seq_length
|
813
|
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
814
|
+
|
815
|
+
if self.config.is_decoder:
|
816
|
+
causal_mask = self._update_causal_mask(
|
817
|
+
attention_mask,
|
818
|
+
inputs_embeds,
|
819
|
+
cache_position,
|
820
|
+
past_key_values.self_attention_cache if past_key_values is not None else None,
|
821
|
+
output_attentions,
|
822
|
+
)
|
823
|
+
else:
|
824
|
+
causal_mask = attention_mask[:, None, None, :]
|
825
|
+
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
|
826
|
+
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
|
827
|
+
|
828
|
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
829
|
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
830
|
+
if self.is_decoder and encoder_hidden_states is not None:
|
831
|
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
832
|
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
833
|
+
if encoder_attention_mask is None:
|
834
|
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
835
|
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
836
|
+
else:
|
837
|
+
encoder_extended_attention_mask = None
|
838
|
+
|
839
|
+
# Prepare head mask if needed
|
840
|
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
841
|
+
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
842
|
+
all_hidden_states = () if output_hidden_states else None
|
843
|
+
all_attentions = () if output_attentions else None
|
844
|
+
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
845
|
+
position_bias = None
|
846
|
+
encoder_decoder_position_bias = None
|
847
|
+
|
848
|
+
hidden_states = self.dropout(inputs_embeds)
|
849
|
+
|
850
|
+
for i, layer_module in enumerate(self.block):
|
851
|
+
layer_head_mask = head_mask[i]
|
852
|
+
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
853
|
+
if output_hidden_states:
|
854
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
855
|
+
|
856
|
+
layer_outputs = layer_module(
|
857
|
+
hidden_states,
|
858
|
+
causal_mask,
|
859
|
+
position_bias,
|
860
|
+
encoder_hidden_states,
|
861
|
+
encoder_extended_attention_mask,
|
862
|
+
encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
|
863
|
+
layer_head_mask=layer_head_mask,
|
864
|
+
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
865
|
+
past_key_value=past_key_values,
|
866
|
+
use_cache=use_cache,
|
867
|
+
output_attentions=output_attentions,
|
868
|
+
cache_position=cache_position,
|
869
|
+
)
|
870
|
+
|
871
|
+
# layer_outputs is a tuple with:
|
872
|
+
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
873
|
+
if use_cache is False:
|
874
|
+
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
875
|
+
|
876
|
+
hidden_states, next_decoder_cache = layer_outputs[:2]
|
877
|
+
|
878
|
+
# We share the position biases between the layers - the first layer store them
|
879
|
+
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
880
|
+
# (cross-attention position bias), (cross-attention weights)
|
881
|
+
position_bias = layer_outputs[2]
|
882
|
+
if self.is_decoder and encoder_hidden_states is not None:
|
883
|
+
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
884
|
+
|
885
|
+
if output_attentions:
|
886
|
+
all_attentions = all_attentions + (layer_outputs[3],)
|
887
|
+
if self.is_decoder:
|
888
|
+
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
889
|
+
|
890
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
891
|
+
hidden_states = self.dropout(hidden_states)
|
892
|
+
|
893
|
+
# Add last layer
|
894
|
+
if output_hidden_states:
|
895
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
896
|
+
|
897
|
+
next_cache = next_decoder_cache if use_cache else None
|
898
|
+
if return_self_attention_cache:
|
899
|
+
next_cache = past_key_values.self_attention_cache
|
900
|
+
if return_legacy_cache:
|
901
|
+
next_cache = past_key_values.to_legacy_cache()
|
902
|
+
|
903
|
+
if not return_dict:
|
904
|
+
return tuple(
|
905
|
+
v
|
906
|
+
for v in [
|
907
|
+
hidden_states,
|
908
|
+
next_cache,
|
909
|
+
all_hidden_states,
|
910
|
+
all_attentions,
|
911
|
+
all_cross_attentions,
|
912
|
+
]
|
913
|
+
if v is not None
|
914
|
+
)
|
915
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
916
|
+
last_hidden_state=hidden_states,
|
917
|
+
past_key_values=next_cache,
|
918
|
+
hidden_states=all_hidden_states,
|
919
|
+
attentions=all_attentions,
|
920
|
+
cross_attentions=all_cross_attentions,
|
921
|
+
)
|
922
|
+
|
923
|
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
924
|
+
def _update_causal_mask(
|
925
|
+
self,
|
926
|
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
927
|
+
input_tensor: torch.Tensor,
|
928
|
+
cache_position: torch.Tensor,
|
929
|
+
past_key_values: Cache,
|
930
|
+
output_attentions: bool = False,
|
931
|
+
):
|
932
|
+
if self.config._attn_implementation == "flash_attention_2":
|
933
|
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
934
|
+
return attention_mask
|
935
|
+
return None
|
936
|
+
if self.config._attn_implementation == "flex_attention":
|
937
|
+
if isinstance(attention_mask, torch.Tensor):
|
938
|
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
939
|
+
return attention_mask
|
940
|
+
|
941
|
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
942
|
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
943
|
+
# to infer the attention mask.
|
944
|
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
945
|
+
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
946
|
+
|
947
|
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
948
|
+
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
949
|
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
950
|
+
attention_mask,
|
951
|
+
inputs_embeds=input_tensor,
|
952
|
+
past_key_values_length=past_seen_tokens,
|
953
|
+
is_training=self.training,
|
954
|
+
):
|
955
|
+
return None
|
956
|
+
|
957
|
+
dtype = input_tensor.dtype
|
958
|
+
sequence_length = input_tensor.shape[1]
|
959
|
+
if using_compilable_cache:
|
960
|
+
target_length = past_key_values.get_max_cache_shape()
|
961
|
+
else:
|
962
|
+
target_length = (
|
963
|
+
attention_mask.shape[-1]
|
964
|
+
if isinstance(attention_mask, torch.Tensor)
|
965
|
+
else past_seen_tokens + sequence_length + 1
|
966
|
+
)
|
967
|
+
|
968
|
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
969
|
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
970
|
+
attention_mask,
|
971
|
+
sequence_length=sequence_length,
|
972
|
+
target_length=target_length,
|
973
|
+
dtype=dtype,
|
974
|
+
cache_position=cache_position,
|
975
|
+
batch_size=input_tensor.shape[0],
|
976
|
+
)
|
977
|
+
|
978
|
+
if (
|
979
|
+
self.config._attn_implementation == "sdpa"
|
980
|
+
and attention_mask is not None
|
981
|
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
982
|
+
and not output_attentions
|
983
|
+
):
|
984
|
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
985
|
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
986
|
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
987
|
+
min_dtype = torch.finfo(dtype).min
|
988
|
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
989
|
+
|
990
|
+
return causal_mask
|
991
|
+
|
992
|
+
@staticmethod
|
993
|
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
994
|
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
995
|
+
attention_mask: torch.Tensor,
|
996
|
+
sequence_length: int,
|
997
|
+
target_length: int,
|
998
|
+
dtype: torch.dtype,
|
999
|
+
cache_position: torch.Tensor,
|
1000
|
+
batch_size: int,
|
1001
|
+
**kwargs,
|
1002
|
+
):
|
1003
|
+
"""
|
1004
|
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
1005
|
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
1006
|
+
|
1007
|
+
Args:
|
1008
|
+
attention_mask (`torch.Tensor`):
|
1009
|
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
1010
|
+
`(batch_size, 1, query_length, key_value_length)`.
|
1011
|
+
sequence_length (`int`):
|
1012
|
+
The sequence length being processed.
|
1013
|
+
target_length (`int`):
|
1014
|
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
1015
|
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
1016
|
+
dtype (`torch.dtype`):
|
1017
|
+
The dtype to use for the 4D attention mask.
|
1018
|
+
cache_position (`torch.Tensor`):
|
1019
|
+
Indices depicting the position of the input sequence tokens in the sequence.
|
1020
|
+
batch_size (`torch.Tensor`):
|
1021
|
+
Batch size.
|
1022
|
+
"""
|
1023
|
+
if attention_mask is not None and attention_mask.dim() == 4:
|
1024
|
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
1025
|
+
causal_mask = attention_mask
|
1026
|
+
else:
|
1027
|
+
min_dtype = torch.finfo(dtype).min
|
1028
|
+
causal_mask = torch.full(
|
1029
|
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
1030
|
+
)
|
1031
|
+
if sequence_length != 1:
|
1032
|
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
1033
|
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
1034
|
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
1035
|
+
if attention_mask is not None:
|
1036
|
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
1037
|
+
mask_length = attention_mask.shape[-1]
|
1038
|
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
1039
|
+
causal_mask.device
|
1040
|
+
)
|
1041
|
+
padding_mask = padding_mask == 0
|
1042
|
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
1043
|
+
padding_mask, min_dtype
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
return causal_mask
|
1047
|
+
|
1048
|
+
class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin):
|
1049
|
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
1050
|
+
|
1051
|
+
def __init__(self, config: Pop2PianoConfig):
|
1052
|
+
super().__init__(config)
|
1053
|
+
self.config = config
|
1054
|
+
self.model_dim = config.d_model
|
1055
|
+
|
1056
|
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1057
|
+
|
1058
|
+
self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config)
|
1059
|
+
|
1060
|
+
encoder_config = copy.deepcopy(config)
|
1061
|
+
encoder_config.is_decoder = False
|
1062
|
+
encoder_config.use_cache = False
|
1063
|
+
encoder_config.is_encoder_decoder = False
|
1064
|
+
|
1065
|
+
self.encoder = Pop2PianoStack(encoder_config, self.shared)
|
1066
|
+
|
1067
|
+
decoder_config = copy.deepcopy(config)
|
1068
|
+
decoder_config.is_decoder = True
|
1069
|
+
decoder_config.is_encoder_decoder = False
|
1070
|
+
decoder_config.num_layers = config.num_decoder_layers
|
1071
|
+
self.decoder = Pop2PianoStack(decoder_config, self.shared)
|
1072
|
+
|
1073
|
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1074
|
+
|
1075
|
+
# Initialize weights and apply final processing
|
1076
|
+
self.post_init()
|
1077
|
+
|
1078
|
+
def get_input_embeddings(self):
|
1079
|
+
return self.shared
|
1080
|
+
|
1081
|
+
def set_input_embeddings(self, new_embeddings):
|
1082
|
+
self.shared = new_embeddings
|
1083
|
+
self.encoder.set_input_embeddings(new_embeddings)
|
1084
|
+
self.decoder.set_input_embeddings(new_embeddings)
|
1085
|
+
|
1086
|
+
def set_output_embeddings(self, new_embeddings):
|
1087
|
+
self.lm_head = new_embeddings
|
1088
|
+
|
1089
|
+
def get_output_embeddings(self):
|
1090
|
+
return self.lm_head
|
1091
|
+
|
1092
|
+
def get_encoder(self):
|
1093
|
+
return self.encoder
|
1094
|
+
|
1095
|
+
def get_decoder(self):
|
1096
|
+
return self.decoder
|
1097
|
+
|
1098
|
+
def get_mel_conditioner_outputs(
|
1099
|
+
self,
|
1100
|
+
input_features: torch.FloatTensor,
|
1101
|
+
composer: str,
|
1102
|
+
generation_config: GenerationConfig,
|
1103
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1104
|
+
):
|
1105
|
+
"""
|
1106
|
+
This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
|
1107
|
+
control the type of MIDI token generated by the model.
|
1108
|
+
|
1109
|
+
Args:
|
1110
|
+
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
1111
|
+
input features extracted from the feature extractor.
|
1112
|
+
composer (`str`):
|
1113
|
+
composer token which determines the type of MIDI tokens to be generated.
|
1114
|
+
generation_config (`~generation.GenerationConfig`):
|
1115
|
+
The generation is used to get the composer-feature_token pair.
|
1116
|
+
attention_mask (``, *optional*):
|
1117
|
+
For batched generation `input_features` are padded to have the same shape across all examples.
|
1118
|
+
`attention_mask` helps to determine which areas were padded and which were not.
|
1119
|
+
- 1 for tokens that are **not padded**,
|
1120
|
+
- 0 for tokens that are **padded**.
|
1121
|
+
"""
|
1122
|
+
composer_to_feature_token = generation_config.composer_to_feature_token
|
1123
|
+
if composer not in composer_to_feature_token.keys():
|
1124
|
+
raise ValueError(
|
1125
|
+
f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}"
|
1126
|
+
)
|
1127
|
+
composer_value = composer_to_feature_token[composer]
|
1128
|
+
composer_value = torch.tensor(composer_value, device=self.device)
|
1129
|
+
composer_value = composer_value.repeat(input_features.shape[0])
|
1130
|
+
|
1131
|
+
embedding_offset = min(composer_to_feature_token.values())
|
1132
|
+
|
1133
|
+
input_features = self.mel_conditioner(
|
1134
|
+
feature=input_features,
|
1135
|
+
index_value=composer_value,
|
1136
|
+
embedding_offset=embedding_offset,
|
1137
|
+
)
|
1138
|
+
if attention_mask is not None:
|
1139
|
+
input_features[~attention_mask[:, 0].bool()] = 0.0
|
1140
|
+
|
1141
|
+
# since self.mel_conditioner adds a new array at the front of inputs_embeds we need to do the same for attention_mask to keep the shapes same
|
1142
|
+
attention_mask = torch.concatenate([attention_mask[:, 0].view(-1, 1), attention_mask], axis=1)
|
1143
|
+
return input_features, attention_mask
|
1144
|
+
|
1145
|
+
return input_features, None
|
1146
|
+
|
1147
|
+
def forward(
|
1148
|
+
self,
|
1149
|
+
input_ids: Optional[torch.LongTensor] = None,
|
1150
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1151
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1152
|
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
1153
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
1154
|
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
1155
|
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1156
|
+
encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
|
1157
|
+
past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None,
|
1158
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1159
|
+
input_features: Optional[torch.FloatTensor] = None,
|
1160
|
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1161
|
+
labels: Optional[torch.LongTensor] = None,
|
1162
|
+
use_cache: Optional[bool] = None,
|
1163
|
+
output_attentions: Optional[bool] = None,
|
1164
|
+
output_hidden_states: Optional[bool] = None,
|
1165
|
+
return_dict: Optional[bool] = None,
|
1166
|
+
cache_position: Optional[torch.LongTensor] = None,
|
1167
|
+
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
1168
|
+
r"""
|
1169
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1170
|
+
Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
|
1171
|
+
so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
|
1172
|
+
[`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
|
1173
|
+
[What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
|
1174
|
+
take a look a [Pop2Piano Training](./Pop2Piano#training).
|
1175
|
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
1176
|
+
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
|
1177
|
+
[`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
1178
|
+
[What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
|
1179
|
+
starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
|
1180
|
+
`decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
|
1181
|
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
1182
|
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
1183
|
+
be used by default.
|
1184
|
+
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1185
|
+
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
|
1186
|
+
1]`:
|
1187
|
+
- 1 indicates the head is **not masked**,
|
1188
|
+
- 0 indicates the head is **masked**.
|
1189
|
+
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1190
|
+
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
|
1191
|
+
`[0, 1]`:
|
1192
|
+
- 1 indicates the head is **not masked**,
|
1193
|
+
- 0 indicates the head is **masked**.
|
1194
|
+
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1195
|
+
Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present
|
1196
|
+
then `input_features` will be considered as `inputs_embeds`.
|
1197
|
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1198
|
+
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
1199
|
+
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
|
1200
|
+
labels in `[0, ..., config.vocab_size]`
|
1201
|
+
"""
|
1202
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1203
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1204
|
+
|
1205
|
+
if inputs_embeds is not None and input_features is not None:
|
1206
|
+
raise ValueError("Both `inputs_embeds` and `input_features` received! Please provide only one of them")
|
1207
|
+
elif input_features is not None and inputs_embeds is None:
|
1208
|
+
inputs_embeds = input_features
|
1209
|
+
|
1210
|
+
# Encode if needed (training, first prediction pass)
|
1211
|
+
if encoder_outputs is None:
|
1212
|
+
# Convert encoder inputs in embeddings if needed
|
1213
|
+
encoder_outputs = self.encoder(
|
1214
|
+
input_ids=input_ids,
|
1215
|
+
attention_mask=attention_mask,
|
1216
|
+
inputs_embeds=inputs_embeds,
|
1217
|
+
head_mask=head_mask,
|
1218
|
+
output_attentions=output_attentions,
|
1219
|
+
output_hidden_states=output_hidden_states,
|
1220
|
+
return_dict=return_dict,
|
1221
|
+
)
|
1222
|
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1223
|
+
encoder_outputs = BaseModelOutput(
|
1224
|
+
last_hidden_state=encoder_outputs[0],
|
1225
|
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1226
|
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1227
|
+
)
|
1228
|
+
|
1229
|
+
hidden_states = encoder_outputs[0]
|
1230
|
+
|
1231
|
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
1232
|
+
# get decoder inputs from shifting lm labels to the right
|
1233
|
+
decoder_input_ids = self._shift_right(labels)
|
1234
|
+
|
1235
|
+
# Decode
|
1236
|
+
decoder_outputs = self.decoder(
|
1237
|
+
input_ids=decoder_input_ids,
|
1238
|
+
attention_mask=decoder_attention_mask,
|
1239
|
+
inputs_embeds=decoder_inputs_embeds,
|
1240
|
+
past_key_values=past_key_values,
|
1241
|
+
encoder_hidden_states=hidden_states,
|
1242
|
+
encoder_attention_mask=attention_mask,
|
1243
|
+
head_mask=decoder_head_mask,
|
1244
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
1245
|
+
use_cache=use_cache,
|
1246
|
+
output_attentions=output_attentions,
|
1247
|
+
output_hidden_states=output_hidden_states,
|
1248
|
+
return_dict=return_dict,
|
1249
|
+
cache_position=cache_position,
|
1250
|
+
)
|
1251
|
+
|
1252
|
+
sequence_output = decoder_outputs[0]
|
1253
|
+
|
1254
|
+
if self.config.tie_word_embeddings:
|
1255
|
+
# Rescale output before projecting on vocab
|
1256
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
1257
|
+
sequence_output = sequence_output * (self.model_dim**-0.5)
|
1258
|
+
|
1259
|
+
lm_logits = self.lm_head(sequence_output)
|
1260
|
+
|
1261
|
+
loss = None
|
1262
|
+
if labels is not None:
|
1263
|
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
1264
|
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
1265
|
+
|
1266
|
+
if not return_dict:
|
1267
|
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
1268
|
+
return ((loss,) + output) if loss is not None else output
|
1269
|
+
|
1270
|
+
return Seq2SeqLMOutput(
|
1271
|
+
loss=loss,
|
1272
|
+
logits=lm_logits,
|
1273
|
+
past_key_values=decoder_outputs.past_key_values,
|
1274
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1275
|
+
decoder_attentions=decoder_outputs.attentions,
|
1276
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
1277
|
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1278
|
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1279
|
+
encoder_attentions=encoder_outputs.attentions,
|
1280
|
+
)
|
1281
|
+
|
1282
|
+
@torch.no_grad()
|
1283
|
+
def generate(
|
1284
|
+
self,
|
1285
|
+
input_features,
|
1286
|
+
attention_mask=None,
|
1287
|
+
composer="composer1",
|
1288
|
+
generation_config=None,
|
1289
|
+
**kwargs,
|
1290
|
+
):
|
1291
|
+
"""
|
1292
|
+
Generates token ids for midi outputs.
|
1293
|
+
|
1294
|
+
<Tip warning={true}>
|
1295
|
+
|
1296
|
+
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
1297
|
+
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
1298
|
+
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
|
1299
|
+
strategies and code examples, check out the [following guide](./generation_strategies).
|
1300
|
+
|
1301
|
+
</Tip>
|
1302
|
+
|
1303
|
+
Parameters:
|
1304
|
+
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1305
|
+
This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
|
1306
|
+
attention_mask:
|
1307
|
+
For batched generation `input_features` are padded to have the same shape across all examples.
|
1308
|
+
`attention_mask` helps to determine which areas were padded and which were not.
|
1309
|
+
- 1 for tokens that are **not padded**,
|
1310
|
+
- 0 for tokens that are **padded**.
|
1311
|
+
composer (`str`, *optional*, defaults to `"composer1"`):
|
1312
|
+
This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
|
1313
|
+
`"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
|
1314
|
+
`generation_config`. For an example please see
|
1315
|
+
https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
|
1316
|
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
1317
|
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
1318
|
+
passed to generate matching the attributes of `generation_config` will override them. If
|
1319
|
+
`generation_config` is not provided, the default will be used, which had the following loading
|
1320
|
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
1321
|
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
1322
|
+
default values, whose documentation should be checked to parameterize generation.
|
1323
|
+
kwargs:
|
1324
|
+
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
1325
|
+
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
1326
|
+
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
1327
|
+
Return:
|
1328
|
+
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
1329
|
+
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
1330
|
+
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
1331
|
+
[`~utils.ModelOutput`] types are:
|
1332
|
+
- [`~generation.GenerateEncoderDecoderOutput`],
|
1333
|
+
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
1334
|
+
"""
|
1335
|
+
|
1336
|
+
if generation_config is None:
|
1337
|
+
generation_config = self.generation_config
|
1338
|
+
generation_config.update(**kwargs)
|
1339
|
+
|
1340
|
+
# check for composer_to_feature_token
|
1341
|
+
if not hasattr(generation_config, "composer_to_feature_token"):
|
1342
|
+
raise ValueError(
|
1343
|
+
"`composer_to_feature_token` was not found! Please refer to "
|
1344
|
+
"https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json"
|
1345
|
+
"and parse a dict like that."
|
1346
|
+
)
|
1347
|
+
|
1348
|
+
if len(generation_config.composer_to_feature_token) != self.config.composer_vocab_size:
|
1349
|
+
raise ValueError(
|
1350
|
+
"config.composer_vocab_size must be same as the number of keys in "
|
1351
|
+
f"generation_config.composer_to_feature_token! "
|
1352
|
+
f"Found {self.config.composer_vocab_size} vs {len(generation_config.composer_to_feature_token)}."
|
1353
|
+
)
|
1354
|
+
|
1355
|
+
# to control the variation of generated MIDI tokens we concatenate mel-conditioner tokens(which depends on composer_token)
|
1356
|
+
# at the front of input_features.
|
1357
|
+
input_features, attention_mask = self.get_mel_conditioner_outputs(
|
1358
|
+
input_features=input_features,
|
1359
|
+
attention_mask=attention_mask,
|
1360
|
+
composer=composer,
|
1361
|
+
generation_config=generation_config,
|
1362
|
+
)
|
1363
|
+
|
1364
|
+
return super().generate(
|
1365
|
+
inputs=None,
|
1366
|
+
inputs_embeds=input_features,
|
1367
|
+
attention_mask=attention_mask,
|
1368
|
+
generation_config=generation_config,
|
1369
|
+
**kwargs,
|
1370
|
+
)
|
1371
|
+
|
1372
|
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
1373
|
+
return self._shift_right(labels)
|
1374
|
+
|
1375
|
+
def _reorder_cache(self, past_key_values, beam_idx):
|
1376
|
+
# if decoder past is not included in output
|
1377
|
+
# speedy decoding is disabled and no need to reorder
|
1378
|
+
if past_key_values is None:
|
1379
|
+
print("You might want to consider setting `use_cache=True` to speed up decoding")
|
1380
|
+
return past_key_values
|
1381
|
+
|
1382
|
+
reordered_decoder_past = ()
|
1383
|
+
for layer_past_states in past_key_values:
|
1384
|
+
# get the correct batch idx from layer past batch dim
|
1385
|
+
# batch dim of `past` is at 2nd position
|
1386
|
+
reordered_layer_past_states = ()
|
1387
|
+
for layer_past_state in layer_past_states:
|
1388
|
+
# need to set correct `past` for each of the four key / value states
|
1389
|
+
reordered_layer_past_states = reordered_layer_past_states + (
|
1390
|
+
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
|
1391
|
+
)
|
1392
|
+
|
1393
|
+
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
|
1394
|
+
raise ValueError(
|
1395
|
+
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
|
1396
|
+
)
|
1397
|
+
if len(reordered_layer_past_states) != len(layer_past_states):
|
1398
|
+
raise ValueError(
|
1399
|
+
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
|
1400
|
+
)
|
1401
|
+
|
1402
|
+
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
1403
|
+
return reordered_decoder_past
|
1404
|
+
|
1405
|
+
class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
|
1406
|
+
r"""
|
1407
|
+
Constructs a Pop2Piano feature extractor.
|
1408
|
+
|
1409
|
+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
1410
|
+
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
1411
|
+
|
1412
|
+
This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed
|
1413
|
+
to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as
|
1414
|
+
well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate
|
1415
|
+
extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate
|
1416
|
+
and preprocessed and then log mel spectogram is computed from that to be used in our transformer model.
|
1417
|
+
|
1418
|
+
Args:
|
1419
|
+
sampling_rate (`int`, *optional*, defaults to 22050):
|
1420
|
+
Target Sampling rate of audio signal. It's the sampling rate that we forward to the model.
|
1421
|
+
padding_value (`int`, *optional*, defaults to 0):
|
1422
|
+
Padding value used to pad the audio. Should correspond to silences.
|
1423
|
+
window_size (`int`, *optional*, defaults to 4096):
|
1424
|
+
Length of the window in samples to which the Fourier transform is applied.
|
1425
|
+
hop_length (`int`, *optional*, defaults to 1024):
|
1426
|
+
Step size between each window of the waveform, in samples.
|
1427
|
+
min_frequency (`float`, *optional*, defaults to 10.0):
|
1428
|
+
Lowest frequency that will be used in the log-mel spectrogram.
|
1429
|
+
feature_size (`int`, *optional*, defaults to 512):
|
1430
|
+
The feature dimension of the extracted features.
|
1431
|
+
num_bars (`int`, *optional*, defaults to 2):
|
1432
|
+
Determines interval between each sequence.
|
1433
|
+
"""
|
1434
|
+
|
1435
|
+
model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"]
|
1436
|
+
|
1437
|
+
def __init__(
|
1438
|
+
self,
|
1439
|
+
sampling_rate: int = 22050,
|
1440
|
+
padding_value: int = 0,
|
1441
|
+
window_size: int = 4096,
|
1442
|
+
hop_length: int = 1024,
|
1443
|
+
min_frequency: float = 10.0,
|
1444
|
+
feature_size: int = 512,
|
1445
|
+
num_bars: int = 2,
|
1446
|
+
**kwargs,
|
1447
|
+
):
|
1448
|
+
super().__init__(
|
1449
|
+
feature_size=feature_size,
|
1450
|
+
sampling_rate=sampling_rate,
|
1451
|
+
padding_value=padding_value,
|
1452
|
+
**kwargs,
|
1453
|
+
)
|
1454
|
+
self.sampling_rate = sampling_rate
|
1455
|
+
self.padding_value = padding_value
|
1456
|
+
self.window_size = window_size
|
1457
|
+
self.hop_length = hop_length
|
1458
|
+
self.min_frequency = min_frequency
|
1459
|
+
self.feature_size = feature_size
|
1460
|
+
self.num_bars = num_bars
|
1461
|
+
self.mel_filters = mel_filter_bank(
|
1462
|
+
num_frequency_bins=(self.window_size // 2) + 1,
|
1463
|
+
num_mel_filters=self.feature_size,
|
1464
|
+
min_frequency=self.min_frequency,
|
1465
|
+
max_frequency=float(self.sampling_rate // 2),
|
1466
|
+
sampling_rate=self.sampling_rate,
|
1467
|
+
norm=None,
|
1468
|
+
mel_scale="htk",
|
1469
|
+
)
|
1470
|
+
|
1471
|
+
def mel_spectrogram(self, sequence: np.ndarray):
|
1472
|
+
"""
|
1473
|
+
Generates MelSpectrogram.
|
1474
|
+
|
1475
|
+
Args:
|
1476
|
+
sequence (`np.ndarray`):
|
1477
|
+
The sequence of which the mel-spectrogram will be computed.
|
1478
|
+
"""
|
1479
|
+
mel_specs = []
|
1480
|
+
for seq in sequence:
|
1481
|
+
window = np.hanning(self.window_size + 1)[:-1]
|
1482
|
+
mel_specs.append(
|
1483
|
+
spectrogram(
|
1484
|
+
waveform=seq,
|
1485
|
+
window=window,
|
1486
|
+
frame_length=self.window_size,
|
1487
|
+
hop_length=self.hop_length,
|
1488
|
+
power=2.0,
|
1489
|
+
mel_filters=self.mel_filters,
|
1490
|
+
)
|
1491
|
+
)
|
1492
|
+
mel_specs = np.array(mel_specs)
|
1493
|
+
|
1494
|
+
return mel_specs
|
1495
|
+
|
1496
|
+
def extract_rhythm(self, audio: np.ndarray):
|
1497
|
+
"""
|
1498
|
+
This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as
|
1499
|
+
tempo in bpm for an audio signal. For more information please visit
|
1500
|
+
https://essentia.upf.edu/reference/std_RhythmExtractor2013.html .
|
1501
|
+
|
1502
|
+
Args:
|
1503
|
+
audio(`np.ndarray`):
|
1504
|
+
raw audio waveform which is passed to the Rhythm Extractor.
|
1505
|
+
"""
|
1506
|
+
essentia_tracker = RhythmExtractor2013(method="multifeature")
|
1507
|
+
bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio)
|
1508
|
+
|
1509
|
+
return bpm, beat_times, confidence, estimates, essentia_beat_intervals
|
1510
|
+
|
1511
|
+
def interpolate_beat_times(
|
1512
|
+
self, beat_times: np.ndarray, steps_per_beat: np.ndarray, n_extend: np.ndarray
|
1513
|
+
):
|
1514
|
+
"""
|
1515
|
+
This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is
|
1516
|
+
then used to convert raw audio to log-mel-spectrogram.
|
1517
|
+
|
1518
|
+
Args:
|
1519
|
+
beat_times (`np.ndarray`):
|
1520
|
+
beat_times is passed into `scipy.interpolate.interp1d` for processing.
|
1521
|
+
steps_per_beat (`int`):
|
1522
|
+
used as an parameter to control the interpolation.
|
1523
|
+
n_extend (`int`):
|
1524
|
+
used as an parameter to control the interpolation.
|
1525
|
+
"""
|
1526
|
+
|
1527
|
+
beat_times_function = interp1d(
|
1528
|
+
np.arange(beat_times.size),
|
1529
|
+
beat_times,
|
1530
|
+
bounds_error=False,
|
1531
|
+
fill_value="extrapolate",
|
1532
|
+
)
|
1533
|
+
|
1534
|
+
ext_beats = beat_times_function(
|
1535
|
+
np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend)
|
1536
|
+
)
|
1537
|
+
|
1538
|
+
return ext_beats
|
1539
|
+
|
1540
|
+
def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray):
|
1541
|
+
"""
|
1542
|
+
Preprocessing for log-mel-spectrogram
|
1543
|
+
|
1544
|
+
Args:
|
1545
|
+
audio (`np.ndarray` of shape `(audio_length, )` ):
|
1546
|
+
Raw audio waveform to be processed.
|
1547
|
+
beatstep (`np.ndarray`):
|
1548
|
+
Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by
|
1549
|
+
the value at beatstep[0].
|
1550
|
+
"""
|
1551
|
+
|
1552
|
+
if audio is not None and len(audio.shape) != 1:
|
1553
|
+
raise ValueError(
|
1554
|
+
f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}."
|
1555
|
+
)
|
1556
|
+
if beatstep[0] > 0.0:
|
1557
|
+
beatstep = beatstep - beatstep[0]
|
1558
|
+
|
1559
|
+
num_steps = self.num_bars * 4
|
1560
|
+
num_target_steps = len(beatstep)
|
1561
|
+
extrapolated_beatstep = self.interpolate_beat_times(
|
1562
|
+
beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1
|
1563
|
+
)
|
1564
|
+
|
1565
|
+
sample_indices = []
|
1566
|
+
max_feature_length = 0
|
1567
|
+
for i in range(0, num_target_steps, num_steps):
|
1568
|
+
start_idx = i
|
1569
|
+
end_idx = min(i + num_steps, num_target_steps)
|
1570
|
+
start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate)
|
1571
|
+
end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate)
|
1572
|
+
sample_indices.append((start_sample, end_sample))
|
1573
|
+
max_feature_length = max(max_feature_length, end_sample - start_sample)
|
1574
|
+
padded_batch = []
|
1575
|
+
for start_sample, end_sample in sample_indices:
|
1576
|
+
feature = audio[start_sample:end_sample]
|
1577
|
+
padded_feature = np.pad(
|
1578
|
+
feature,
|
1579
|
+
((0, max_feature_length - feature.shape[0]),),
|
1580
|
+
"constant",
|
1581
|
+
constant_values=0,
|
1582
|
+
)
|
1583
|
+
padded_batch.append(padded_feature)
|
1584
|
+
|
1585
|
+
padded_batch = np.asarray(padded_batch)
|
1586
|
+
return padded_batch, extrapolated_beatstep
|
1587
|
+
|
1588
|
+
def _pad(self, features: np.ndarray, add_zero_line=True):
|
1589
|
+
features_shapes = [each_feature.shape for each_feature in features]
|
1590
|
+
attention_masks, padded_features = [], []
|
1591
|
+
for i, each_feature in enumerate(features):
|
1592
|
+
# To pad "input_features".
|
1593
|
+
if len(each_feature.shape) == 3:
|
1594
|
+
features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1]
|
1595
|
+
attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64)
|
1596
|
+
feature_padding = ((0, 0), (0, features_pad_value), (0, 0))
|
1597
|
+
attention_mask_padding = (feature_padding[0], feature_padding[1])
|
1598
|
+
|
1599
|
+
# To pad "beatsteps" and "extrapolated_beatstep".
|
1600
|
+
else:
|
1601
|
+
each_feature = each_feature.reshape(1, -1)
|
1602
|
+
features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0]
|
1603
|
+
attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1)
|
1604
|
+
feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value))
|
1605
|
+
|
1606
|
+
each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value)
|
1607
|
+
attention_mask = np.pad(
|
1608
|
+
attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value
|
1609
|
+
)
|
1610
|
+
|
1611
|
+
if add_zero_line:
|
1612
|
+
# if it is batched then we separate each examples using zero array
|
1613
|
+
zero_array_len = max([*zip(*features_shapes)][1])
|
1614
|
+
|
1615
|
+
# we concatenate the zero array line here
|
1616
|
+
each_padded_feature = np.concatenate(
|
1617
|
+
[each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0
|
1618
|
+
)
|
1619
|
+
attention_mask = np.concatenate(
|
1620
|
+
[attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0
|
1621
|
+
)
|
1622
|
+
|
1623
|
+
padded_features.append(each_padded_feature)
|
1624
|
+
attention_masks.append(attention_mask)
|
1625
|
+
|
1626
|
+
padded_features = np.concatenate(padded_features, axis=0).astype(np.float32)
|
1627
|
+
attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64)
|
1628
|
+
|
1629
|
+
return padded_features, attention_masks
|
1630
|
+
|
1631
|
+
def pad(
|
1632
|
+
self,
|
1633
|
+
inputs: BatchFeature,
|
1634
|
+
is_batched: bool,
|
1635
|
+
return_attention_mask: bool,
|
1636
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
1637
|
+
):
|
1638
|
+
"""
|
1639
|
+
Pads the inputs to same length and returns attention_mask.
|
1640
|
+
|
1641
|
+
Args:
|
1642
|
+
inputs (`BatchFeature`):
|
1643
|
+
Processed audio features.
|
1644
|
+
is_batched (`bool`):
|
1645
|
+
Whether inputs are batched or not.
|
1646
|
+
return_attention_mask (`bool`):
|
1647
|
+
Whether to return attention mask or not.
|
1648
|
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
1649
|
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
1650
|
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
1651
|
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
1652
|
+
If nothing is specified, it will return list of `np.ndarray` arrays.
|
1653
|
+
Return:
|
1654
|
+
`BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added
|
1655
|
+
to it:
|
1656
|
+
- **attention_mask** np.ndarray of shape `(batch_size, max_input_features_seq_length)` --
|
1657
|
+
Example :
|
1658
|
+
1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 that's why there are 2 zeros at
|
1659
|
+
the end indicating they are padded)
|
1660
|
+
|
1661
|
+
0, 0, 0, 0, 0 (zero pad to separate audio 1 and 2)
|
1662
|
+
|
1663
|
+
1, 1, 1, 1, 1 (audio 2)
|
1664
|
+
|
1665
|
+
0, 0, 0, 0, 0 (zero pad to separate audio 2 and 3)
|
1666
|
+
|
1667
|
+
1, 1, 1, 1, 1 (audio 3)
|
1668
|
+
- **attention_mask_beatsteps** np.ndarray of shape `(batch_size, max_beatsteps_seq_length)`
|
1669
|
+
- **attention_mask_extrapolated_beatstep** np.ndarray of shape `(batch_size,
|
1670
|
+
max_extrapolated_beatstep_seq_length)`
|
1671
|
+
"""
|
1672
|
+
|
1673
|
+
processed_features_dict = {}
|
1674
|
+
for feature_name, feature_value in inputs.items():
|
1675
|
+
if feature_name == "input_features":
|
1676
|
+
padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True)
|
1677
|
+
processed_features_dict[feature_name] = padded_feature_values
|
1678
|
+
if return_attention_mask:
|
1679
|
+
processed_features_dict["attention_mask"] = attention_mask
|
1680
|
+
else:
|
1681
|
+
padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False)
|
1682
|
+
processed_features_dict[feature_name] = padded_feature_values
|
1683
|
+
if return_attention_mask:
|
1684
|
+
processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask
|
1685
|
+
|
1686
|
+
# If we are processing only one example, we should remove the zero array line since we don't need it to
|
1687
|
+
# separate examples from each other.
|
1688
|
+
if not is_batched and not return_attention_mask:
|
1689
|
+
processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...]
|
1690
|
+
|
1691
|
+
outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors)
|
1692
|
+
|
1693
|
+
return outputs
|
1694
|
+
|
1695
|
+
def __call__(
|
1696
|
+
self,
|
1697
|
+
audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
1698
|
+
sampling_rate: Union[int, list[int]],
|
1699
|
+
steps_per_beat: int = 2,
|
1700
|
+
resample: Optional[bool] = True,
|
1701
|
+
return_attention_mask: Optional[bool] = False,
|
1702
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
1703
|
+
**kwargs,
|
1704
|
+
) -> BatchFeature:
|
1705
|
+
"""
|
1706
|
+
Main method to featurize and prepare for the model.
|
1707
|
+
|
1708
|
+
Args:
|
1709
|
+
audio (`np.ndarray`, `List`):
|
1710
|
+
The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a
|
1711
|
+
list of numpy arrays or a list of list of float values.
|
1712
|
+
sampling_rate (`int`):
|
1713
|
+
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
1714
|
+
`sampling_rate` at the forward call to prevent silent errors.
|
1715
|
+
steps_per_beat (`int`, *optional*, defaults to 2):
|
1716
|
+
This is used in interpolating `beat_times`.
|
1717
|
+
resample (`bool`, *optional*, defaults to `True`):
|
1718
|
+
Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True
|
1719
|
+
during inference.
|
1720
|
+
return_attention_mask (`bool` *optional*, defaults to `False`):
|
1721
|
+
Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as
|
1722
|
+
output or not. Automatically set to True for batched inputs.
|
1723
|
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
1724
|
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
1725
|
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
1726
|
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
1727
|
+
If nothing is specified, it will return list of `np.ndarray` arrays.
|
1728
|
+
"""
|
1729
|
+
is_batched = bool(isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list)))
|
1730
|
+
if is_batched:
|
1731
|
+
# This enables the user to process files of different sampling_rate at same time
|
1732
|
+
if not isinstance(sampling_rate, list):
|
1733
|
+
raise ValueError(
|
1734
|
+
"Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. "
|
1735
|
+
f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]."
|
1736
|
+
)
|
1737
|
+
return_attention_mask = True if return_attention_mask is None else return_attention_mask
|
1738
|
+
else:
|
1739
|
+
audio = [audio]
|
1740
|
+
sampling_rate = [sampling_rate]
|
1741
|
+
return_attention_mask = False if return_attention_mask is None else return_attention_mask
|
1742
|
+
|
1743
|
+
batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], []
|
1744
|
+
total_len = len(audio)
|
1745
|
+
for index, (single_raw_audio, single_sampling_rate) in enumerate(zip(audio, sampling_rate)):
|
1746
|
+
bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm(
|
1747
|
+
audio=single_raw_audio
|
1748
|
+
)
|
1749
|
+
beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1)
|
1750
|
+
if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None:
|
1751
|
+
if resample:
|
1752
|
+
# Change sampling_rate to self.sampling_rate
|
1753
|
+
single_raw_audio = librosa_resample(
|
1754
|
+
single_raw_audio,
|
1755
|
+
orig_sr=single_sampling_rate,
|
1756
|
+
target_sr=self.sampling_rate,
|
1757
|
+
res_type="kaiser_best",
|
1758
|
+
)
|
1759
|
+
else:
|
1760
|
+
print(
|
1761
|
+
f"The sampling_rate of the provided audio is different from the target sampling_rate "
|
1762
|
+
f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. "
|
1763
|
+
f"In these cases it is recommended to use `resample=True` in the `__call__` method to "
|
1764
|
+
f"get the optimal behaviour."
|
1765
|
+
)
|
1766
|
+
|
1767
|
+
single_sampling_rate = self.sampling_rate
|
1768
|
+
start_sample = int(beatsteps[0] * single_sampling_rate)
|
1769
|
+
end_sample = int(beatsteps[-1] * single_sampling_rate)
|
1770
|
+
|
1771
|
+
input_features, extrapolated_beatstep = self.preprocess_mel(
|
1772
|
+
single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0]
|
1773
|
+
)
|
1774
|
+
|
1775
|
+
mel_specs = self.mel_spectrogram(input_features.astype(np.float32))
|
1776
|
+
|
1777
|
+
# apply np.log to get log mel-spectrograms
|
1778
|
+
log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None))
|
1779
|
+
|
1780
|
+
input_features = np.transpose(log_mel_specs, (0, -1, -2))
|
1781
|
+
|
1782
|
+
batch_input_features.append(input_features)
|
1783
|
+
batch_beatsteps.append(beatsteps)
|
1784
|
+
batch_ext_beatstep.append(extrapolated_beatstep)
|
1785
|
+
output = BatchFeature(
|
1786
|
+
{
|
1787
|
+
"input_features": batch_input_features,
|
1788
|
+
"beatsteps": batch_beatsteps,
|
1789
|
+
"extrapolated_beatstep": batch_ext_beatstep,
|
1790
|
+
}
|
1791
|
+
)
|
1792
|
+
|
1793
|
+
output = self.pad(
|
1794
|
+
output,
|
1795
|
+
is_batched=is_batched,
|
1796
|
+
return_attention_mask=return_attention_mask,
|
1797
|
+
return_tensors=return_tensors,
|
1798
|
+
)
|
1799
|
+
|
1800
|
+
return output
|
1801
|
+
|
1802
|
+
VOCAB_FILES_NAMES = {
|
1803
|
+
"vocab": "vocab.json",
|
1804
|
+
}
|
1805
|
+
|
1806
|
+
def token_time_to_note(number, cutoff_time_idx, current_idx):
|
1807
|
+
current_idx += number
|
1808
|
+
if cutoff_time_idx is not None:
|
1809
|
+
current_idx = min(current_idx, cutoff_time_idx)
|
1810
|
+
|
1811
|
+
return current_idx
|
1812
|
+
|
1813
|
+
def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes):
|
1814
|
+
if note_onsets_ready[number] is not None:
|
1815
|
+
# offset with onset
|
1816
|
+
onset_idx = note_onsets_ready[number]
|
1817
|
+
if onset_idx < current_idx:
|
1818
|
+
# Time shift after previous note_on
|
1819
|
+
offset_idx = current_idx
|
1820
|
+
notes.append([onset_idx, offset_idx, number, default_velocity])
|
1821
|
+
onsets_ready = None if current_velocity == 0 else current_idx
|
1822
|
+
note_onsets_ready[number] = onsets_ready
|
1823
|
+
else:
|
1824
|
+
note_onsets_ready[number] = current_idx
|
1825
|
+
return notes
|
1826
|
+
|
1827
|
+
class Pop2PianoTokenizer(PreTrainedTokenizer):
|
1828
|
+
"""
|
1829
|
+
Constructs a Pop2Piano tokenizer. This tokenizer does not require training.
|
1830
|
+
|
1831
|
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
1832
|
+
this superclass for more information regarding those methods.
|
1833
|
+
|
1834
|
+
Args:
|
1835
|
+
vocab (`str`):
|
1836
|
+
Path to the vocab file which contains the vocabulary.
|
1837
|
+
default_velocity (`int`, *optional*, defaults to 77):
|
1838
|
+
Determines the default velocity to be used while creating midi Notes.
|
1839
|
+
num_bars (`int`, *optional*, defaults to 2):
|
1840
|
+
Determines cutoff_time_idx in for each token.
|
1841
|
+
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`):
|
1842
|
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
1843
|
+
token instead.
|
1844
|
+
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1):
|
1845
|
+
The end of sequence token.
|
1846
|
+
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0):
|
1847
|
+
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
1848
|
+
attention mechanisms or loss computation.
|
1849
|
+
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2):
|
1850
|
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
1851
|
+
"""
|
1852
|
+
|
1853
|
+
model_input_names = ["token_ids", "attention_mask"]
|
1854
|
+
vocab_files_names = VOCAB_FILES_NAMES
|
1855
|
+
|
1856
|
+
def __init__(
|
1857
|
+
self,
|
1858
|
+
vocab,
|
1859
|
+
default_velocity=77,
|
1860
|
+
num_bars=2,
|
1861
|
+
unk_token="-1",
|
1862
|
+
eos_token="1",
|
1863
|
+
pad_token="0",
|
1864
|
+
bos_token="2",
|
1865
|
+
**kwargs,
|
1866
|
+
):
|
1867
|
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
1868
|
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
1869
|
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
1870
|
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
1871
|
+
|
1872
|
+
self.default_velocity = default_velocity
|
1873
|
+
self.num_bars = num_bars
|
1874
|
+
|
1875
|
+
# Load the vocab
|
1876
|
+
with open(vocab, "rb") as file:
|
1877
|
+
self.encoder = json_load(file)
|
1878
|
+
|
1879
|
+
# create mappings for encoder
|
1880
|
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
1881
|
+
|
1882
|
+
super().__init__(
|
1883
|
+
unk_token=unk_token,
|
1884
|
+
eos_token=eos_token,
|
1885
|
+
pad_token=pad_token,
|
1886
|
+
bos_token=bos_token,
|
1887
|
+
**kwargs,
|
1888
|
+
)
|
1889
|
+
|
1890
|
+
@property
|
1891
|
+
def vocab_size(self):
|
1892
|
+
"""Returns the vocabulary size of the tokenizer."""
|
1893
|
+
return len(self.encoder)
|
1894
|
+
|
1895
|
+
def get_vocab(self):
|
1896
|
+
"""Returns the vocabulary of the tokenizer."""
|
1897
|
+
return dict(self.encoder, **self.added_tokens_encoder)
|
1898
|
+
|
1899
|
+
def _convert_id_to_token(self, token_id: int) -> list:
|
1900
|
+
"""
|
1901
|
+
Decodes the token ids generated by the transformer into notes.
|
1902
|
+
|
1903
|
+
Args:
|
1904
|
+
token_id (`int`):
|
1905
|
+
This denotes the ids generated by the transformers to be converted to Midi tokens.
|
1906
|
+
|
1907
|
+
Returns:
|
1908
|
+
`List`: A list consists of token_type (`str`) and value (`int`).
|
1909
|
+
"""
|
1910
|
+
|
1911
|
+
token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME")
|
1912
|
+
token_type_value = token_type_value.split("_")
|
1913
|
+
token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0])
|
1914
|
+
|
1915
|
+
return [token_type, value]
|
1916
|
+
|
1917
|
+
def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int:
|
1918
|
+
"""
|
1919
|
+
Encodes the Midi tokens to transformer generated token ids.
|
1920
|
+
|
1921
|
+
Args:
|
1922
|
+
token (`int`):
|
1923
|
+
This denotes the token value.
|
1924
|
+
token_type (`str`):
|
1925
|
+
This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
|
1926
|
+
"TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".
|
1927
|
+
|
1928
|
+
Returns:
|
1929
|
+
`int`: returns the id of the token.
|
1930
|
+
"""
|
1931
|
+
return self.encoder.get(f"{token}_{token_type}", int(self.unk_token))
|
1932
|
+
|
1933
|
+
def relative_batch_tokens_ids_to_notes(
|
1934
|
+
self,
|
1935
|
+
tokens: np.ndarray,
|
1936
|
+
beat_offset_idx: int,
|
1937
|
+
bars_per_batch: int,
|
1938
|
+
cutoff_time_idx: int,
|
1939
|
+
):
|
1940
|
+
"""
|
1941
|
+
Converts relative tokens to notes which are then used to generate pretty midi object.
|
1942
|
+
|
1943
|
+
Args:
|
1944
|
+
tokens (`np.ndarray`):
|
1945
|
+
Tokens to be converted to notes.
|
1946
|
+
beat_offset_idx (`int`):
|
1947
|
+
Denotes beat offset index for each note in generated Midi.
|
1948
|
+
bars_per_batch (`int`):
|
1949
|
+
A parameter to control the Midi output generation.
|
1950
|
+
cutoff_time_idx (`int`):
|
1951
|
+
Denotes the cutoff time index for each note in generated Midi.
|
1952
|
+
"""
|
1953
|
+
|
1954
|
+
notes = None
|
1955
|
+
|
1956
|
+
for index in range(len(tokens)):
|
1957
|
+
_tokens = tokens[index]
|
1958
|
+
_start_idx = beat_offset_idx + index * bars_per_batch * 4
|
1959
|
+
_cutoff_time_idx = cutoff_time_idx + _start_idx
|
1960
|
+
_notes = self.relative_tokens_ids_to_notes(
|
1961
|
+
_tokens,
|
1962
|
+
start_idx=_start_idx,
|
1963
|
+
cutoff_time_idx=_cutoff_time_idx,
|
1964
|
+
)
|
1965
|
+
|
1966
|
+
if len(_notes) == 0:
|
1967
|
+
pass
|
1968
|
+
elif notes is None:
|
1969
|
+
notes = _notes
|
1970
|
+
else:
|
1971
|
+
notes = np.concatenate((notes, _notes), axis=0)
|
1972
|
+
|
1973
|
+
if notes is None:
|
1974
|
+
return []
|
1975
|
+
return notes
|
1976
|
+
|
1977
|
+
def relative_batch_tokens_ids_to_midi(
|
1978
|
+
self,
|
1979
|
+
tokens: np.ndarray,
|
1980
|
+
beatstep: np.ndarray,
|
1981
|
+
beat_offset_idx: int = 0,
|
1982
|
+
bars_per_batch: int = 2,
|
1983
|
+
cutoff_time_idx: int = 12,
|
1984
|
+
):
|
1985
|
+
"""
|
1986
|
+
Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
|
1987
|
+
to notes then uses `notes_to_midi` method to convert them to Midi.
|
1988
|
+
|
1989
|
+
Args:
|
1990
|
+
tokens (`np.ndarray`):
|
1991
|
+
Denotes tokens which alongside beatstep will be converted to Midi.
|
1992
|
+
beatstep (`np.ndarray`):
|
1993
|
+
We get beatstep from feature extractor which is also used to get Midi.
|
1994
|
+
beat_offset_idx (`int`, *optional*, defaults to 0):
|
1995
|
+
Denotes beat offset index for each note in generated Midi.
|
1996
|
+
bars_per_batch (`int`, *optional*, defaults to 2):
|
1997
|
+
A parameter to control the Midi output generation.
|
1998
|
+
cutoff_time_idx (`int`, *optional*, defaults to 12):
|
1999
|
+
Denotes the cutoff time index for each note in generated Midi.
|
2000
|
+
"""
|
2001
|
+
beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx
|
2002
|
+
notes = self.relative_batch_tokens_ids_to_notes(
|
2003
|
+
tokens=tokens,
|
2004
|
+
beat_offset_idx=beat_offset_idx,
|
2005
|
+
bars_per_batch=bars_per_batch,
|
2006
|
+
cutoff_time_idx=cutoff_time_idx,
|
2007
|
+
)
|
2008
|
+
midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx])
|
2009
|
+
return midi
|
2010
|
+
|
2011
|
+
# Taken from the original code
|
2012
|
+
# Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257
|
2013
|
+
def relative_tokens_ids_to_notes(
|
2014
|
+
self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: Optional[float] = None
|
2015
|
+
):
|
2016
|
+
"""
|
2017
|
+
Converts relative tokens to notes which will then be used to create Pretty Midi objects.
|
2018
|
+
|
2019
|
+
Args:
|
2020
|
+
tokens (`np.ndarray`):
|
2021
|
+
Relative Tokens which will be converted to notes.
|
2022
|
+
start_idx (`float`):
|
2023
|
+
A parameter which denotes the starting index.
|
2024
|
+
cutoff_time_idx (`float`, *optional*):
|
2025
|
+
A parameter used while converting tokens to notes.
|
2026
|
+
"""
|
2027
|
+
words = [self._convert_id_to_token(token) for token in tokens]
|
2028
|
+
|
2029
|
+
current_idx = start_idx
|
2030
|
+
current_velocity = 0
|
2031
|
+
note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder.keys()]) + 1)]
|
2032
|
+
notes = []
|
2033
|
+
for token_type, number in words:
|
2034
|
+
if token_type == "TOKEN_SPECIAL":
|
2035
|
+
if number == 1:
|
2036
|
+
break
|
2037
|
+
elif token_type == "TOKEN_TIME":
|
2038
|
+
current_idx = token_time_to_note(
|
2039
|
+
number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx
|
2040
|
+
)
|
2041
|
+
elif token_type == "TOKEN_VELOCITY":
|
2042
|
+
current_velocity = number
|
2043
|
+
|
2044
|
+
elif token_type == "TOKEN_NOTE":
|
2045
|
+
notes = token_note_to_note(
|
2046
|
+
number=number,
|
2047
|
+
current_velocity=current_velocity,
|
2048
|
+
default_velocity=self.default_velocity,
|
2049
|
+
note_onsets_ready=note_onsets_ready,
|
2050
|
+
current_idx=current_idx,
|
2051
|
+
notes=notes,
|
2052
|
+
)
|
2053
|
+
else:
|
2054
|
+
raise ValueError("Token type not understood!")
|
2055
|
+
|
2056
|
+
for pitch, note_onset in enumerate(note_onsets_ready):
|
2057
|
+
# force offset if no offset for each pitch
|
2058
|
+
if note_onset is not None:
|
2059
|
+
if cutoff_time_idx is None:
|
2060
|
+
cutoff = note_onset + 1
|
2061
|
+
else:
|
2062
|
+
cutoff = max(cutoff_time_idx, note_onset + 1)
|
2063
|
+
|
2064
|
+
offset_idx = max(current_idx, cutoff)
|
2065
|
+
notes.append([note_onset, offset_idx, pitch, self.default_velocity])
|
2066
|
+
|
2067
|
+
if len(notes) == 0:
|
2068
|
+
return []
|
2069
|
+
else:
|
2070
|
+
notes = np.array(notes)
|
2071
|
+
note_order = notes[:, 0] * 128 + notes[:, 1]
|
2072
|
+
notes = notes[note_order.argsort()]
|
2073
|
+
return notes
|
2074
|
+
|
2075
|
+
def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0):
|
2076
|
+
"""
|
2077
|
+
Converts notes to Midi.
|
2078
|
+
|
2079
|
+
Args:
|
2080
|
+
notes (`np.ndarray`):
|
2081
|
+
This is used to create Pretty Midi objects.
|
2082
|
+
beatstep (`np.ndarray`):
|
2083
|
+
This is the extrapolated beatstep that we get from feature extractor.
|
2084
|
+
offset_sec (`int`, *optional*, defaults to 0.0):
|
2085
|
+
This represents the offset seconds which is used while creating each Pretty Midi Note.
|
2086
|
+
"""
|
2087
|
+
new_pm = pretty_midi_fix.PrettyMIDI(resolution=384, initial_tempo=120.0)
|
2088
|
+
new_inst = pretty_midi_fix.Instrument(program=0)
|
2089
|
+
new_notes = []
|
2090
|
+
|
2091
|
+
for onset_idx, offset_idx, pitch, velocity in notes:
|
2092
|
+
new_note = pretty_midi_fix.Note(
|
2093
|
+
velocity=velocity,
|
2094
|
+
pitch=pitch,
|
2095
|
+
start=beatstep[onset_idx] - offset_sec,
|
2096
|
+
end=beatstep[offset_idx] - offset_sec,
|
2097
|
+
)
|
2098
|
+
new_notes.append(new_note)
|
2099
|
+
new_inst.notes = new_notes
|
2100
|
+
new_pm.instruments.append(new_inst)
|
2101
|
+
new_pm.remove_invalid_notes()
|
2102
|
+
return new_pm
|
2103
|
+
|
2104
|
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
2105
|
+
"""
|
2106
|
+
Saves the tokenizer's vocabulary dictionary to the provided save_directory.
|
2107
|
+
|
2108
|
+
Args:
|
2109
|
+
save_directory (`str`):
|
2110
|
+
A path to the directory where to saved. It will be created if it doesn't exist.
|
2111
|
+
filename_prefix (`Optional[str]`, *optional*):
|
2112
|
+
A prefix to add to the names of the files saved by the tokenizer.
|
2113
|
+
"""
|
2114
|
+
if not os.path.isdir(save_directory):
|
2115
|
+
print(f"Vocabulary path ({save_directory}) should be a directory")
|
2116
|
+
return
|
2117
|
+
|
2118
|
+
# Save the encoder.
|
2119
|
+
out_vocab_file = os.path.join(
|
2120
|
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
|
2121
|
+
)
|
2122
|
+
with open(out_vocab_file, "w") as file:
|
2123
|
+
file.write(json_dumps(self.encoder))
|
2124
|
+
|
2125
|
+
return (out_vocab_file,)
|
2126
|
+
|
2127
|
+
def encode_plus(
|
2128
|
+
self,
|
2129
|
+
notes: Union[np.ndarray, list[pretty_midi_fix.Note]],
|
2130
|
+
truncation_strategy: Optional[TruncationStrategy] = None,
|
2131
|
+
max_length: Optional[int] = None,
|
2132
|
+
**kwargs,
|
2133
|
+
) -> BatchEncoding:
|
2134
|
+
r"""
|
2135
|
+
This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
|
2136
|
+
generated token ids. It only works on a single batch, to process multiple batches please use
|
2137
|
+
`batch_encode_plus` or `__call__` method.
|
2138
|
+
|
2139
|
+
Args:
|
2140
|
+
notes (`np.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
|
2141
|
+
This represents the midi notes. If `notes` is a `np.ndarray`:
|
2142
|
+
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
|
2143
|
+
If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
|
2144
|
+
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
|
2145
|
+
truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
|
2146
|
+
Indicates the truncation strategy that is going to be used during truncation.
|
2147
|
+
max_length (`int`, *optional*):
|
2148
|
+
Maximum length of the returned list and optionally padding length (see above).
|
2149
|
+
|
2150
|
+
Returns:
|
2151
|
+
`BatchEncoding` containing the tokens ids.
|
2152
|
+
"""
|
2153
|
+
# check if notes is a pretty_midi_fix object or not, if yes then extract the attributes and put them into a numpy
|
2154
|
+
# array.
|
2155
|
+
if isinstance(notes[0], pretty_midi_fix.Note):
|
2156
|
+
notes = np.array(
|
2157
|
+
[[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes]
|
2158
|
+
).reshape(-1, 4)
|
2159
|
+
|
2160
|
+
# to round up all the values to the closest int values.
|
2161
|
+
notes = np.round(notes).astype(np.int32)
|
2162
|
+
max_time_idx = notes[:, :2].max()
|
2163
|
+
|
2164
|
+
times = [[] for i in range(max_time_idx + 1)]
|
2165
|
+
for onset, offset, pitch, velocity in notes:
|
2166
|
+
times[onset].append([pitch, velocity])
|
2167
|
+
times[offset].append([pitch, 0])
|
2168
|
+
|
2169
|
+
tokens = []
|
2170
|
+
current_velocity = 0
|
2171
|
+
for i, time in enumerate(times):
|
2172
|
+
if len(time) == 0:
|
2173
|
+
continue
|
2174
|
+
tokens.append(self._convert_token_to_id(i, "TOKEN_TIME"))
|
2175
|
+
for pitch, velocity in time:
|
2176
|
+
velocity = int(velocity > 0)
|
2177
|
+
if current_velocity != velocity:
|
2178
|
+
current_velocity = velocity
|
2179
|
+
tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY"))
|
2180
|
+
tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE"))
|
2181
|
+
|
2182
|
+
total_len = len(tokens)
|
2183
|
+
|
2184
|
+
# truncation
|
2185
|
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
2186
|
+
tokens, _, _ = self.truncate_sequences(
|
2187
|
+
ids=tokens,
|
2188
|
+
num_tokens_to_remove=total_len - max_length,
|
2189
|
+
truncation_strategy=truncation_strategy,
|
2190
|
+
**kwargs,
|
2191
|
+
)
|
2192
|
+
|
2193
|
+
return BatchEncoding({"token_ids": tokens})
|
2194
|
+
|
2195
|
+
def batch_encode_plus(
|
2196
|
+
self,
|
2197
|
+
notes: Union[np.ndarray, list[pretty_midi_fix.Note]],
|
2198
|
+
truncation_strategy: Optional[TruncationStrategy] = None,
|
2199
|
+
max_length: Optional[int] = None,
|
2200
|
+
**kwargs,
|
2201
|
+
) -> BatchEncoding:
|
2202
|
+
r"""
|
2203
|
+
This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
|
2204
|
+
generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.
|
2205
|
+
|
2206
|
+
Args:
|
2207
|
+
notes (`np.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
|
2208
|
+
This represents the midi notes. If `notes` is a `np.ndarray`:
|
2209
|
+
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
|
2210
|
+
If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
|
2211
|
+
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
|
2212
|
+
truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
|
2213
|
+
Indicates the truncation strategy that is going to be used during truncation.
|
2214
|
+
max_length (`int`, *optional*):
|
2215
|
+
Maximum length of the returned list and optionally padding length (see above).
|
2216
|
+
|
2217
|
+
Returns:
|
2218
|
+
`BatchEncoding` containing the tokens ids.
|
2219
|
+
"""
|
2220
|
+
|
2221
|
+
encoded_batch_token_ids = []
|
2222
|
+
for i in range(len(notes)):
|
2223
|
+
encoded_batch_token_ids.append(
|
2224
|
+
self.encode_plus(
|
2225
|
+
notes[i],
|
2226
|
+
truncation_strategy=truncation_strategy,
|
2227
|
+
max_length=max_length,
|
2228
|
+
**kwargs,
|
2229
|
+
)["token_ids"]
|
2230
|
+
)
|
2231
|
+
|
2232
|
+
return BatchEncoding({"token_ids": encoded_batch_token_ids})
|
2233
|
+
|
2234
|
+
def __call__(
|
2235
|
+
self,
|
2236
|
+
notes: Union[
|
2237
|
+
np.ndarray,
|
2238
|
+
list[pretty_midi_fix.Note],
|
2239
|
+
list[list[pretty_midi_fix.Note]],
|
2240
|
+
],
|
2241
|
+
padding: Union[bool, str, PaddingStrategy] = False,
|
2242
|
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
2243
|
+
max_length: Optional[int] = None,
|
2244
|
+
pad_to_multiple_of: Optional[int] = None,
|
2245
|
+
return_attention_mask: Optional[bool] = None,
|
2246
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
2247
|
+
verbose: bool = True,
|
2248
|
+
**kwargs,
|
2249
|
+
) -> BatchEncoding:
|
2250
|
+
r"""
|
2251
|
+
This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
|
2252
|
+
token ids.
|
2253
|
+
|
2254
|
+
Args:
|
2255
|
+
notes (`np.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi_fix.Note` objects):
|
2256
|
+
This represents the midi notes.
|
2257
|
+
|
2258
|
+
If `notes` is a `np.ndarray`:
|
2259
|
+
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
|
2260
|
+
If `notes` is a `list` containing `pretty_midi_fix.Note` objects:
|
2261
|
+
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
|
2262
|
+
padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
|
2263
|
+
Activates and controls padding. Accepts the following values:
|
2264
|
+
|
2265
|
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
2266
|
+
sequence if provided).
|
2267
|
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
2268
|
+
acceptable input length for the model if that argument is not provided.
|
2269
|
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
2270
|
+
lengths).
|
2271
|
+
truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
|
2272
|
+
Activates and controls truncation. Accepts the following values:
|
2273
|
+
|
2274
|
+
- `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
|
2275
|
+
to the maximum acceptable input length for the model if that argument is not provided. This will
|
2276
|
+
truncate token by token, removing a token from the longest sequence in the pair if a pair of
|
2277
|
+
sequences (or a batch of pairs) is provided.
|
2278
|
+
- `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
|
2279
|
+
maximum acceptable input length for the model if that argument is not provided. This will only
|
2280
|
+
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
2281
|
+
- `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
|
2282
|
+
maximum acceptable input length for the model if that argument is not provided. This will only
|
2283
|
+
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
2284
|
+
- `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
|
2285
|
+
greater than the model maximum admissible input size).
|
2286
|
+
max_length (`int`, *optional*):
|
2287
|
+
Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
|
2288
|
+
`None`, this will use the predefined model maximum length if a maximum length is required by one of the
|
2289
|
+
truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
|
2290
|
+
truncation/padding to a maximum length will be deactivated.
|
2291
|
+
pad_to_multiple_of (`int`, *optional*):
|
2292
|
+
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
|
2293
|
+
the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
|
2294
|
+
return_attention_mask (`bool`, *optional*):
|
2295
|
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
2296
|
+
to the specific tokenizer's default, defined by the `return_outputs` attribute.
|
2297
|
+
|
2298
|
+
[What are attention masks?](../glossary#attention-mask)
|
2299
|
+
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
2300
|
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
2301
|
+
|
2302
|
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
2303
|
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
2304
|
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
2305
|
+
verbose (`bool`, *optional*, defaults to `True`):
|
2306
|
+
Whether or not to print more information and warnings.
|
2307
|
+
|
2308
|
+
Returns:
|
2309
|
+
`BatchEncoding` containing the token_ids.
|
2310
|
+
"""
|
2311
|
+
|
2312
|
+
# check if it is batched or not
|
2313
|
+
# it is batched if its a list containing a list of `pretty_midi_fix.Notes` where the outer list contains all the
|
2314
|
+
# batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be
|
2315
|
+
# considered batched if it has shape of `[batch_size, seqence_length, 4]` or ndim=3.
|
2316
|
+
is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list)
|
2317
|
+
|
2318
|
+
# get the truncation and padding strategy
|
2319
|
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
2320
|
+
padding=padding,
|
2321
|
+
truncation=truncation,
|
2322
|
+
max_length=max_length,
|
2323
|
+
pad_to_multiple_of=pad_to_multiple_of,
|
2324
|
+
verbose=verbose,
|
2325
|
+
**kwargs,
|
2326
|
+
)
|
2327
|
+
|
2328
|
+
if is_batched:
|
2329
|
+
# If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True
|
2330
|
+
return_attention_mask = True if return_attention_mask is None else return_attention_mask
|
2331
|
+
token_ids = self.batch_encode_plus(
|
2332
|
+
notes=notes,
|
2333
|
+
truncation_strategy=truncation_strategy,
|
2334
|
+
max_length=max_length,
|
2335
|
+
**kwargs,
|
2336
|
+
)
|
2337
|
+
else:
|
2338
|
+
token_ids = self.encode_plus(
|
2339
|
+
notes=notes,
|
2340
|
+
truncation_strategy=truncation_strategy,
|
2341
|
+
max_length=max_length,
|
2342
|
+
**kwargs,
|
2343
|
+
)
|
2344
|
+
|
2345
|
+
# since we already have truncated sequnences we are just left to do padding
|
2346
|
+
token_ids = self.pad(
|
2347
|
+
token_ids,
|
2348
|
+
padding=padding_strategy,
|
2349
|
+
max_length=max_length,
|
2350
|
+
pad_to_multiple_of=pad_to_multiple_of,
|
2351
|
+
return_attention_mask=return_attention_mask,
|
2352
|
+
return_tensors=return_tensors,
|
2353
|
+
verbose=verbose,
|
2354
|
+
)
|
2355
|
+
|
2356
|
+
return token_ids
|
2357
|
+
|
2358
|
+
def batch_decode(
|
2359
|
+
self,
|
2360
|
+
token_ids,
|
2361
|
+
feature_extractor_output: BatchFeature,
|
2362
|
+
return_midi: bool = True,
|
2363
|
+
):
|
2364
|
+
r"""
|
2365
|
+
This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
|
2366
|
+
transformer to midi_notes and returns them.
|
2367
|
+
|
2368
|
+
Args:
|
2369
|
+
token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`):
|
2370
|
+
Output token_ids of `Pop2PianoConditionalGeneration` model.
|
2371
|
+
feature_extractor_output (`BatchFeature`):
|
2372
|
+
Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
|
2373
|
+
`"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
|
2374
|
+
`"attention_mask_extrapolated_beatstep"`
|
2375
|
+
should be present if they were returned by the feature extractor.
|
2376
|
+
return_midi (`bool`, *optional*, defaults to `True`):
|
2377
|
+
Whether to return midi object or not.
|
2378
|
+
Returns:
|
2379
|
+
If `return_midi` is True:
|
2380
|
+
- `BatchEncoding` containing both `notes` and `pretty_midi_fix.pretty_midi_fix.PrettyMIDI` objects.
|
2381
|
+
If `return_midi` is False:
|
2382
|
+
- `BatchEncoding` containing `notes`.
|
2383
|
+
"""
|
2384
|
+
|
2385
|
+
# check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not
|
2386
|
+
attention_masks_present = bool(
|
2387
|
+
hasattr(feature_extractor_output, "attention_mask")
|
2388
|
+
and hasattr(feature_extractor_output, "attention_mask_beatsteps")
|
2389
|
+
and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep")
|
2390
|
+
)
|
2391
|
+
|
2392
|
+
# if we are processing batched inputs then we must need attention_masks
|
2393
|
+
if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1:
|
2394
|
+
raise ValueError(
|
2395
|
+
"attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present "
|
2396
|
+
"for batched inputs! But one of them were not present."
|
2397
|
+
)
|
2398
|
+
|
2399
|
+
# check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep
|
2400
|
+
if attention_masks_present:
|
2401
|
+
# since we know about the number of examples in token_ids from attention_mask
|
2402
|
+
if (
|
2403
|
+
sum(feature_extractor_output["attention_mask"][:, 0] == 0)
|
2404
|
+
!= feature_extractor_output["beatsteps"].shape[0]
|
2405
|
+
or feature_extractor_output["beatsteps"].shape[0]
|
2406
|
+
!= feature_extractor_output["extrapolated_beatstep"].shape[0]
|
2407
|
+
):
|
2408
|
+
raise ValueError(
|
2409
|
+
"Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found "
|
2410
|
+
f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} "
|
2411
|
+
f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}"
|
2412
|
+
)
|
2413
|
+
if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]:
|
2414
|
+
raise ValueError(
|
2415
|
+
f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}"
|
2416
|
+
)
|
2417
|
+
else:
|
2418
|
+
# if there is no attention mask present then it's surely a single example
|
2419
|
+
if (
|
2420
|
+
feature_extractor_output["beatsteps"].shape[0] != 1
|
2421
|
+
or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1
|
2422
|
+
):
|
2423
|
+
raise ValueError(
|
2424
|
+
"Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, "
|
2425
|
+
f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}."
|
2426
|
+
)
|
2427
|
+
|
2428
|
+
if attention_masks_present:
|
2429
|
+
# check for zeros(since token_ids are separated by zero arrays)
|
2430
|
+
batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0]
|
2431
|
+
else:
|
2432
|
+
batch_idx = [token_ids.shape[0]]
|
2433
|
+
|
2434
|
+
notes_list = []
|
2435
|
+
pretty_midi_fix_objects_list = []
|
2436
|
+
start_idx = 0
|
2437
|
+
for index, end_idx in enumerate(batch_idx):
|
2438
|
+
each_tokens_ids = token_ids[start_idx:end_idx]
|
2439
|
+
# check where the whole example ended by searching for eos_token_id and getting the upper bound
|
2440
|
+
each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1]
|
2441
|
+
beatsteps = feature_extractor_output["beatsteps"][index]
|
2442
|
+
extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index]
|
2443
|
+
|
2444
|
+
# if attention mask is present then mask out real array/tensor
|
2445
|
+
if attention_masks_present:
|
2446
|
+
attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index]
|
2447
|
+
attention_mask_extrapolated_beatstep = feature_extractor_output[
|
2448
|
+
"attention_mask_extrapolated_beatstep"
|
2449
|
+
][index]
|
2450
|
+
beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1]
|
2451
|
+
extrapolated_beatstep = extrapolated_beatstep[
|
2452
|
+
: np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1
|
2453
|
+
]
|
2454
|
+
|
2455
|
+
each_tokens_ids = to_numpy(each_tokens_ids)
|
2456
|
+
beatsteps = to_numpy(beatsteps)
|
2457
|
+
extrapolated_beatstep = to_numpy(extrapolated_beatstep)
|
2458
|
+
|
2459
|
+
pretty_midi_fix_object = self.relative_batch_tokens_ids_to_midi(
|
2460
|
+
tokens=each_tokens_ids,
|
2461
|
+
beatstep=extrapolated_beatstep,
|
2462
|
+
bars_per_batch=self.num_bars,
|
2463
|
+
cutoff_time_idx=(self.num_bars + 1) * 4,
|
2464
|
+
)
|
2465
|
+
|
2466
|
+
for note in pretty_midi_fix_object.instruments[0].notes:
|
2467
|
+
note.start += beatsteps[0]
|
2468
|
+
note.end += beatsteps[0]
|
2469
|
+
notes_list.append(note)
|
2470
|
+
|
2471
|
+
pretty_midi_fix_objects_list.append(pretty_midi_fix_object)
|
2472
|
+
start_idx += end_idx + 1 # 1 represents the zero array
|
2473
|
+
|
2474
|
+
if return_midi:
|
2475
|
+
return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_fix_objects_list})
|
2476
|
+
|
2477
|
+
return BatchEncoding({"notes": notes_list})
|
2478
|
+
|
2479
|
+
class Pop2PianoProcessor(ProcessorMixin):
|
2480
|
+
r"""
|
2481
|
+
Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
|
2482
|
+
processor.
|
2483
|
+
|
2484
|
+
[`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
|
2485
|
+
See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.
|
2486
|
+
|
2487
|
+
Args:
|
2488
|
+
feature_extractor (`Pop2PianoFeatureExtractor`):
|
2489
|
+
An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
|
2490
|
+
tokenizer (`Pop2PianoTokenizer`):
|
2491
|
+
An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
|
2492
|
+
"""
|
2493
|
+
|
2494
|
+
attributes = ["feature_extractor", "tokenizer"]
|
2495
|
+
feature_extractor_class = "Pop2PianoFeatureExtractor"
|
2496
|
+
tokenizer_class = "Pop2PianoTokenizer"
|
2497
|
+
|
2498
|
+
def __init__(self, feature_extractor, tokenizer):
|
2499
|
+
super().__init__(feature_extractor, tokenizer)
|
2500
|
+
|
2501
|
+
def __call__(
|
2502
|
+
self,
|
2503
|
+
audio: Union[np.ndarray, list[float], list[np.ndarray]] = None,
|
2504
|
+
sampling_rate: Optional[Union[int, list[int]]] = None,
|
2505
|
+
steps_per_beat: int = 2,
|
2506
|
+
resample: Optional[bool] = True,
|
2507
|
+
notes: Union[list, TensorType] = None,
|
2508
|
+
padding: Union[bool, str, PaddingStrategy] = False,
|
2509
|
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
2510
|
+
max_length: Optional[int] = None,
|
2511
|
+
pad_to_multiple_of: Optional[int] = None,
|
2512
|
+
verbose: bool = True,
|
2513
|
+
**kwargs,
|
2514
|
+
) -> Union[BatchFeature, BatchEncoding]:
|
2515
|
+
"""
|
2516
|
+
This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model,
|
2517
|
+
and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes.
|
2518
|
+
|
2519
|
+
Please refer to the docstring of the above two methods for more information.
|
2520
|
+
"""
|
2521
|
+
|
2522
|
+
# Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
|
2523
|
+
# feature_extractor_output, we must check for both.
|
2524
|
+
if (audio is None and sampling_rate is None) and (notes is None):
|
2525
|
+
raise ValueError(
|
2526
|
+
"You have to specify at least audios and sampling_rate in order to use feature extractor or "
|
2527
|
+
"notes to use the tokenizer part."
|
2528
|
+
)
|
2529
|
+
|
2530
|
+
if audio is not None and sampling_rate is not None:
|
2531
|
+
inputs = self.feature_extractor(
|
2532
|
+
audio=audio,
|
2533
|
+
sampling_rate=sampling_rate,
|
2534
|
+
steps_per_beat=steps_per_beat,
|
2535
|
+
resample=resample,
|
2536
|
+
**kwargs,
|
2537
|
+
)
|
2538
|
+
|
2539
|
+
if notes is not None:
|
2540
|
+
encoded_token_ids = self.tokenizer(
|
2541
|
+
notes=notes,
|
2542
|
+
padding=padding,
|
2543
|
+
truncation=truncation,
|
2544
|
+
max_length=max_length,
|
2545
|
+
pad_to_multiple_of=pad_to_multiple_of,
|
2546
|
+
verbose=verbose,
|
2547
|
+
**kwargs,
|
2548
|
+
)
|
2549
|
+
|
2550
|
+
if notes is None:
|
2551
|
+
return inputs
|
2552
|
+
|
2553
|
+
elif audio is None or sampling_rate is None:
|
2554
|
+
return encoded_token_ids
|
2555
|
+
|
2556
|
+
else:
|
2557
|
+
inputs["token_ids"] = encoded_token_ids["token_ids"]
|
2558
|
+
return inputs
|
2559
|
+
|
2560
|
+
def batch_decode(
|
2561
|
+
self,
|
2562
|
+
token_ids,
|
2563
|
+
feature_extractor_output: BatchFeature,
|
2564
|
+
return_midi: bool = True,
|
2565
|
+
) -> BatchEncoding:
|
2566
|
+
"""
|
2567
|
+
This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
|
2568
|
+
|
2569
|
+
Please refer to the docstring of the above two methods for more information.
|
2570
|
+
"""
|
2571
|
+
|
2572
|
+
return self.tokenizer.batch_decode(
|
2573
|
+
token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
|
2574
|
+
)
|
2575
|
+
|
2576
|
+
@property
|
2577
|
+
def model_input_names(self):
|
2578
|
+
tokenizer_input_names = self.tokenizer.model_input_names
|
2579
|
+
feature_extractor_input_names = self.feature_extractor.model_input_names
|
2580
|
+
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
2581
|
+
|
2582
|
+
def save_pretrained(self, save_directory, **kwargs):
|
2583
|
+
if os.path.isfile(save_directory):
|
2584
|
+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
2585
|
+
os.makedirs(save_directory, exist_ok=True)
|
2586
|
+
return super().save_pretrained(save_directory, **kwargs)
|
2587
|
+
|
2588
|
+
@classmethod
|
2589
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
2590
|
+
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
2591
|
+
return cls(*args)
|
2592
|
+
|
2593
|
+
|
2594
|
+
|
2595
|
+
class Pop2Piano:
|
2596
|
+
def __init__(self,device="cpu",model_path=snapshot_download("sweetcocoa/pop2piano")):
|
2597
|
+
self.model = Pop2PianoForConditionalGeneration.from_pretrained(model_path).to(device)
|
2598
|
+
self.processor = Pop2PianoProcessor.from_pretrained(model_path)
|
2599
|
+
|
2600
|
+
def predict(self,audio,composer=1,num_bars=2,num_beams=1,steps_per_beat=2,output_file="output.mid"):
|
2601
|
+
data, sr = librosa_load(audio, sr=None)
|
2602
|
+
inputs = self.processor(data, sr, steps_per_beat,return_tensors="pt",num_bars=num_bars)
|
2603
|
+
self.processor.batch_decode(self.model.generate(num_beams=num_beams,do_sample=True,input_features=inputs["input_features"], composer="composer" + str(composer)),inputs)["pretty_midi_objects"][0].write(open(output_file, "wb"))
|
2604
|
+
return output_file
|