xinference 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (84) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +204 -1
  3. xinference/client/restful/restful_client.py +4 -2
  4. xinference/core/image_interface.py +28 -0
  5. xinference/core/model.py +30 -2
  6. xinference/core/supervisor.py +6 -0
  7. xinference/model/audio/cosyvoice.py +3 -3
  8. xinference/model/audio/fish_speech.py +9 -9
  9. xinference/model/audio/model_spec.json +9 -9
  10. xinference/model/audio/whisper.py +4 -1
  11. xinference/model/image/core.py +2 -1
  12. xinference/model/image/model_spec.json +16 -4
  13. xinference/model/image/model_spec_modelscope.json +16 -4
  14. xinference/model/image/sdapi.py +136 -0
  15. xinference/model/image/stable_diffusion/core.py +163 -24
  16. xinference/model/llm/__init__.py +9 -1
  17. xinference/model/llm/llm_family.json +1241 -0
  18. xinference/model/llm/llm_family.py +3 -1
  19. xinference/model/llm/llm_family_modelscope.json +1301 -3
  20. xinference/model/llm/sglang/core.py +7 -0
  21. xinference/model/llm/transformers/chatglm.py +1 -1
  22. xinference/model/llm/transformers/core.py +6 -0
  23. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  24. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  25. xinference/model/llm/transformers/qwen2_vl.py +31 -5
  26. xinference/model/llm/utils.py +104 -84
  27. xinference/model/llm/vllm/core.py +13 -0
  28. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +2 -3
  29. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +1 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  34. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  35. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  37. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  38. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  39. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  40. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  41. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  42. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  43. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  44. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  45. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  46. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  47. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  48. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  49. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  50. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  51. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  52. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  53. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  54. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  55. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  56. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  57. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  58. xinference/types.py +7 -4
  59. xinference/web/ui/build/asset-manifest.json +6 -6
  60. xinference/web/ui/build/index.html +1 -1
  61. xinference/web/ui/build/static/css/{main.632e9148.css → main.5061c4c3.css} +2 -2
  62. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  63. xinference/web/ui/build/static/js/{main.9cfafbd6.js → main.29578905.js} +3 -3
  64. xinference/web/ui/build/static/js/main.29578905.js.map +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  67. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/METADATA +13 -7
  68. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/RECORD +73 -75
  69. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  73. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  74. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  75. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  76. xinference/web/ui/build/static/css/main.632e9148.css.map +0 -1
  77. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +0 -1
  80. /xinference/web/ui/build/static/js/{main.9cfafbd6.js.LICENSE.txt → main.29578905.js.LICENSE.txt} +0 -0
  81. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/LICENSE +0 -0
  82. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/WHEEL +0 -0
  83. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/entry_points.txt +0 -0
  84. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/top_level.txt +0 -0
@@ -1,442 +0,0 @@
1
- import itertools
2
- import math
3
- from typing import Any, Callable
4
-
5
- import lightning as L
6
- import torch
7
- import torch.nn.functional as F
8
- # import wandb
9
- from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
10
- from matplotlib import pyplot as plt
11
- from torch import nn
12
-
13
- from fish_speech.models.vqgan.modules.discriminator import Discriminator
14
- from fish_speech.models.vqgan.modules.wavenet import WaveNet
15
- from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
16
-
17
-
18
- class VQGAN(L.LightningModule):
19
- def __init__(
20
- self,
21
- optimizer: Callable,
22
- lr_scheduler: Callable,
23
- encoder: WaveNet,
24
- quantizer: nn.Module,
25
- decoder: WaveNet,
26
- discriminator: Discriminator,
27
- vocoder: nn.Module,
28
- encode_mel_transform: nn.Module,
29
- gt_mel_transform: nn.Module,
30
- weight_adv: float = 1.0,
31
- weight_vq: float = 1.0,
32
- weight_mel: float = 1.0,
33
- sampling_rate: int = 44100,
34
- freeze_encoder: bool = False,
35
- ):
36
- super().__init__()
37
-
38
- # Model parameters
39
- self.optimizer_builder = optimizer
40
- self.lr_scheduler_builder = lr_scheduler
41
-
42
- # Modules
43
- self.encoder = encoder
44
- self.quantizer = quantizer
45
- self.decoder = decoder
46
- self.vocoder = vocoder
47
- self.discriminator = discriminator
48
- self.encode_mel_transform = encode_mel_transform
49
- self.gt_mel_transform = gt_mel_transform
50
-
51
- # A simple linear layer to project quality to condition channels
52
- self.quality_projection = nn.Linear(1, 768)
53
-
54
- # Freeze vocoder
55
- for param in self.vocoder.parameters():
56
- param.requires_grad = False
57
-
58
- # Loss weights
59
- self.weight_adv = weight_adv
60
- self.weight_vq = weight_vq
61
- self.weight_mel = weight_mel
62
-
63
- # Other parameters
64
- self.sampling_rate = sampling_rate
65
-
66
- # Disable strict loading
67
- self.strict_loading = False
68
-
69
- # If encoder is frozen
70
- if freeze_encoder:
71
- for param in self.encoder.parameters():
72
- param.requires_grad = False
73
-
74
- for param in self.quantizer.parameters():
75
- param.requires_grad = False
76
-
77
- self.automatic_optimization = False
78
-
79
- def on_save_checkpoint(self, checkpoint):
80
- # Do not save vocoder
81
- state_dict = checkpoint["state_dict"]
82
- for name in list(state_dict.keys()):
83
- if "vocoder" in name:
84
- state_dict.pop(name)
85
-
86
- def configure_optimizers(self):
87
- optimizer_generator = self.optimizer_builder(
88
- itertools.chain(
89
- self.encoder.parameters(),
90
- self.quantizer.parameters(),
91
- self.decoder.parameters(),
92
- self.quality_projection.parameters(),
93
- )
94
- )
95
- optimizer_discriminator = self.optimizer_builder(
96
- self.discriminator.parameters()
97
- )
98
-
99
- lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
100
- lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
101
-
102
- return (
103
- {
104
- "optimizer": optimizer_generator,
105
- "lr_scheduler": {
106
- "scheduler": lr_scheduler_generator,
107
- "interval": "step",
108
- "name": "optimizer/generator",
109
- },
110
- },
111
- {
112
- "optimizer": optimizer_discriminator,
113
- "lr_scheduler": {
114
- "scheduler": lr_scheduler_discriminator,
115
- "interval": "step",
116
- "name": "optimizer/discriminator",
117
- },
118
- },
119
- )
120
-
121
- def training_step(self, batch, batch_idx):
122
- optim_g, optim_d = self.optimizers()
123
-
124
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
125
-
126
- audios = audios.float()
127
- audios = audios[:, None, :]
128
-
129
- with torch.no_grad():
130
- encoded_mels = self.encode_mel_transform(audios)
131
- gt_mels = self.gt_mel_transform(audios)
132
- quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
133
- quality = quality.unsqueeze(-1)
134
-
135
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
136
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
137
- mel_masks_float_conv = mel_masks[:, None, :].float()
138
- gt_mels = gt_mels * mel_masks_float_conv
139
- encoded_mels = encoded_mels * mel_masks_float_conv
140
-
141
- # Encode
142
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
143
-
144
- # Quantize
145
- vq_result = self.quantizer(encoded_features)
146
- loss_vq = getattr("vq_result", "loss", 0.0)
147
- vq_recon_features = vq_result.z * mel_masks_float_conv
148
- vq_recon_features = (
149
- vq_recon_features + self.quality_projection(quality)[:, :, None]
150
- )
151
-
152
- # VQ Decode
153
- gen_mel = (
154
- self.decoder(
155
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
156
- condition=vq_recon_features,
157
- )
158
- * mel_masks_float_conv
159
- )
160
-
161
- # Discriminator
162
- real_logits = self.discriminator(gt_mels)
163
- fake_logits = self.discriminator(gen_mel.detach())
164
- d_mask = F.interpolate(
165
- mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
166
- )
167
-
168
- loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
169
- loss_fake = avg_with_mask(fake_logits**2, d_mask)
170
-
171
- loss_d = loss_real + loss_fake
172
-
173
- self.log(
174
- "train/discriminator/loss",
175
- loss_d,
176
- on_step=True,
177
- on_epoch=False,
178
- prog_bar=True,
179
- logger=True,
180
- )
181
-
182
- # Discriminator backward
183
- optim_d.zero_grad()
184
- self.manual_backward(loss_d)
185
- self.clip_gradients(
186
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
187
- )
188
- optim_d.step()
189
-
190
- # Mel Loss, applying l1, using a weighted sum
191
- mel_distance = (
192
- gen_mel - gt_mels
193
- ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
194
- loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
195
- loss_mel_mid_freq = avg_with_mask(
196
- mel_distance[:, 40:70, :], mel_masks_float_conv
197
- )
198
- loss_mel_high_freq = avg_with_mask(
199
- mel_distance[:, 70:, :], mel_masks_float_conv
200
- )
201
- loss_mel = (
202
- loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
203
- )
204
-
205
- # Adversarial Loss
206
- fake_logits = self.discriminator(gen_mel)
207
- loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
208
-
209
- # Total loss
210
- loss = (
211
- self.weight_vq * loss_vq
212
- + self.weight_mel * loss_mel
213
- + self.weight_adv * loss_adv
214
- )
215
-
216
- # Log losses
217
- self.log(
218
- "train/generator/loss",
219
- loss,
220
- on_step=True,
221
- on_epoch=False,
222
- prog_bar=True,
223
- logger=True,
224
- )
225
- self.log(
226
- "train/generator/loss_vq",
227
- loss_vq,
228
- on_step=True,
229
- on_epoch=False,
230
- prog_bar=False,
231
- logger=True,
232
- )
233
- self.log(
234
- "train/generator/loss_mel",
235
- loss_mel,
236
- on_step=True,
237
- on_epoch=False,
238
- prog_bar=False,
239
- logger=True,
240
- )
241
- self.log(
242
- "train/generator/loss_adv",
243
- loss_adv,
244
- on_step=True,
245
- on_epoch=False,
246
- prog_bar=False,
247
- logger=True,
248
- )
249
-
250
- # Generator backward
251
- optim_g.zero_grad()
252
- self.manual_backward(loss)
253
- self.clip_gradients(
254
- optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
255
- )
256
- optim_g.step()
257
-
258
- scheduler_g, scheduler_d = self.lr_schedulers()
259
- scheduler_g.step()
260
- scheduler_d.step()
261
-
262
- def validation_step(self, batch: Any, batch_idx: int):
263
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
264
-
265
- audios = audios.float()
266
- audios = audios[:, None, :]
267
-
268
- encoded_mels = self.encode_mel_transform(audios)
269
- gt_mels = self.gt_mel_transform(audios)
270
-
271
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
272
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
273
- mel_masks_float_conv = mel_masks[:, None, :].float()
274
- gt_mels = gt_mels * mel_masks_float_conv
275
- encoded_mels = encoded_mels * mel_masks_float_conv
276
-
277
- # Encode
278
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
279
-
280
- # Quantize
281
- vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
282
- vq_recon_features = (
283
- vq_recon_features
284
- + self.quality_projection(
285
- torch.ones(
286
- vq_recon_features.shape[0], 1, device=vq_recon_features.device
287
- )
288
- * 2
289
- )[:, :, None]
290
- )
291
-
292
- # VQ Decode
293
- gen_aux_mels = (
294
- self.decoder(
295
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
296
- condition=vq_recon_features,
297
- )
298
- * mel_masks_float_conv
299
- )
300
- loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
301
-
302
- self.log(
303
- "val/loss_mel",
304
- loss_mel,
305
- on_step=False,
306
- on_epoch=True,
307
- prog_bar=False,
308
- logger=True,
309
- sync_dist=True,
310
- )
311
-
312
- recon_audios = self.vocoder(gt_mels)
313
- gen_aux_audios = self.vocoder(gen_aux_mels)
314
-
315
- # only log the first batch
316
- if batch_idx != 0:
317
- return
318
-
319
- for idx, (
320
- gt_mel,
321
- gen_aux_mel,
322
- audio,
323
- gen_aux_audio,
324
- recon_audio,
325
- audio_len,
326
- ) in enumerate(
327
- zip(
328
- gt_mels,
329
- gen_aux_mels,
330
- audios.cpu().float(),
331
- gen_aux_audios.cpu().float(),
332
- recon_audios.cpu().float(),
333
- audio_lengths,
334
- )
335
- ):
336
- if idx > 4:
337
- break
338
-
339
- mel_len = audio_len // self.gt_mel_transform.hop_length
340
-
341
- image_mels = plot_mel(
342
- [
343
- gt_mel[:, :mel_len],
344
- gen_aux_mel[:, :mel_len],
345
- ],
346
- [
347
- "Ground-Truth",
348
- "Auxiliary",
349
- ],
350
- )
351
-
352
- if isinstance(self.logger, WandbLogger):
353
- self.logger.experiment.log(
354
- {
355
- "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
356
- "wavs": [
357
- wandb.Audio(
358
- audio[0, :audio_len],
359
- sample_rate=self.sampling_rate,
360
- caption="gt",
361
- ),
362
- wandb.Audio(
363
- gen_aux_audio[0, :audio_len],
364
- sample_rate=self.sampling_rate,
365
- caption="aux",
366
- ),
367
- wandb.Audio(
368
- recon_audio[0, :audio_len],
369
- sample_rate=self.sampling_rate,
370
- caption="recon",
371
- ),
372
- ],
373
- },
374
- )
375
-
376
- if isinstance(self.logger, TensorBoardLogger):
377
- self.logger.experiment.add_figure(
378
- f"sample-{idx}/mels",
379
- image_mels,
380
- global_step=self.global_step,
381
- )
382
- self.logger.experiment.add_audio(
383
- f"sample-{idx}/wavs/gt",
384
- audio[0, :audio_len],
385
- self.global_step,
386
- sample_rate=self.sampling_rate,
387
- )
388
- self.logger.experiment.add_audio(
389
- f"sample-{idx}/wavs/gen",
390
- gen_aux_audio[0, :audio_len],
391
- self.global_step,
392
- sample_rate=self.sampling_rate,
393
- )
394
- self.logger.experiment.add_audio(
395
- f"sample-{idx}/wavs/recon",
396
- recon_audio[0, :audio_len],
397
- self.global_step,
398
- sample_rate=self.sampling_rate,
399
- )
400
-
401
- plt.close(image_mels)
402
-
403
- def encode(self, audios, audio_lengths):
404
- audios = audios.float()
405
-
406
- mels = self.encode_mel_transform(audios)
407
- mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
408
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
409
- mel_masks_float_conv = mel_masks[:, None, :].float()
410
- mels = mels * mel_masks_float_conv
411
-
412
- # Encode
413
- encoded_features = self.encoder(mels) * mel_masks_float_conv
414
- feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
415
-
416
- return self.quantizer.encode(encoded_features), feature_lengths
417
-
418
- def decode(self, indices, feature_lengths, return_audios=False):
419
- factor = math.prod(self.quantizer.downsample_factor)
420
- mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
421
- mel_masks_float_conv = mel_masks[:, None, :].float()
422
-
423
- z = self.quantizer.decode(indices) * mel_masks_float_conv
424
- z = (
425
- z
426
- + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
427
- :, :, None
428
- ]
429
- )
430
-
431
- gen_mel = (
432
- self.decoder(
433
- torch.randn_like(z) * mel_masks_float_conv,
434
- condition=z,
435
- )
436
- * mel_masks_float_conv
437
- )
438
-
439
- if return_audios:
440
- return self.vocoder(gen_mel)
441
-
442
- return gen_mel
@@ -1,44 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.utils.parametrizations import weight_norm
4
-
5
-
6
- class Discriminator(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
-
10
- blocks = []
11
- convs = [
12
- (1, 64, (3, 9), 1, (1, 4)),
13
- (64, 128, (3, 9), (1, 2), (1, 4)),
14
- (128, 256, (3, 9), (1, 2), (1, 4)),
15
- (256, 512, (3, 9), (1, 2), (1, 4)),
16
- (512, 1024, (3, 3), 1, (1, 1)),
17
- (1024, 1, (3, 3), 1, (1, 1)),
18
- ]
19
-
20
- for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
21
- convs
22
- ):
23
- blocks.append(
24
- weight_norm(
25
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
26
- )
27
- )
28
-
29
- if idx != len(convs) - 1:
30
- blocks.append(nn.SiLU(inplace=True))
31
-
32
- self.blocks = nn.Sequential(*blocks)
33
-
34
- def forward(self, x):
35
- return self.blocks(x[:, None])[:, 0]
36
-
37
-
38
- if __name__ == "__main__":
39
- model = Discriminator()
40
- print(sum(p.numel() for p in model.parameters()) / 1_000_000)
41
- x = torch.randn(1, 128, 1024)
42
- y = model(x)
43
- print(y.shape)
44
- print(y)
@@ -1,115 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from fish_speech.utils import autocast_exclude_mps
8
-
9
- from .wavenet import WaveNet
10
-
11
-
12
- class ReferenceEncoder(WaveNet):
13
- def __init__(
14
- self,
15
- input_channels: Optional[int] = None,
16
- output_channels: Optional[int] = None,
17
- residual_channels: int = 512,
18
- residual_layers: int = 20,
19
- dilation_cycle: Optional[int] = 4,
20
- num_heads: int = 8,
21
- latent_len: int = 4,
22
- ):
23
- super().__init__(
24
- input_channels=input_channels,
25
- residual_channels=residual_channels,
26
- residual_layers=residual_layers,
27
- dilation_cycle=dilation_cycle,
28
- )
29
-
30
- self.head_dim = residual_channels // num_heads
31
- self.num_heads = num_heads
32
-
33
- self.latent_len = latent_len
34
- self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
35
-
36
- self.q = nn.Linear(residual_channels, residual_channels, bias=True)
37
- self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
38
- self.q_norm = nn.LayerNorm(self.head_dim)
39
- self.k_norm = nn.LayerNorm(self.head_dim)
40
- self.proj = nn.Linear(residual_channels, residual_channels)
41
- self.proj_drop = nn.Dropout(0.1)
42
-
43
- self.norm = nn.LayerNorm(residual_channels)
44
- self.mlp = nn.Sequential(
45
- nn.Linear(residual_channels, residual_channels * 4),
46
- nn.SiLU(),
47
- nn.Linear(residual_channels * 4, residual_channels),
48
- )
49
- self.output_projection_attn = nn.Linear(residual_channels, output_channels)
50
-
51
- torch.nn.init.trunc_normal_(self.latent, std=0.02)
52
- self.apply(self.init_weights)
53
-
54
- def init_weights(self, m):
55
- if isinstance(m, nn.Linear):
56
- torch.nn.init.trunc_normal_(m.weight, std=0.02)
57
- if m.bias is not None:
58
- torch.nn.init.constant_(m.bias, 0)
59
-
60
- def forward(self, x, attn_mask=None):
61
- x = super().forward(x).mT
62
- B, N, C = x.shape
63
-
64
- # Calculate mask
65
- if attn_mask is not None:
66
- assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
67
-
68
- attn_mask = attn_mask[:, None, None, :].expand(
69
- B, self.num_heads, self.latent_len, N
70
- )
71
-
72
- q_latent = self.latent.expand(B, -1, -1)
73
- q = (
74
- self.q(q_latent)
75
- .reshape(B, self.latent_len, self.num_heads, self.head_dim)
76
- .transpose(1, 2)
77
- )
78
-
79
- kv = (
80
- self.kv(x)
81
- .reshape(B, N, 2, self.num_heads, self.head_dim)
82
- .permute(2, 0, 3, 1, 4)
83
- )
84
- k, v = kv.unbind(0)
85
-
86
- q, k = self.q_norm(q), self.k_norm(k)
87
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
88
-
89
- x = x.transpose(1, 2).reshape(B, self.latent_len, C)
90
- x = self.proj(x)
91
- x = self.proj_drop(x)
92
-
93
- x = x + self.mlp(self.norm(x))
94
- x = self.output_projection_attn(x)
95
- x = x.mean(1)
96
-
97
- return x
98
-
99
-
100
- if __name__ == "__main__":
101
- with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
102
- model = ReferenceEncoder(
103
- input_channels=128,
104
- output_channels=64,
105
- residual_channels=384,
106
- residual_layers=20,
107
- dilation_cycle=4,
108
- num_heads=8,
109
- )
110
- x = torch.randn(4, 128, 64)
111
- mask = torch.ones(4, 64, dtype=torch.bool)
112
- y = model(x, mask)
113
- print(y.shape)
114
- loss = F.mse_loss(y, torch.randn(4, 64))
115
- loss.backward()