xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,649 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from typing import Optional, Tuple
18
+ from torch import nn
19
+ from torch.nn import Parameter, Linear
20
+ from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
21
+ from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
22
+ from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
23
+ import torch.nn.functional as F
24
+
25
+ DEFAULT_MAX_SOURCE_POSITIONS = 3000
26
+ DEFAULT_MAX_TARGET_POSITIONS = 3000
27
+
28
+
29
+ class SinusoidalPositionalEmbedding(nn.Module):
30
+ """This module produces sinusoidal positional embeddings of any length.
31
+
32
+ Padding symbols are ignored.
33
+ """
34
+
35
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
36
+ super().__init__()
37
+ self.embedding_dim = embedding_dim
38
+ self.padding_idx = padding_idx
39
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
40
+ init_size,
41
+ embedding_dim,
42
+ padding_idx,
43
+ )
44
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
45
+
46
+ @staticmethod
47
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
48
+ """Build sinusoidal embeddings.
49
+
50
+ This matches the implementation in tensor2tensor, but differs slightly
51
+ from the description in Section 3.5 of "Attention Is All You Need".
52
+ """
53
+ half_dim = embedding_dim // 2
54
+ emb = math.log(10000) / (half_dim - 1)
55
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
56
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
57
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
58
+ if embedding_dim % 2 == 1:
59
+ # zero pad
60
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
61
+ if padding_idx is not None:
62
+ emb[padding_idx, :] = 0
63
+ return emb
64
+
65
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
66
+ """Input is expected to be of size [bsz x seqlen]."""
67
+ bsz, seq_len = input.shape[:2]
68
+ max_pos = self.padding_idx + 1 + seq_len
69
+ if self.weights is None or max_pos > self.weights.size(0):
70
+ # recompute/expand embeddings if needed
71
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
72
+ max_pos,
73
+ self.embedding_dim,
74
+ self.padding_idx,
75
+ )
76
+ self.weights = self.weights.to(self._float_tensor)
77
+
78
+ if incremental_state is not None:
79
+ # positions is the same for every token when decoding a single step
80
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
81
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
82
+
83
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
84
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
85
+
86
+ def max_positions(self):
87
+ """Maximum number of supported positions."""
88
+ return int(1e5) # an arbitrary large number
89
+
90
+
91
+ class RotaryEmbeddings(nn.Module):
92
+ cos: torch.Tensor
93
+ sin: torch.Tensor
94
+ theta: torch.Tensor
95
+
96
+ def __init__(
97
+ self,
98
+ width: int,
99
+ *,
100
+ seq_len: int = 40000,
101
+ base: int = 10000,
102
+ device: Optional[torch.device] = None,
103
+ ):
104
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
105
+ will be precomputed for up to 'seq _len' positions. The embedding
106
+ will be recomputed when a longer sequence is found in the input.
107
+
108
+ :param width:
109
+ Rotary embedding dimensionality, must be even.
110
+ :param seq_len:
111
+ Number of positons to initially precompute.
112
+ :param base:
113
+ The base used for Θ_i, determines the cycle length of the
114
+ embeddings.
115
+ :param device: Device on which the module is to be initialized.
116
+ """
117
+ super().__init__()
118
+
119
+ if width % 2:
120
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
121
+
122
+ # Ignore allocations on the meta device as we don't persist our buffer,
123
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
124
+ if device is not None and device.type == "meta":
125
+ device = None
126
+ # Θ_i = 10000^(-2(i-1)/d)
127
+ theta = torch.pow(
128
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
129
+ )
130
+ self.register_buffer("theta", theta, persistent=False)
131
+
132
+ self._create_rotary_embed(width=width, length=seq_len)
133
+
134
+ def _create_rotary_embed(self, *, width: int, length: int):
135
+ # mΘ
136
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
137
+ m_theta = position * self.theta.unsqueeze(0)
138
+
139
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
140
+ # is changed for compatibility with most common implementations.
141
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
142
+
143
+ re_cos = m_theta.cos().view([length, width])
144
+ re_sin = m_theta.sin().view([length, width])
145
+
146
+ self.register_buffer("cos", re_cos, persistent=False)
147
+ self.register_buffer("sin", re_sin, persistent=False)
148
+
149
+ def _rotate(self, input: torch.Tensor):
150
+ """Rotate the input tensor by half of its innermost width.
151
+
152
+ input (Tensor): array to rotate.
153
+ RETURNS (Tensor): rotated array.
154
+
155
+ Shapes:
156
+ input - (..., width)
157
+ output - (..., width)
158
+ """
159
+ half_idx = input.shape[-1] // 2
160
+ input_1 = -input[..., half_idx:]
161
+ input_2 = input[..., :half_idx]
162
+ return torch.cat([input_1, input_2], dim=-1)
163
+
164
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
165
+ """
166
+ Apply rotary embeddings to an array.
167
+
168
+ :param input: Array to apply the rotary embeddings to.
169
+ :param positions: positions of the inputs. If no positions are
170
+ provided, they are assumed to be [0, seq_len).
171
+ :return: Array with the rotary embeddings applied.
172
+
173
+ Shapes:
174
+ input - (batch_size, num_heads, seq_len, width_per_head)
175
+ positions - (batch_size, seq_len)
176
+ output - (batch_size, num_heads, seq_len, width_per_head)
177
+ """
178
+ batch_size, _, seq_len, width = input.shape
179
+
180
+ if positions is None:
181
+ # Fastpath: positions from [0..seq_len), avoid indexing.
182
+ if self.cos.size(-2) < seq_len:
183
+ self._create_rotary_embed(width=width, length=seq_len)
184
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
185
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
186
+ else:
187
+ max_len = int(positions.max()) + 1
188
+ if self.cos.size(-2) < max_len:
189
+ self._create_rotary_embed(width=width, length=max_len)
190
+
191
+ # Flatten positions to index cos/sin arrays, then unflatten.
192
+ #
193
+ # Example shapes:
194
+ #
195
+ # positions_flat - (batch_size * seq_len)
196
+ # self.cos - (max_len, width)
197
+ # rot_cos - (batch_size, seq_len, width)
198
+ positions_flat = positions.view(-1)
199
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
200
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
201
+
202
+ # Eq 34 with ordering changed for compatibility.
203
+ return rot_cos * input + rot_sin * self._rotate(input)
204
+
205
+
206
+ class RotMultiheadAttention(MultiheadAttention):
207
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
208
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
209
+ encoder_decoder_attention=False):
210
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
211
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
212
+ encoder_decoder_attention=encoder_decoder_attention)
213
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
214
+
215
+ def forward(
216
+ self,
217
+ query, key, value,
218
+ spk_pos_ids_flat=None,
219
+ key_padding_mask=None,
220
+ incremental_state=None,
221
+ need_weights=True,
222
+ static_kv=False,
223
+ attn_mask=None,
224
+ before_softmax=False,
225
+ need_head_weights=False,
226
+ enc_dec_attn_constraint_mask=None,
227
+ reset_attn_weight=None
228
+ ):
229
+ """Input shape: Time x Batch x Channel
230
+
231
+ Args:
232
+ key_padding_mask (ByteTensor, optional): mask to exclude
233
+ keys that are pads, of shape `(batch, src_len)`, where
234
+ padding elements are indicated by 1s.
235
+ need_weights (bool, optional): return the attention weights,
236
+ averaged over heads (default: False).
237
+ attn_mask (ByteTensor, optional): typically used to
238
+ implement causal attention, where the mask prevents the
239
+ attention from looking forward in time (default: None).
240
+ before_softmax (bool, optional): return the raw attention
241
+ weights and values before the attention softmax.
242
+ need_head_weights (bool, optional): return the attention
243
+ weights for each head. Implies *need_weights*. Default:
244
+ return the average attention weights over all heads.
245
+ """
246
+ if need_head_weights:
247
+ need_weights = True
248
+
249
+ tgt_len, bsz, embed_dim = query.size()
250
+ assert embed_dim == self.embed_dim
251
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
252
+
253
+ if incremental_state is not None:
254
+ saved_state = self._get_input_buffer(incremental_state)
255
+ if 'prev_key' in saved_state:
256
+ # previous time steps are cached - no need to recompute
257
+ # key and value if they are static
258
+ if static_kv:
259
+ assert self.encoder_decoder_attention and not self.self_attention
260
+ key = value = None
261
+ else:
262
+ saved_state = None
263
+
264
+ if self.self_attention:
265
+ # self-attention
266
+ q, k, v = self.in_proj_qkv(query)
267
+ elif self.encoder_decoder_attention:
268
+ # encoder-decoder attention
269
+ q = self.in_proj_q(query)
270
+ if key is None:
271
+ assert value is None
272
+ k = v = None
273
+ else:
274
+ k = self.in_proj_k(key)
275
+ v = self.in_proj_v(key)
276
+ else:
277
+ q = self.in_proj_q(query)
278
+ k = self.in_proj_k(key)
279
+ v = self.in_proj_v(value)
280
+ q = q * self.scaling
281
+
282
+ if self.bias_k is not None:
283
+ assert self.bias_v is not None
284
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
285
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
286
+ if attn_mask is not None:
287
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
288
+ if key_padding_mask is not None:
289
+ key_padding_mask = torch.cat(
290
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
291
+
292
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
293
+ if k is not None:
294
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
295
+ if v is not None:
296
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
297
+
298
+ # Apply rot embedding and store incremental_state
299
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
300
+ if saved_state is not None:
301
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
302
+ if 'prev_key' in saved_state:
303
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
304
+ if static_kv:
305
+ k = prev_key
306
+ else:
307
+ k = torch.cat((prev_key, k), dim=1)
308
+ if 'prev_value' in saved_state:
309
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
310
+ if static_kv:
311
+ v = prev_value
312
+ else:
313
+ v = torch.cat((prev_value, v), dim=1)
314
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
315
+ bsz, self.num_heads, -1, self.head_dim)
316
+ self._set_input_buffer(incremental_state, saved_state)
317
+ if incremental_state is not None:
318
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
319
+ else:
320
+ key_pos = spk_pos_ids_flat
321
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
322
+
323
+ src_len = k.size(1)
324
+
325
+ # This is part of a workaround to get around fork/join parallelism
326
+ # not supporting Optional types.
327
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
328
+ key_padding_mask = None
329
+
330
+ if key_padding_mask is not None:
331
+ assert key_padding_mask.size(0) == bsz
332
+ assert key_padding_mask.size(1) == src_len
333
+
334
+ if self.add_zero_attn:
335
+ src_len += 1
336
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
337
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
338
+ if attn_mask is not None:
339
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
340
+ if key_padding_mask is not None:
341
+ key_padding_mask = torch.cat(
342
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
343
+
344
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
345
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
346
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
347
+
348
+ if attn_mask is not None:
349
+ if len(attn_mask.shape) == 2:
350
+ attn_mask = attn_mask.unsqueeze(0)
351
+ elif len(attn_mask.shape) == 3:
352
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
353
+ bsz * self.num_heads, tgt_len, src_len)
354
+ attn_weights = attn_weights + attn_mask
355
+
356
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
357
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
358
+ attn_weights = attn_weights.masked_fill(
359
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
360
+ -1e8,
361
+ )
362
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
363
+
364
+ if key_padding_mask is not None:
365
+ # don't attend to padding symbols
366
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
367
+ attn_weights = attn_weights.masked_fill(
368
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
369
+ -1e8,
370
+ )
371
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
372
+
373
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
374
+
375
+ if before_softmax:
376
+ return attn_weights, v
377
+
378
+ attn_weights_float = softmax(attn_weights, dim=-1)
379
+ attn_weights = attn_weights_float.type_as(attn_weights)
380
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
381
+
382
+ if reset_attn_weight is not None:
383
+ if reset_attn_weight:
384
+ self.last_attn_probs = attn_probs.detach()
385
+ else:
386
+ assert self.last_attn_probs is not None
387
+ attn_probs = self.last_attn_probs
388
+ attn = torch.bmm(attn_probs, v)
389
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
390
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
391
+ attn = self.out_proj(attn)
392
+
393
+ if need_weights:
394
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
395
+ if not need_head_weights:
396
+ # average attention weights over heads
397
+ attn_weights = attn_weights.mean(dim=0)
398
+ else:
399
+ attn_weights = None
400
+
401
+ return attn, (attn_weights, attn_logits)
402
+
403
+
404
+ class RotMultiheadAttention2(MultiheadAttention):
405
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
406
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
407
+ encoder_decoder_attention=False):
408
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
409
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
410
+ encoder_decoder_attention=encoder_decoder_attention)
411
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
412
+
413
+ def forward(
414
+ self,
415
+ query, key, value,
416
+ spk_pos_ids_flat=None,
417
+ key_padding_mask=None,
418
+ incremental_state=None,
419
+ need_weights=True,
420
+ static_kv=False,
421
+ attn_mask=None,
422
+ before_softmax=False,
423
+ need_head_weights=False,
424
+ enc_dec_attn_constraint_mask=None,
425
+ reset_attn_weight=None
426
+ ):
427
+ """Input shape: Time x Batch x Channel
428
+
429
+ Args:
430
+ key_padding_mask (ByteTensor, optional): mask to exclude
431
+ keys that are pads, of shape `(batch, src_len)`, where
432
+ padding elements are indicated by 1s.
433
+ need_weights (bool, optional): return the attention weights,
434
+ averaged over heads (default: False).
435
+ attn_mask (ByteTensor, optional): typically used to
436
+ implement causal attention, where the mask prevents the
437
+ attention from looking forward in time (default: None).
438
+ before_softmax (bool, optional): return the raw attention
439
+ weights and values before the attention softmax.
440
+ need_head_weights (bool, optional): return the attention
441
+ weights for each head. Implies *need_weights*. Default:
442
+ return the average attention weights over all heads.
443
+ """
444
+ if need_head_weights:
445
+ need_weights = True
446
+
447
+ tgt_len, bsz, embed_dim = query.size()
448
+ assert embed_dim == self.embed_dim
449
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
450
+
451
+ if incremental_state is not None:
452
+ saved_state = self._get_input_buffer(incremental_state)
453
+ if 'prev_key' in saved_state:
454
+ # previous time steps are cached - no need to recompute
455
+ # key and value if they are static
456
+ if static_kv:
457
+ assert self.encoder_decoder_attention and not self.self_attention
458
+ key = value = None
459
+ else:
460
+ saved_state = None
461
+
462
+ if self.self_attention:
463
+ # self-attention
464
+ q, k, v = self.in_proj_qkv(query)
465
+ elif self.encoder_decoder_attention:
466
+ # encoder-decoder attention
467
+ q = self.in_proj_q(query)
468
+ if key is None:
469
+ assert value is None
470
+ k = v = None
471
+ else:
472
+ k = self.in_proj_k(key)
473
+ v = self.in_proj_v(key)
474
+ else:
475
+ q = self.in_proj_q(query)
476
+ k = self.in_proj_k(key)
477
+ v = self.in_proj_v(value)
478
+
479
+ if self.bias_k is not None:
480
+ assert self.bias_v is not None
481
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
482
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
483
+ if attn_mask is not None:
484
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
485
+ if key_padding_mask is not None:
486
+ key_padding_mask = torch.cat(
487
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
488
+
489
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
490
+ if k is not None:
491
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
492
+ if v is not None:
493
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
494
+
495
+ # Apply rot embedding and store incremental_state
496
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
497
+ if saved_state is not None:
498
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
499
+ if 'prev_key' in saved_state:
500
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
501
+ if static_kv:
502
+ k = prev_key
503
+ else:
504
+ k = torch.cat((prev_key, k), dim=1)
505
+ if 'prev_value' in saved_state:
506
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
507
+ if static_kv:
508
+ v = prev_value
509
+ else:
510
+ v = torch.cat((prev_value, v), dim=1)
511
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
512
+ bsz, self.num_heads, -1, self.head_dim)
513
+ self._set_input_buffer(incremental_state, saved_state)
514
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
515
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
516
+
517
+ src_len = k.size(1)
518
+
519
+ # This is part of a workaround to get around fork/join parallelism
520
+ # not supporting Optional types.
521
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
522
+ key_padding_mask = None
523
+
524
+ if key_padding_mask is not None:
525
+ assert key_padding_mask.size(0) == bsz
526
+ assert key_padding_mask.size(1) == src_len
527
+
528
+ if attn_mask is not None:
529
+ if len(attn_mask.shape) == 2:
530
+ attn_mask = attn_mask.unsqueeze(0)
531
+ elif len(attn_mask.shape) == 3:
532
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
533
+ bsz * self.num_heads, tgt_len, src_len)
534
+ attn = torch.nn.functional.scaled_dot_product_attention(
535
+ q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
536
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
537
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
538
+ attn_logits = None
539
+ attn_weights = None
540
+ return attn, (attn_weights, attn_logits)
541
+
542
+
543
+ class RotDecSALayer(nn.Module):
544
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
545
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
546
+ super().__init__()
547
+ self.c = c
548
+ self.dropout = dropout
549
+ self.layer_norm1 = LayerNorm(c)
550
+ self.self_attn = RotMultiheadAttention(
551
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
552
+ )
553
+ self.layer_norm2 = LayerNorm(c)
554
+ self.ffn = TransformerFFNLayer(
555
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
556
+ dropout=relu_dropout, act=act, bias=bias)
557
+ self.post_ln = post_ln
558
+
559
+ def forward(
560
+ self,
561
+ x,
562
+ encoder_out=None,
563
+ encoder_padding_mask=None,
564
+ incremental_state=None,
565
+ self_attn_mask=None,
566
+ self_attn_padding_mask=None,
567
+ attn_out=None,
568
+ reset_attn_weight=None,
569
+ spk_pos_ids_flat=None,
570
+ **kwargs,
571
+ ):
572
+ layer_norm_training = kwargs.get('layer_norm_training', None)
573
+ if layer_norm_training is not None:
574
+ self.layer_norm1.training = layer_norm_training
575
+ self.layer_norm2.training = layer_norm_training
576
+ residual = x
577
+ if not self.post_ln:
578
+ x = self.layer_norm1(x)
579
+
580
+ x, (attn_weights, _) = self.self_attn(
581
+ query=x,
582
+ key=x,
583
+ value=x,
584
+ key_padding_mask=self_attn_padding_mask,
585
+ incremental_state=incremental_state,
586
+ attn_mask=self_attn_mask,
587
+ spk_pos_ids_flat=spk_pos_ids_flat
588
+ )
589
+ x = F.dropout(x, self.dropout, training=self.training)
590
+ x = residual + x
591
+ if self.post_ln:
592
+ x = self.layer_norm1(x)
593
+
594
+ residual = x
595
+ if not self.post_ln:
596
+ x = self.layer_norm2(x)
597
+ x = self.ffn(x, incremental_state=incremental_state)
598
+ x = F.dropout(x, self.dropout, training=self.training)
599
+ x = residual + x
600
+ if self.post_ln:
601
+ x = self.layer_norm2(x)
602
+ return x, attn_weights
603
+
604
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
605
+ self.encoder_attn.clear_buffer(incremental_state)
606
+ self.ffn.clear_buffer(incremental_state)
607
+
608
+ def set_buffer(self, name, tensor, incremental_state):
609
+ return set_incremental_state(self, incremental_state, name, tensor)
610
+
611
+
612
+ class RotDecSALayer2(RotDecSALayer):
613
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
614
+ ffn_hidden_size=1024, act='gelu', post_ln=False):
615
+ super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
616
+ post_ln)
617
+ self.self_attn = RotMultiheadAttention2(
618
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
619
+ )
620
+
621
+
622
+ class RotTransformerDecoderLayer(nn.Module):
623
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
624
+ op_version=1, bias=True):
625
+ super().__init__()
626
+ self.hidden_size = hidden_size
627
+ self.dropout = dropout
628
+ self.num_heads = num_heads
629
+ if op_version == 1:
630
+ self.op = RotDecSALayer(
631
+ hidden_size, num_heads, dropout=dropout,
632
+ attention_dropout=0.0, relu_dropout=dropout,
633
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
634
+ post_ln=post_ln, bias=bias)
635
+ else:
636
+ self.op = RotDecSALayer2(
637
+ hidden_size, num_heads, dropout=dropout,
638
+ attention_dropout=0.0, relu_dropout=dropout,
639
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
640
+ post_ln=post_ln)
641
+
642
+ def forward(self, x, **kwargs):
643
+ return self.op(x, **kwargs)
644
+
645
+ def clear_buffer(self, *args):
646
+ return self.op.clear_buffer(*args)
647
+
648
+ def set_buffer(self, *args):
649
+ return self.op.set_buffer(*args)