xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (78) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +26 -4
  4. xinference/client/restful/restful_client.py +16 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +8 -3
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/core.py +5 -2
  9. xinference/model/audio/cosyvoice.py +136 -0
  10. xinference/model/audio/model_spec.json +24 -0
  11. xinference/model/audio/model_spec_modelscope.json +27 -0
  12. xinference/model/flexible/launchers/__init__.py +1 -0
  13. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  14. xinference/model/image/model_spec.json +7 -0
  15. xinference/model/image/stable_diffusion/core.py +6 -1
  16. xinference/model/llm/llm_family.json +802 -82
  17. xinference/model/llm/llm_family_csghub.json +39 -0
  18. xinference/model/llm/llm_family_modelscope.json +295 -47
  19. xinference/model/llm/pytorch/chatglm.py +243 -5
  20. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  21. xinference/model/llm/utils.py +78 -1
  22. xinference/model/llm/vllm/core.py +8 -0
  23. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  24. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  25. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  26. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  27. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  29. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  30. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  31. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  33. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  34. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  35. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  36. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  37. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  38. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  39. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  40. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  41. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  42. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  43. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  44. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  45. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  46. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  47. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  48. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  49. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  50. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  51. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  52. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  53. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  54. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  55. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  56. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  58. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  59. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  60. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  61. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  62. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  63. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  64. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  65. xinference/web/ui/build/asset-manifest.json +3 -3
  66. xinference/web/ui/build/index.html +1 -1
  67. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  68. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  70. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
  71. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
  72. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  74. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  75. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  76. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  77. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  78. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,369 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+
24
+ torchaudio.set_audio_backend('soundfile')
25
+
26
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ df = pq.read_table(url).to_pandas()
44
+ for i in range(len(df)):
45
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
46
+ continue
47
+ sample.update(dict(df.loc[i]))
48
+ if mode == 'train':
49
+ # NOTE do not return sample directly, must initialize a new dict
50
+ yield {**sample}
51
+ else:
52
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
53
+ yield {**sample, 'tts_index': index, 'tts_text': text}
54
+ except Exception as ex:
55
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
+
57
+ def filter(data,
58
+ max_length=10240,
59
+ min_length=10,
60
+ token_max_length=200,
61
+ token_min_length=1,
62
+ min_output_input_ratio=0.0005,
63
+ max_output_input_ratio=1,
64
+ mode='train'):
65
+ """ Filter sample according to feature and label length
66
+ Inplace operation.
67
+
68
+ Args::
69
+ data: Iterable[{key, wav, label, sample_rate}]
70
+ max_length: drop utterance which is greater than max_length(10ms)
71
+ min_length: drop utterance which is less than min_length(10ms)
72
+ token_max_length: drop utterance which is greater than
73
+ token_max_length, especially when use char unit for
74
+ english modeling
75
+ token_min_length: drop utterance which is
76
+ less than token_max_length
77
+ min_output_input_ratio: minimal ration of
78
+ token_length / feats_length(10ms)
79
+ max_output_input_ratio: maximum ration of
80
+ token_length / feats_length(10ms)
81
+
82
+ Returns:
83
+ Iterable[{key, wav, label, sample_rate}]
84
+ """
85
+ for sample in data:
86
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
87
+ del sample['audio_data']
88
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
89
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
90
+ if num_frames < min_length:
91
+ continue
92
+ if num_frames > max_length:
93
+ continue
94
+ if len(sample['text_token']) < token_min_length:
95
+ continue
96
+ if len(sample['text_token']) > token_max_length:
97
+ continue
98
+ if len(sample['speech_token']) == 0:
99
+ continue
100
+ if num_frames != 0:
101
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
102
+ continue
103
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
104
+ continue
105
+ yield sample
106
+
107
+
108
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
109
+ """ Resample data.
110
+ Inplace operation.
111
+
112
+ Args:
113
+ data: Iterable[{key, wav, label, sample_rate}]
114
+ resample_rate: target resample rate
115
+
116
+ Returns:
117
+ Iterable[{key, wav, label, sample_rate}]
118
+ """
119
+ for sample in data:
120
+ assert 'sample_rate' in sample
121
+ assert 'speech' in sample
122
+ sample_rate = sample['sample_rate']
123
+ waveform = sample['speech']
124
+ if sample_rate != resample_rate:
125
+ if sample_rate < min_sample_rate:
126
+ continue
127
+ sample['sample_rate'] = resample_rate
128
+ sample['speech'] = torchaudio.transforms.Resample(
129
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
130
+ max_val = sample['speech'].abs().max()
131
+ if max_val > 1:
132
+ sample['speech'] /= max_val
133
+ yield sample
134
+
135
+
136
+ def compute_fbank(data,
137
+ feat_extractor,
138
+ mode='train'):
139
+ """ Extract fbank
140
+
141
+ Args:
142
+ data: Iterable[{key, wav, label, sample_rate}]
143
+
144
+ Returns:
145
+ Iterable[{key, feat, label}]
146
+ """
147
+ for sample in data:
148
+ assert 'sample_rate' in sample
149
+ assert 'speech' in sample
150
+ assert 'utt' in sample
151
+ assert 'text_token' in sample
152
+ waveform = sample['speech']
153
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
154
+ sample['speech_feat'] = mat
155
+ del sample['speech']
156
+ yield sample
157
+
158
+
159
+ def parse_embedding(data, normalize, mode='train'):
160
+ """ Parse utt_embedding/spk_embedding
161
+
162
+ Args:
163
+ data: Iterable[{key, wav, label, sample_rate}]
164
+
165
+ Returns:
166
+ Iterable[{key, feat, label}]
167
+ """
168
+ for sample in data:
169
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
170
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
171
+ if normalize:
172
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
173
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
174
+ yield sample
175
+
176
+
177
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
178
+ """ Decode text to chars or BPE
179
+ Inplace operation
180
+
181
+ Args:
182
+ data: Iterable[{key, wav, txt, sample_rate}]
183
+
184
+ Returns:
185
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
186
+ """
187
+ tokenizer = get_tokenizer()
188
+ for sample in data:
189
+ assert 'text' in sample
190
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
191
+ if mode == 'inference':
192
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
193
+ yield sample
194
+
195
+
196
+ def shuffle(data, shuffle_size=10000, mode='train'):
197
+ """ Local shuffle the data
198
+
199
+ Args:
200
+ data: Iterable[{key, feat, label}]
201
+ shuffle_size: buffer size for shuffle
202
+
203
+ Returns:
204
+ Iterable[{key, feat, label}]
205
+ """
206
+ buf = []
207
+ for sample in data:
208
+ buf.append(sample)
209
+ if len(buf) >= shuffle_size:
210
+ random.shuffle(buf)
211
+ for x in buf:
212
+ yield x
213
+ buf = []
214
+ # The sample left over
215
+ random.shuffle(buf)
216
+ for x in buf:
217
+ yield x
218
+
219
+
220
+ def sort(data, sort_size=500, mode='train'):
221
+ """ Sort the data by feature length.
222
+ Sort is used after shuffle and before batch, so we can group
223
+ utts with similar lengths into a batch, and `sort_size` should
224
+ be less than `shuffle_size`
225
+
226
+ Args:
227
+ data: Iterable[{key, feat, label}]
228
+ sort_size: buffer size for sort
229
+
230
+ Returns:
231
+ Iterable[{key, feat, label}]
232
+ """
233
+
234
+ buf = []
235
+ for sample in data:
236
+ buf.append(sample)
237
+ if len(buf) >= sort_size:
238
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
239
+ for x in buf:
240
+ yield x
241
+ buf = []
242
+ # The sample left over
243
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
244
+ for x in buf:
245
+ yield x
246
+
247
+
248
+ def static_batch(data, batch_size=16):
249
+ """ Static batch the data by `batch_size`
250
+
251
+ Args:
252
+ data: Iterable[{key, feat, label}]
253
+ batch_size: batch size
254
+
255
+ Returns:
256
+ Iterable[List[{key, feat, label}]]
257
+ """
258
+ buf = []
259
+ for sample in data:
260
+ buf.append(sample)
261
+ if len(buf) >= batch_size:
262
+ yield buf
263
+ buf = []
264
+ if len(buf) > 0:
265
+ yield buf
266
+
267
+
268
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
269
+ """ Dynamic batch the data until the total frames in batch
270
+ reach `max_frames_in_batch`
271
+
272
+ Args:
273
+ data: Iterable[{key, feat, label}]
274
+ max_frames_in_batch: max_frames in one batch
275
+
276
+ Returns:
277
+ Iterable[List[{key, feat, label}]]
278
+ """
279
+ buf = []
280
+ longest_frames = 0
281
+ for sample in data:
282
+ assert 'speech_feat' in sample
283
+ assert isinstance(sample['speech_feat'], torch.Tensor)
284
+ new_sample_frames = sample['speech_feat'].size(0)
285
+ longest_frames = max(longest_frames, new_sample_frames)
286
+ frames_after_padding = longest_frames * (len(buf) + 1)
287
+ if frames_after_padding > max_frames_in_batch:
288
+ yield buf
289
+ buf = [sample]
290
+ longest_frames = new_sample_frames
291
+ else:
292
+ buf.append(sample)
293
+ if len(buf) > 0:
294
+ yield buf
295
+
296
+
297
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
298
+ """ Wrapper for static/dynamic batch
299
+ """
300
+ if mode == 'inference':
301
+ return static_batch(data, 1)
302
+ else:
303
+ if batch_type == 'static':
304
+ return static_batch(data, batch_size)
305
+ elif batch_type == 'dynamic':
306
+ return dynamic_batch(data, max_frames_in_batch)
307
+ else:
308
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
309
+
310
+
311
+ def padding(data, use_spk_embedding, mode='train'):
312
+ """ Padding the data into training data
313
+
314
+ Args:
315
+ data: Iterable[List[{key, feat, label}]]
316
+
317
+ Returns:
318
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
319
+ """
320
+ for sample in data:
321
+ assert isinstance(sample, list)
322
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
323
+ dtype=torch.int32)
324
+ order = torch.argsort(speech_feat_len, descending=True)
325
+
326
+ utts = [sample[i]['utt'] for i in order]
327
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
328
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
329
+ speech_token = pad_sequence(speech_token,
330
+ batch_first=True,
331
+ padding_value=0)
332
+ speech_feat = [sample[i]['speech_feat'] for i in order]
333
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
334
+ speech_feat = pad_sequence(speech_feat,
335
+ batch_first=True,
336
+ padding_value=0)
337
+ text = [sample[i]['text'] for i in order]
338
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
339
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
340
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
341
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
342
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
343
+ batch = {
344
+ "utts": utts,
345
+ "speech_token": speech_token,
346
+ "speech_token_len": speech_token_len,
347
+ "speech_feat": speech_feat,
348
+ "speech_feat_len": speech_feat_len,
349
+ "text": text,
350
+ "text_token": text_token,
351
+ "text_token_len": text_token_len,
352
+ "utt_embedding": utt_embedding,
353
+ "spk_embedding": spk_embedding,
354
+ }
355
+ if mode == 'inference':
356
+ tts_text = [sample[i]['tts_text'] for i in order]
357
+ tts_index = [sample[i]['tts_index'] for i in order]
358
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
359
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
360
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
361
+ batch.update({'tts_text': tts_text,
362
+ 'tts_index': tts_index,
363
+ 'tts_text_token': tts_text_token,
364
+ 'tts_text_token_len': tts_text_token_len})
365
+ if use_spk_embedding is True:
366
+ batch["embedding"] = batch["spk_embedding"]
367
+ else:
368
+ batch["embedding"] = batch["utt_embedding"]
369
+ yield batch
File without changes
@@ -0,0 +1,222 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import pack, rearrange, repeat
17
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
18
+ from matcha.models.components.transformer import BasicTransformerBlock
19
+
20
+
21
+ class ConditionalDecoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels,
25
+ out_channels,
26
+ channels=(256, 256),
27
+ dropout=0.05,
28
+ attention_head_dim=64,
29
+ n_blocks=1,
30
+ num_mid_blocks=2,
31
+ num_heads=4,
32
+ act_fn="snake",
33
+ ):
34
+ """
35
+ This decoder requires an input with the same shape of the target. So, if your text content
36
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
37
+ """
38
+ super().__init__()
39
+ channels = tuple(channels)
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+
43
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
44
+ time_embed_dim = channels[0] * 4
45
+ self.time_mlp = TimestepEmbedding(
46
+ in_channels=in_channels,
47
+ time_embed_dim=time_embed_dim,
48
+ act_fn="silu",
49
+ )
50
+ self.down_blocks = nn.ModuleList([])
51
+ self.mid_blocks = nn.ModuleList([])
52
+ self.up_blocks = nn.ModuleList([])
53
+
54
+ output_channel = in_channels
55
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
56
+ input_channel = output_channel
57
+ output_channel = channels[i]
58
+ is_last = i == len(channels) - 1
59
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
60
+ transformer_blocks = nn.ModuleList(
61
+ [
62
+ BasicTransformerBlock(
63
+ dim=output_channel,
64
+ num_attention_heads=num_heads,
65
+ attention_head_dim=attention_head_dim,
66
+ dropout=dropout,
67
+ activation_fn=act_fn,
68
+ )
69
+ for _ in range(n_blocks)
70
+ ]
71
+ )
72
+ downsample = (
73
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
74
+ )
75
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
+
77
+ for i in range(num_mid_blocks):
78
+ input_channel = channels[-1]
79
+ out_channels = channels[-1]
80
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
81
+
82
+ transformer_blocks = nn.ModuleList(
83
+ [
84
+ BasicTransformerBlock(
85
+ dim=output_channel,
86
+ num_attention_heads=num_heads,
87
+ attention_head_dim=attention_head_dim,
88
+ dropout=dropout,
89
+ activation_fn=act_fn,
90
+ )
91
+ for _ in range(n_blocks)
92
+ ]
93
+ )
94
+
95
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
96
+
97
+ channels = channels[::-1] + (channels[0],)
98
+ for i in range(len(channels) - 1):
99
+ input_channel = channels[i] * 2
100
+ output_channel = channels[i + 1]
101
+ is_last = i == len(channels) - 2
102
+ resnet = ResnetBlock1D(
103
+ dim=input_channel,
104
+ dim_out=output_channel,
105
+ time_emb_dim=time_embed_dim,
106
+ )
107
+ transformer_blocks = nn.ModuleList(
108
+ [
109
+ BasicTransformerBlock(
110
+ dim=output_channel,
111
+ num_attention_heads=num_heads,
112
+ attention_head_dim=attention_head_dim,
113
+ dropout=dropout,
114
+ activation_fn=act_fn,
115
+ )
116
+ for _ in range(n_blocks)
117
+ ]
118
+ )
119
+ upsample = (
120
+ Upsample1D(output_channel, use_conv_transpose=True)
121
+ if not is_last
122
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
123
+ )
124
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
125
+ self.final_block = Block1D(channels[-1], channels[-1])
126
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
+ self.initialize_weights()
128
+
129
+
130
+ def initialize_weights(self):
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv1d):
133
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+ elif isinstance(m, nn.GroupNorm):
137
+ nn.init.constant_(m.weight, 1)
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, nn.Linear):
140
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+
144
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
145
+ """Forward pass of the UNet1DConditional model.
146
+
147
+ Args:
148
+ x (torch.Tensor): shape (batch_size, in_channels, time)
149
+ mask (_type_): shape (batch_size, 1, time)
150
+ t (_type_): shape (batch_size)
151
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
152
+ cond (_type_, optional): placeholder for future use. Defaults to None.
153
+
154
+ Raises:
155
+ ValueError: _description_
156
+ ValueError: _description_
157
+
158
+ Returns:
159
+ _type_: _description_
160
+ """
161
+
162
+ t = self.time_embeddings(t)
163
+ t = self.time_mlp(t)
164
+
165
+ x = pack([x, mu], "b * t")[0]
166
+
167
+ if spks is not None:
168
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
169
+ x = pack([x, spks], "b * t")[0]
170
+ if cond is not None:
171
+ x = pack([x, cond], "b * t")[0]
172
+
173
+ hiddens = []
174
+ masks = [mask]
175
+ for resnet, transformer_blocks, downsample in self.down_blocks:
176
+ mask_down = masks[-1]
177
+ x = resnet(x, mask_down, t)
178
+ x = rearrange(x, "b c t -> b t c").contiguous()
179
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
180
+ for transformer_block in transformer_blocks:
181
+ x = transformer_block(
182
+ hidden_states=x,
183
+ attention_mask=attn_mask,
184
+ timestep=t,
185
+ )
186
+ x = rearrange(x, "b t c -> b c t").contiguous()
187
+ hiddens.append(x) # Save hidden states for skip connections
188
+ x = downsample(x * mask_down)
189
+ masks.append(mask_down[:, :, ::2])
190
+ masks = masks[:-1]
191
+ mask_mid = masks[-1]
192
+
193
+ for resnet, transformer_blocks in self.mid_blocks:
194
+ x = resnet(x, mask_mid, t)
195
+ x = rearrange(x, "b c t -> b t c").contiguous()
196
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
197
+ for transformer_block in transformer_blocks:
198
+ x = transformer_block(
199
+ hidden_states=x,
200
+ attention_mask=attn_mask,
201
+ timestep=t,
202
+ )
203
+ x = rearrange(x, "b t c -> b c t").contiguous()
204
+
205
+ for resnet, transformer_blocks, upsample in self.up_blocks:
206
+ mask_up = masks.pop()
207
+ skip = hiddens.pop()
208
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
209
+ x = resnet(x, mask_up, t)
210
+ x = rearrange(x, "b c t -> b t c").contiguous()
211
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
212
+ for transformer_block in transformer_blocks:
213
+ x = transformer_block(
214
+ hidden_states=x,
215
+ attention_mask=attn_mask,
216
+ timestep=t,
217
+ )
218
+ x = rearrange(x, "b t c -> b c t").contiguous()
219
+ x = upsample(x * mask_up)
220
+ x = self.final_block(x, mask_up)
221
+ output = self.final_proj(x * mask_up)
222
+ return output * mask