xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -11,16 +11,15 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Tuple, Optional, Dict, Any
14
+ from typing import Tuple
15
15
  import torch
16
16
  import torch.nn as nn
17
17
  import torch.nn.functional as F
18
18
  from einops import pack, rearrange, repeat
19
- from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
20
19
  from cosyvoice.utils.common import mask_to_bias
21
20
  from cosyvoice.utils.mask import add_optional_chunk_mask
22
21
  from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
23
- from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
22
+ from matcha.models.components.transformer import BasicTransformerBlock
24
23
 
25
24
 
26
25
  class Transpose(torch.nn.Module):
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
29
28
  self.dim0 = dim0
30
29
  self.dim1 = dim1
31
30
 
32
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
32
  x = torch.transpose(x, self.dim0, self.dim1)
34
33
  return x
35
34
 
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
57
56
  assert stride == 1
58
57
  self.causal_padding = kernel_size - 1
59
58
 
60
- def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
61
- if cache.size(2) == 0:
62
- x = F.pad(x, (self.causal_padding, 0), value=0.0)
63
- else:
64
- assert cache.size(2) == self.causal_padding
65
- x = torch.concat([cache, x], dim=2)
66
- cache = x[:, :, -self.causal_padding:]
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
67
61
  x = super(CausalConv1d, self).forward(x)
68
- return x, cache
62
+ return x
69
63
 
70
64
 
71
65
  class CausalBlock1D(Block1D):
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
79
73
  nn.Mish(),
80
74
  )
81
75
 
82
- def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
83
- output, cache = self.block[0](x * mask, cache)
84
- for i in range(1, len(self.block)):
85
- output = self.block[i](output)
86
- return output * mask, cache
76
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ output = self.block(x * mask)
78
+ return output * mask
87
79
 
88
80
 
89
81
  class CausalResnetBlock1D(ResnetBlock1D):
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
92
84
  self.block1 = CausalBlock1D(dim, dim_out)
93
85
  self.block2 = CausalBlock1D(dim_out, dim_out)
94
86
 
95
- def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
96
- block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
97
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
- h, block1_cache = self.block1(x, mask, block1_cache)
99
- h += self.mlp(time_emb).unsqueeze(-1)
100
- h, block2_cache = self.block2(h, mask, block2_cache)
101
- output = h + self.res_conv(x * mask)
102
- return output, block1_cache, block2_cache
103
-
104
-
105
- class CausalAttnProcessor2_0(AttnProcessor2_0):
106
- r"""
107
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
108
- """
109
-
110
- def __init__(self):
111
- super(CausalAttnProcessor2_0, self).__init__()
112
-
113
- def __call__(
114
- self,
115
- attn: Attention,
116
- hidden_states: torch.FloatTensor,
117
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
118
- attention_mask: Optional[torch.FloatTensor] = None,
119
- temb: Optional[torch.FloatTensor] = None,
120
- cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
121
- *args,
122
- **kwargs,
123
- ) -> Tuple[torch.FloatTensor, torch.Tensor]:
124
- if len(args) > 0 or kwargs.get("scale", None) is not None:
125
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
126
- `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
127
- deprecate("scale", "1.0.0", deprecation_message)
128
-
129
- residual = hidden_states
130
- if attn.spatial_norm is not None:
131
- hidden_states = attn.spatial_norm(hidden_states, temb)
132
-
133
- input_ndim = hidden_states.ndim
134
-
135
- if input_ndim == 4:
136
- batch_size, channel, height, width = hidden_states.shape
137
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
138
-
139
- batch_size, sequence_length, _ = (
140
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
141
- )
142
-
143
- if attention_mask is not None:
144
- # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
145
- # scaled_dot_product_attention expects attention_mask shape to be
146
- # (batch, heads, source_length, target_length)
147
- attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
148
-
149
- if attn.group_norm is not None:
150
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
151
-
152
- query = attn.to_q(hidden_states)
153
-
154
- if encoder_hidden_states is None:
155
- encoder_hidden_states = hidden_states
156
- elif attn.norm_cross:
157
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
158
-
159
- key_cache = attn.to_k(encoder_hidden_states)
160
- value_cache = attn.to_v(encoder_hidden_states)
161
- # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
162
- if cache.size(0) != 0:
163
- key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
164
- value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
165
- else:
166
- key, value = key_cache, value_cache
167
- cache = torch.stack([key_cache, value_cache], dim=3)
168
-
169
- inner_dim = key.shape[-1]
170
- head_dim = inner_dim // attn.heads
171
-
172
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173
-
174
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
-
177
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
178
- # TODO: add support for attn.scale when we move to Torch 2.1
179
- hidden_states = F.scaled_dot_product_attention(
180
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
181
- )
182
-
183
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
184
- hidden_states = hidden_states.to(query.dtype)
185
-
186
- # linear proj
187
- hidden_states = attn.to_out[0](hidden_states)
188
- # dropout
189
- hidden_states = attn.to_out[1](hidden_states)
190
-
191
- if input_ndim == 4:
192
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
193
-
194
- if attn.residual_connection:
195
- hidden_states = hidden_states + residual
196
-
197
- hidden_states = hidden_states / attn.rescale_output_factor
198
-
199
- return hidden_states, cache
200
-
201
-
202
- @maybe_allow_in_graph
203
- class CausalAttention(Attention):
204
- def __init__(
205
- self,
206
- query_dim: int,
207
- cross_attention_dim: Optional[int] = None,
208
- heads: int = 8,
209
- dim_head: int = 64,
210
- dropout: float = 0.0,
211
- bias: bool = False,
212
- upcast_attention: bool = False,
213
- upcast_softmax: bool = False,
214
- cross_attention_norm: Optional[str] = None,
215
- cross_attention_norm_num_groups: int = 32,
216
- qk_norm: Optional[str] = None,
217
- added_kv_proj_dim: Optional[int] = None,
218
- norm_num_groups: Optional[int] = None,
219
- spatial_norm_dim: Optional[int] = None,
220
- out_bias: bool = True,
221
- scale_qk: bool = True,
222
- only_cross_attention: bool = False,
223
- eps: float = 1e-5,
224
- rescale_output_factor: float = 1.0,
225
- residual_connection: bool = False,
226
- _from_deprecated_attn_block: bool = False,
227
- processor: Optional["AttnProcessor2_0"] = None,
228
- out_dim: int = None,
229
- ):
230
- super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
231
- cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
232
- spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
233
- _from_deprecated_attn_block, processor, out_dim)
234
- processor = CausalAttnProcessor2_0()
235
- self.set_processor(processor)
236
-
237
- def forward(
238
- self,
239
- hidden_states: torch.FloatTensor,
240
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
241
- attention_mask: Optional[torch.FloatTensor] = None,
242
- cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
243
- **cross_attention_kwargs,
244
- ) -> Tuple[torch.Tensor, torch.Tensor]:
245
- r"""
246
- The forward method of the `Attention` class.
247
-
248
- Args:
249
- hidden_states (`torch.Tensor`):
250
- The hidden states of the query.
251
- encoder_hidden_states (`torch.Tensor`, *optional*):
252
- The hidden states of the encoder.
253
- attention_mask (`torch.Tensor`, *optional*):
254
- The attention mask to use. If `None`, no mask is applied.
255
- **cross_attention_kwargs:
256
- Additional keyword arguments to pass along to the cross attention.
257
-
258
- Returns:
259
- `torch.Tensor`: The output of the attention layer.
260
- """
261
- # The `Attention` class can call different attention processors / attention functions
262
- # here we simply pass along all tensors to the selected processor class
263
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
264
-
265
- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
266
- unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
267
- if len(unused_kwargs) > 0:
268
- logger.warning(
269
- f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
270
- )
271
- cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
272
-
273
- return self.processor(
274
- self,
275
- hidden_states,
276
- encoder_hidden_states=encoder_hidden_states,
277
- attention_mask=attention_mask,
278
- cache=cache,
279
- **cross_attention_kwargs,
280
- )
281
-
282
-
283
- @maybe_allow_in_graph
284
- class CausalBasicTransformerBlock(BasicTransformerBlock):
285
- def __init__(
286
- self,
287
- dim: int,
288
- num_attention_heads: int,
289
- attention_head_dim: int,
290
- dropout=0.0,
291
- cross_attention_dim: Optional[int] = None,
292
- activation_fn: str = "geglu",
293
- num_embeds_ada_norm: Optional[int] = None,
294
- attention_bias: bool = False,
295
- only_cross_attention: bool = False,
296
- double_self_attention: bool = False,
297
- upcast_attention: bool = False,
298
- norm_elementwise_affine: bool = True,
299
- norm_type: str = "layer_norm",
300
- final_dropout: bool = False,
301
- ):
302
- super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
303
- cross_attention_dim, activation_fn, num_embeds_ada_norm,
304
- attention_bias, only_cross_attention, double_self_attention,
305
- upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
306
- self.attn1 = CausalAttention(
307
- query_dim=dim,
308
- heads=num_attention_heads,
309
- dim_head=attention_head_dim,
310
- dropout=dropout,
311
- bias=attention_bias,
312
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
313
- upcast_attention=upcast_attention,
314
- )
315
-
316
- def forward(
317
- self,
318
- hidden_states: torch.FloatTensor,
319
- attention_mask: Optional[torch.FloatTensor] = None,
320
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
321
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
322
- timestep: Optional[torch.LongTensor] = None,
323
- cross_attention_kwargs: Dict[str, Any] = None,
324
- class_labels: Optional[torch.LongTensor] = None,
325
- cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
326
- ) -> Tuple[torch.Tensor, torch.Tensor]:
327
- # Notice that normalization is always applied before the real computation in the following blocks.
328
- # 1. Self-Attention
329
- if self.use_ada_layer_norm:
330
- norm_hidden_states = self.norm1(hidden_states, timestep)
331
- elif self.use_ada_layer_norm_zero:
332
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
333
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
334
- )
335
- else:
336
- norm_hidden_states = self.norm1(hidden_states)
337
-
338
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
339
-
340
- attn_output, cache = self.attn1(
341
- norm_hidden_states,
342
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
343
- attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
344
- cache=cache,
345
- **cross_attention_kwargs,
346
- )
347
- if self.use_ada_layer_norm_zero:
348
- attn_output = gate_msa.unsqueeze(1) * attn_output
349
- hidden_states = attn_output + hidden_states
350
-
351
- # 2. Cross-Attention
352
- if self.attn2 is not None:
353
- norm_hidden_states = (
354
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
355
- )
356
-
357
- attn_output = self.attn2(
358
- norm_hidden_states,
359
- encoder_hidden_states=encoder_hidden_states,
360
- attention_mask=encoder_attention_mask,
361
- **cross_attention_kwargs,
362
- )
363
- hidden_states = attn_output + hidden_states
364
-
365
- # 3. Feed-forward
366
- norm_hidden_states = self.norm3(hidden_states)
367
-
368
- if self.use_ada_layer_norm_zero:
369
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
370
-
371
- if self._chunk_size is not None:
372
- # "feed_forward_chunk_size" can be used to save memory
373
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
374
- raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
375
- {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
376
-
377
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
378
- ff_output = torch.cat(
379
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
380
- dim=self._chunk_dim,
381
- )
382
- else:
383
- ff_output = self.ff(norm_hidden_states)
384
-
385
- if self.use_ada_layer_norm_zero:
386
- ff_output = gate_mlp.unsqueeze(1) * ff_output
387
-
388
- hidden_states = ff_output + hidden_states
389
-
390
- return hidden_states, cache
391
-
392
87
 
393
88
  class ConditionalDecoder(nn.Module):
394
89
  def __init__(
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
640
335
  resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
641
336
  transformer_blocks = nn.ModuleList(
642
337
  [
643
- CausalBasicTransformerBlock(
338
+ BasicTransformerBlock(
644
339
  dim=output_channel,
645
340
  num_attention_heads=num_heads,
646
341
  attention_head_dim=attention_head_dim,
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
662
357
 
663
358
  transformer_blocks = nn.ModuleList(
664
359
  [
665
- CausalBasicTransformerBlock(
360
+ BasicTransformerBlock(
666
361
  dim=output_channel,
667
362
  num_attention_heads=num_heads,
668
363
  attention_head_dim=attention_head_dim,
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
687
382
  )
688
383
  transformer_blocks = nn.ModuleList(
689
384
  [
690
- CausalBasicTransformerBlock(
385
+ BasicTransformerBlock(
691
386
  dim=output_channel,
692
387
  num_attention_heads=num_heads,
693
388
  attention_head_dim=attention_head_dim,
@@ -724,7 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder):
724
419
  Returns:
725
420
  _type_: _description_
726
421
  """
727
-
728
422
  t = self.time_embeddings(t).to(t.dtype)
729
423
  t = self.time_mlp(t)
730
424
 
@@ -740,36 +434,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
740
434
  masks = [mask]
741
435
  for resnet, transformer_blocks, downsample in self.down_blocks:
742
436
  mask_down = masks[-1]
743
- x, _, _ = resnet(x, mask_down, t)
437
+ x = resnet(x, mask_down, t)
744
438
  x = rearrange(x, "b c t -> b t c").contiguous()
745
439
  if streaming is True:
746
- attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
440
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
747
441
  else:
748
442
  attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
749
443
  attn_mask = mask_to_bias(attn_mask, x.dtype)
750
444
  for transformer_block in transformer_blocks:
751
- x, _ = transformer_block(
445
+ x = transformer_block(
752
446
  hidden_states=x,
753
447
  attention_mask=attn_mask,
754
448
  timestep=t,
755
449
  )
756
450
  x = rearrange(x, "b t c -> b c t").contiguous()
757
451
  hiddens.append(x) # Save hidden states for skip connections
758
- x, _ = downsample(x * mask_down)
452
+ x = downsample(x * mask_down)
759
453
  masks.append(mask_down[:, :, ::2])
760
454
  masks = masks[:-1]
761
455
  mask_mid = masks[-1]
762
456
 
763
457
  for resnet, transformer_blocks in self.mid_blocks:
764
- x, _, _ = resnet(x, mask_mid, t)
458
+ x = resnet(x, mask_mid, t)
765
459
  x = rearrange(x, "b c t -> b t c").contiguous()
766
460
  if streaming is True:
767
- attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
461
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
768
462
  else:
769
463
  attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
770
464
  attn_mask = mask_to_bias(attn_mask, x.dtype)
771
465
  for transformer_block in transformer_blocks:
772
- x, _ = transformer_block(
466
+ x = transformer_block(
773
467
  hidden_states=x,
774
468
  attention_mask=attn_mask,
775
469
  timestep=t,
@@ -780,124 +474,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
780
474
  mask_up = masks.pop()
781
475
  skip = hiddens.pop()
782
476
  x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
783
- x, _, _ = resnet(x, mask_up, t)
477
+ x = resnet(x, mask_up, t)
784
478
  x = rearrange(x, "b c t -> b t c").contiguous()
785
479
  if streaming is True:
786
- attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
480
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
787
481
  else:
788
482
  attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
789
483
  attn_mask = mask_to_bias(attn_mask, x.dtype)
790
484
  for transformer_block in transformer_blocks:
791
- x, _ = transformer_block(
485
+ x = transformer_block(
792
486
  hidden_states=x,
793
487
  attention_mask=attn_mask,
794
488
  timestep=t,
795
489
  )
796
490
  x = rearrange(x, "b t c -> b c t").contiguous()
797
- x, _ = upsample(x * mask_up)
798
- x, _ = self.final_block(x, mask_up)
491
+ x = upsample(x * mask_up)
492
+ x = self.final_block(x, mask_up)
799
493
  output = self.final_proj(x * mask_up)
800
494
  return output * mask
801
-
802
- @torch.inference_mode()
803
- def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
804
- down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
805
- down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
806
- mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
807
- mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
808
- up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
809
- up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
810
- final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
811
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
812
- """Forward pass of the UNet1DConditional model.
813
-
814
- Args:
815
- x (torch.Tensor): shape (batch_size, in_channels, time)
816
- mask (_type_): shape (batch_size, 1, time)
817
- t (_type_): shape (batch_size)
818
- spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
819
- cond (_type_, optional): placeholder for future use. Defaults to None.
820
-
821
- Raises:
822
- ValueError: _description_
823
- ValueError: _description_
824
-
825
- Returns:
826
- _type_: _description_
827
- """
828
-
829
- t = self.time_embeddings(t).to(t.dtype)
830
- t = self.time_mlp(t)
831
-
832
- x = pack([x, mu], "b * t")[0]
833
-
834
- if spks is not None:
835
- spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
836
- x = pack([x, spks], "b * t")[0]
837
- if cond is not None:
838
- x = pack([x, cond], "b * t")[0]
839
-
840
- hiddens = []
841
- masks = [mask]
842
-
843
- down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
844
- mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
845
- up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
846
- for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
847
- mask_down = masks[-1]
848
- x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
849
- resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
850
- x = rearrange(x, "b c t -> b t c").contiguous()
851
- attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
852
- attn_mask = mask_to_bias(attn_mask, x.dtype)
853
- for i, transformer_block in enumerate(transformer_blocks):
854
- x, down_blocks_kv_cache_new[index, i] = transformer_block(
855
- hidden_states=x,
856
- attention_mask=attn_mask,
857
- timestep=t,
858
- cache=down_blocks_kv_cache[index, i],
859
- )
860
- x = rearrange(x, "b t c -> b c t").contiguous()
861
- hiddens.append(x) # Save hidden states for skip connections
862
- x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
863
- masks.append(mask_down[:, :, ::2])
864
- masks = masks[:-1]
865
- mask_mid = masks[-1]
866
-
867
- for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
868
- x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
869
- resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
870
- x = rearrange(x, "b c t -> b t c").contiguous()
871
- attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
872
- attn_mask = mask_to_bias(attn_mask, x.dtype)
873
- for i, transformer_block in enumerate(transformer_blocks):
874
- x, mid_blocks_kv_cache_new[index, i] = transformer_block(
875
- hidden_states=x,
876
- attention_mask=attn_mask,
877
- timestep=t,
878
- cache=mid_blocks_kv_cache[index, i]
879
- )
880
- x = rearrange(x, "b t c -> b c t").contiguous()
881
-
882
- for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
883
- mask_up = masks.pop()
884
- skip = hiddens.pop()
885
- x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
886
- x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
887
- resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
888
- x = rearrange(x, "b c t -> b t c").contiguous()
889
- attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
890
- attn_mask = mask_to_bias(attn_mask, x.dtype)
891
- for i, transformer_block in enumerate(transformer_blocks):
892
- x, up_blocks_kv_cache_new[index, i] = transformer_block(
893
- hidden_states=x,
894
- attention_mask=attn_mask,
895
- timestep=t,
896
- cache=up_blocks_kv_cache[index, i]
897
- )
898
- x = rearrange(x, "b t c -> b c t").contiguous()
899
- x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
900
- x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
901
- output = self.final_proj(x * mask_up)
902
- return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
903
- up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
92
92
 
93
93
  mask = (~make_pad_mask(feat_len)).to(h)
94
94
  # NOTE this is unnecessary, feat/h already same shape
95
- feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
96
95
  loss, _ = self.decoder.compute_loss(
97
96
  feat.transpose(1, 2).contiguous(),
98
97
  mask.unsqueeze(1),
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
214
213
  h = self.encoder_proj(h)
215
214
 
216
215
  # get conditions
217
- feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
218
216
  conds = torch.zeros(feat.shape, device=token.device)
219
217
  for i, j in enumerate(feat_len):
220
218
  if random.random() < 0.5:
@@ -243,7 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
243
241
  prompt_feat,
244
242
  prompt_feat_len,
245
243
  embedding,
246
- cache,
244
+ streaming,
247
245
  finalize):
248
246
  assert token.shape[0] == 1
249
247
  # xvec projection
@@ -257,16 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
257
255
 
258
256
  # text encode
259
257
  if finalize is True:
260
- h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
258
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
261
259
  else:
262
260
  token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
263
- h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
264
- cache['encoder_cache']['offset'] = encoder_cache[0]
265
- cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
266
- cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
267
- cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
268
- cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
269
- cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
261
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
270
262
  mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
271
263
  h = self.encoder_proj(h)
272
264
 
@@ -276,14 +268,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
276
268
  conds = conds.transpose(1, 2)
277
269
 
278
270
  mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
279
- feat, cache['decoder_cache'] = self.decoder(
271
+ feat, _ = self.decoder(
280
272
  mu=h.transpose(1, 2).contiguous(),
281
273
  mask=mask.unsqueeze(1),
282
274
  spks=embedding,
283
275
  cond=conds,
284
276
  n_timesteps=10,
285
- cache=cache['decoder_cache']
277
+ streaming=streaming
286
278
  )
287
279
  feat = feat[:, :, mel_len1:]
288
280
  assert feat.shape[2] == mel_len2
289
- return feat.float(), cache
281
+ return feat.float(), None