xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -1,108 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import List, Optional
16
-
17
- from ....types import LoRA
18
- from ..llm_family import LLMFamilyV1, LLMSpecV1
19
- from .core import PytorchChatModel, PytorchModel, PytorchModelConfig
20
-
21
-
22
- class LlamaPytorchModel(PytorchModel):
23
- def __init__(
24
- self,
25
- model_uid: str,
26
- model_family: "LLMFamilyV1",
27
- model_spec: "LLMSpecV1",
28
- quantization: str,
29
- model_path: str,
30
- pytorch_model_config: Optional[PytorchModelConfig] = None,
31
- peft_model: Optional[List[LoRA]] = None,
32
- ):
33
- super().__init__(
34
- model_uid,
35
- model_family,
36
- model_spec,
37
- quantization,
38
- model_path,
39
- pytorch_model_config=pytorch_model_config,
40
- peft_model=peft_model,
41
- )
42
-
43
- def _load_model(self, **kwargs):
44
- model, tokenizer = super()._load_model(**kwargs)
45
- # Llama has no pad token by default
46
- # https://github.com/huggingface/transformers/blob/07998ef39926b76d3f6667025535d0859eed61c3/docs/source/en/llm_tutorial.md?plain=1#L125
47
- tokenizer.pad_token = tokenizer.eos_token
48
- model.config.eos_token_id = tokenizer.eos_token_id
49
- model.config.pad_token_id = tokenizer.pad_token_id
50
- return model, tokenizer
51
-
52
- @classmethod
53
- def match(
54
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
55
- ) -> bool:
56
- if llm_spec.model_format != "pytorch":
57
- return False
58
- model_family = llm_family.model_family or llm_family.model_name
59
- if "llama-2" not in model_family:
60
- return False
61
- if "generate" not in llm_family.model_ability:
62
- return False
63
- return True
64
-
65
-
66
- class LlamaPytorchChatModel(PytorchChatModel):
67
- def __init__(
68
- self,
69
- model_uid: str,
70
- model_family: "LLMFamilyV1",
71
- model_spec: "LLMSpecV1",
72
- quantization: str,
73
- model_path: str,
74
- pytorch_model_config: Optional["PytorchModelConfig"] = None,
75
- peft_model: Optional[List[LoRA]] = None,
76
- ):
77
- super().__init__(
78
- model_uid,
79
- model_family,
80
- model_spec,
81
- quantization,
82
- model_path,
83
- peft_model=peft_model,
84
- pytorch_model_config=pytorch_model_config,
85
- )
86
- self._use_fast_tokenizer = False
87
-
88
- def _load_model(self, **kwargs):
89
- model, tokenizer = super()._load_model(**kwargs)
90
- # Llama has no pad token by default
91
- # https://github.com/huggingface/transformers/blob/07998ef39926b76d3f6667025535d0859eed61c3/docs/source/en/llm_tutorial.md?plain=1#L125
92
- tokenizer.pad_token = tokenizer.eos_token
93
- model.config.eos_token_id = tokenizer.eos_token_id
94
- model.config.pad_token_id = tokenizer.pad_token_id
95
- return model, tokenizer
96
-
97
- @classmethod
98
- def match(
99
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
100
- ) -> bool:
101
- if llm_spec.model_format != "pytorch":
102
- return False
103
- model_family = llm_family.model_family or llm_family.model_name
104
- if "llama-2" not in model_family:
105
- return False
106
- if "chat" not in llm_family.model_ability:
107
- return False
108
- return True
@@ -1,442 +0,0 @@
1
- import itertools
2
- import math
3
- from typing import Any, Callable
4
-
5
- import lightning as L
6
- import torch
7
- import torch.nn.functional as F
8
- # import wandb
9
- from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
10
- from matplotlib import pyplot as plt
11
- from torch import nn
12
-
13
- from fish_speech.models.vqgan.modules.discriminator import Discriminator
14
- from fish_speech.models.vqgan.modules.wavenet import WaveNet
15
- from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
16
-
17
-
18
- class VQGAN(L.LightningModule):
19
- def __init__(
20
- self,
21
- optimizer: Callable,
22
- lr_scheduler: Callable,
23
- encoder: WaveNet,
24
- quantizer: nn.Module,
25
- decoder: WaveNet,
26
- discriminator: Discriminator,
27
- vocoder: nn.Module,
28
- encode_mel_transform: nn.Module,
29
- gt_mel_transform: nn.Module,
30
- weight_adv: float = 1.0,
31
- weight_vq: float = 1.0,
32
- weight_mel: float = 1.0,
33
- sampling_rate: int = 44100,
34
- freeze_encoder: bool = False,
35
- ):
36
- super().__init__()
37
-
38
- # Model parameters
39
- self.optimizer_builder = optimizer
40
- self.lr_scheduler_builder = lr_scheduler
41
-
42
- # Modules
43
- self.encoder = encoder
44
- self.quantizer = quantizer
45
- self.decoder = decoder
46
- self.vocoder = vocoder
47
- self.discriminator = discriminator
48
- self.encode_mel_transform = encode_mel_transform
49
- self.gt_mel_transform = gt_mel_transform
50
-
51
- # A simple linear layer to project quality to condition channels
52
- self.quality_projection = nn.Linear(1, 768)
53
-
54
- # Freeze vocoder
55
- for param in self.vocoder.parameters():
56
- param.requires_grad = False
57
-
58
- # Loss weights
59
- self.weight_adv = weight_adv
60
- self.weight_vq = weight_vq
61
- self.weight_mel = weight_mel
62
-
63
- # Other parameters
64
- self.sampling_rate = sampling_rate
65
-
66
- # Disable strict loading
67
- self.strict_loading = False
68
-
69
- # If encoder is frozen
70
- if freeze_encoder:
71
- for param in self.encoder.parameters():
72
- param.requires_grad = False
73
-
74
- for param in self.quantizer.parameters():
75
- param.requires_grad = False
76
-
77
- self.automatic_optimization = False
78
-
79
- def on_save_checkpoint(self, checkpoint):
80
- # Do not save vocoder
81
- state_dict = checkpoint["state_dict"]
82
- for name in list(state_dict.keys()):
83
- if "vocoder" in name:
84
- state_dict.pop(name)
85
-
86
- def configure_optimizers(self):
87
- optimizer_generator = self.optimizer_builder(
88
- itertools.chain(
89
- self.encoder.parameters(),
90
- self.quantizer.parameters(),
91
- self.decoder.parameters(),
92
- self.quality_projection.parameters(),
93
- )
94
- )
95
- optimizer_discriminator = self.optimizer_builder(
96
- self.discriminator.parameters()
97
- )
98
-
99
- lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
100
- lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
101
-
102
- return (
103
- {
104
- "optimizer": optimizer_generator,
105
- "lr_scheduler": {
106
- "scheduler": lr_scheduler_generator,
107
- "interval": "step",
108
- "name": "optimizer/generator",
109
- },
110
- },
111
- {
112
- "optimizer": optimizer_discriminator,
113
- "lr_scheduler": {
114
- "scheduler": lr_scheduler_discriminator,
115
- "interval": "step",
116
- "name": "optimizer/discriminator",
117
- },
118
- },
119
- )
120
-
121
- def training_step(self, batch, batch_idx):
122
- optim_g, optim_d = self.optimizers()
123
-
124
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
125
-
126
- audios = audios.float()
127
- audios = audios[:, None, :]
128
-
129
- with torch.no_grad():
130
- encoded_mels = self.encode_mel_transform(audios)
131
- gt_mels = self.gt_mel_transform(audios)
132
- quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
133
- quality = quality.unsqueeze(-1)
134
-
135
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
136
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
137
- mel_masks_float_conv = mel_masks[:, None, :].float()
138
- gt_mels = gt_mels * mel_masks_float_conv
139
- encoded_mels = encoded_mels * mel_masks_float_conv
140
-
141
- # Encode
142
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
143
-
144
- # Quantize
145
- vq_result = self.quantizer(encoded_features)
146
- loss_vq = getattr("vq_result", "loss", 0.0)
147
- vq_recon_features = vq_result.z * mel_masks_float_conv
148
- vq_recon_features = (
149
- vq_recon_features + self.quality_projection(quality)[:, :, None]
150
- )
151
-
152
- # VQ Decode
153
- gen_mel = (
154
- self.decoder(
155
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
156
- condition=vq_recon_features,
157
- )
158
- * mel_masks_float_conv
159
- )
160
-
161
- # Discriminator
162
- real_logits = self.discriminator(gt_mels)
163
- fake_logits = self.discriminator(gen_mel.detach())
164
- d_mask = F.interpolate(
165
- mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
166
- )
167
-
168
- loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
169
- loss_fake = avg_with_mask(fake_logits**2, d_mask)
170
-
171
- loss_d = loss_real + loss_fake
172
-
173
- self.log(
174
- "train/discriminator/loss",
175
- loss_d,
176
- on_step=True,
177
- on_epoch=False,
178
- prog_bar=True,
179
- logger=True,
180
- )
181
-
182
- # Discriminator backward
183
- optim_d.zero_grad()
184
- self.manual_backward(loss_d)
185
- self.clip_gradients(
186
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
187
- )
188
- optim_d.step()
189
-
190
- # Mel Loss, applying l1, using a weighted sum
191
- mel_distance = (
192
- gen_mel - gt_mels
193
- ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
194
- loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
195
- loss_mel_mid_freq = avg_with_mask(
196
- mel_distance[:, 40:70, :], mel_masks_float_conv
197
- )
198
- loss_mel_high_freq = avg_with_mask(
199
- mel_distance[:, 70:, :], mel_masks_float_conv
200
- )
201
- loss_mel = (
202
- loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
203
- )
204
-
205
- # Adversarial Loss
206
- fake_logits = self.discriminator(gen_mel)
207
- loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
208
-
209
- # Total loss
210
- loss = (
211
- self.weight_vq * loss_vq
212
- + self.weight_mel * loss_mel
213
- + self.weight_adv * loss_adv
214
- )
215
-
216
- # Log losses
217
- self.log(
218
- "train/generator/loss",
219
- loss,
220
- on_step=True,
221
- on_epoch=False,
222
- prog_bar=True,
223
- logger=True,
224
- )
225
- self.log(
226
- "train/generator/loss_vq",
227
- loss_vq,
228
- on_step=True,
229
- on_epoch=False,
230
- prog_bar=False,
231
- logger=True,
232
- )
233
- self.log(
234
- "train/generator/loss_mel",
235
- loss_mel,
236
- on_step=True,
237
- on_epoch=False,
238
- prog_bar=False,
239
- logger=True,
240
- )
241
- self.log(
242
- "train/generator/loss_adv",
243
- loss_adv,
244
- on_step=True,
245
- on_epoch=False,
246
- prog_bar=False,
247
- logger=True,
248
- )
249
-
250
- # Generator backward
251
- optim_g.zero_grad()
252
- self.manual_backward(loss)
253
- self.clip_gradients(
254
- optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
255
- )
256
- optim_g.step()
257
-
258
- scheduler_g, scheduler_d = self.lr_schedulers()
259
- scheduler_g.step()
260
- scheduler_d.step()
261
-
262
- def validation_step(self, batch: Any, batch_idx: int):
263
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
264
-
265
- audios = audios.float()
266
- audios = audios[:, None, :]
267
-
268
- encoded_mels = self.encode_mel_transform(audios)
269
- gt_mels = self.gt_mel_transform(audios)
270
-
271
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
272
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
273
- mel_masks_float_conv = mel_masks[:, None, :].float()
274
- gt_mels = gt_mels * mel_masks_float_conv
275
- encoded_mels = encoded_mels * mel_masks_float_conv
276
-
277
- # Encode
278
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
279
-
280
- # Quantize
281
- vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
282
- vq_recon_features = (
283
- vq_recon_features
284
- + self.quality_projection(
285
- torch.ones(
286
- vq_recon_features.shape[0], 1, device=vq_recon_features.device
287
- )
288
- * 2
289
- )[:, :, None]
290
- )
291
-
292
- # VQ Decode
293
- gen_aux_mels = (
294
- self.decoder(
295
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
296
- condition=vq_recon_features,
297
- )
298
- * mel_masks_float_conv
299
- )
300
- loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
301
-
302
- self.log(
303
- "val/loss_mel",
304
- loss_mel,
305
- on_step=False,
306
- on_epoch=True,
307
- prog_bar=False,
308
- logger=True,
309
- sync_dist=True,
310
- )
311
-
312
- recon_audios = self.vocoder(gt_mels)
313
- gen_aux_audios = self.vocoder(gen_aux_mels)
314
-
315
- # only log the first batch
316
- if batch_idx != 0:
317
- return
318
-
319
- for idx, (
320
- gt_mel,
321
- gen_aux_mel,
322
- audio,
323
- gen_aux_audio,
324
- recon_audio,
325
- audio_len,
326
- ) in enumerate(
327
- zip(
328
- gt_mels,
329
- gen_aux_mels,
330
- audios.cpu().float(),
331
- gen_aux_audios.cpu().float(),
332
- recon_audios.cpu().float(),
333
- audio_lengths,
334
- )
335
- ):
336
- if idx > 4:
337
- break
338
-
339
- mel_len = audio_len // self.gt_mel_transform.hop_length
340
-
341
- image_mels = plot_mel(
342
- [
343
- gt_mel[:, :mel_len],
344
- gen_aux_mel[:, :mel_len],
345
- ],
346
- [
347
- "Ground-Truth",
348
- "Auxiliary",
349
- ],
350
- )
351
-
352
- if isinstance(self.logger, WandbLogger):
353
- self.logger.experiment.log(
354
- {
355
- "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
356
- "wavs": [
357
- wandb.Audio(
358
- audio[0, :audio_len],
359
- sample_rate=self.sampling_rate,
360
- caption="gt",
361
- ),
362
- wandb.Audio(
363
- gen_aux_audio[0, :audio_len],
364
- sample_rate=self.sampling_rate,
365
- caption="aux",
366
- ),
367
- wandb.Audio(
368
- recon_audio[0, :audio_len],
369
- sample_rate=self.sampling_rate,
370
- caption="recon",
371
- ),
372
- ],
373
- },
374
- )
375
-
376
- if isinstance(self.logger, TensorBoardLogger):
377
- self.logger.experiment.add_figure(
378
- f"sample-{idx}/mels",
379
- image_mels,
380
- global_step=self.global_step,
381
- )
382
- self.logger.experiment.add_audio(
383
- f"sample-{idx}/wavs/gt",
384
- audio[0, :audio_len],
385
- self.global_step,
386
- sample_rate=self.sampling_rate,
387
- )
388
- self.logger.experiment.add_audio(
389
- f"sample-{idx}/wavs/gen",
390
- gen_aux_audio[0, :audio_len],
391
- self.global_step,
392
- sample_rate=self.sampling_rate,
393
- )
394
- self.logger.experiment.add_audio(
395
- f"sample-{idx}/wavs/recon",
396
- recon_audio[0, :audio_len],
397
- self.global_step,
398
- sample_rate=self.sampling_rate,
399
- )
400
-
401
- plt.close(image_mels)
402
-
403
- def encode(self, audios, audio_lengths):
404
- audios = audios.float()
405
-
406
- mels = self.encode_mel_transform(audios)
407
- mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
408
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
409
- mel_masks_float_conv = mel_masks[:, None, :].float()
410
- mels = mels * mel_masks_float_conv
411
-
412
- # Encode
413
- encoded_features = self.encoder(mels) * mel_masks_float_conv
414
- feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
415
-
416
- return self.quantizer.encode(encoded_features), feature_lengths
417
-
418
- def decode(self, indices, feature_lengths, return_audios=False):
419
- factor = math.prod(self.quantizer.downsample_factor)
420
- mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
421
- mel_masks_float_conv = mel_masks[:, None, :].float()
422
-
423
- z = self.quantizer.decode(indices) * mel_masks_float_conv
424
- z = (
425
- z
426
- + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
427
- :, :, None
428
- ]
429
- )
430
-
431
- gen_mel = (
432
- self.decoder(
433
- torch.randn_like(z) * mel_masks_float_conv,
434
- condition=z,
435
- )
436
- * mel_masks_float_conv
437
- )
438
-
439
- if return_audios:
440
- return self.vocoder(gen_mel)
441
-
442
- return gen_mel
@@ -1,44 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.utils.parametrizations import weight_norm
4
-
5
-
6
- class Discriminator(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
-
10
- blocks = []
11
- convs = [
12
- (1, 64, (3, 9), 1, (1, 4)),
13
- (64, 128, (3, 9), (1, 2), (1, 4)),
14
- (128, 256, (3, 9), (1, 2), (1, 4)),
15
- (256, 512, (3, 9), (1, 2), (1, 4)),
16
- (512, 1024, (3, 3), 1, (1, 1)),
17
- (1024, 1, (3, 3), 1, (1, 1)),
18
- ]
19
-
20
- for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
21
- convs
22
- ):
23
- blocks.append(
24
- weight_norm(
25
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
26
- )
27
- )
28
-
29
- if idx != len(convs) - 1:
30
- blocks.append(nn.SiLU(inplace=True))
31
-
32
- self.blocks = nn.Sequential(*blocks)
33
-
34
- def forward(self, x):
35
- return self.blocks(x[:, None])[:, 0]
36
-
37
-
38
- if __name__ == "__main__":
39
- model = Discriminator()
40
- print(sum(p.numel() for p in model.parameters()) / 1_000_000)
41
- x = torch.randn(1, 128, 1024)
42
- y = model(x)
43
- print(y.shape)
44
- print(y)
@@ -1,115 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from fish_speech.utils import autocast_exclude_mps
8
-
9
- from .wavenet import WaveNet
10
-
11
-
12
- class ReferenceEncoder(WaveNet):
13
- def __init__(
14
- self,
15
- input_channels: Optional[int] = None,
16
- output_channels: Optional[int] = None,
17
- residual_channels: int = 512,
18
- residual_layers: int = 20,
19
- dilation_cycle: Optional[int] = 4,
20
- num_heads: int = 8,
21
- latent_len: int = 4,
22
- ):
23
- super().__init__(
24
- input_channels=input_channels,
25
- residual_channels=residual_channels,
26
- residual_layers=residual_layers,
27
- dilation_cycle=dilation_cycle,
28
- )
29
-
30
- self.head_dim = residual_channels // num_heads
31
- self.num_heads = num_heads
32
-
33
- self.latent_len = latent_len
34
- self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
35
-
36
- self.q = nn.Linear(residual_channels, residual_channels, bias=True)
37
- self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
38
- self.q_norm = nn.LayerNorm(self.head_dim)
39
- self.k_norm = nn.LayerNorm(self.head_dim)
40
- self.proj = nn.Linear(residual_channels, residual_channels)
41
- self.proj_drop = nn.Dropout(0.1)
42
-
43
- self.norm = nn.LayerNorm(residual_channels)
44
- self.mlp = nn.Sequential(
45
- nn.Linear(residual_channels, residual_channels * 4),
46
- nn.SiLU(),
47
- nn.Linear(residual_channels * 4, residual_channels),
48
- )
49
- self.output_projection_attn = nn.Linear(residual_channels, output_channels)
50
-
51
- torch.nn.init.trunc_normal_(self.latent, std=0.02)
52
- self.apply(self.init_weights)
53
-
54
- def init_weights(self, m):
55
- if isinstance(m, nn.Linear):
56
- torch.nn.init.trunc_normal_(m.weight, std=0.02)
57
- if m.bias is not None:
58
- torch.nn.init.constant_(m.bias, 0)
59
-
60
- def forward(self, x, attn_mask=None):
61
- x = super().forward(x).mT
62
- B, N, C = x.shape
63
-
64
- # Calculate mask
65
- if attn_mask is not None:
66
- assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
67
-
68
- attn_mask = attn_mask[:, None, None, :].expand(
69
- B, self.num_heads, self.latent_len, N
70
- )
71
-
72
- q_latent = self.latent.expand(B, -1, -1)
73
- q = (
74
- self.q(q_latent)
75
- .reshape(B, self.latent_len, self.num_heads, self.head_dim)
76
- .transpose(1, 2)
77
- )
78
-
79
- kv = (
80
- self.kv(x)
81
- .reshape(B, N, 2, self.num_heads, self.head_dim)
82
- .permute(2, 0, 3, 1, 4)
83
- )
84
- k, v = kv.unbind(0)
85
-
86
- q, k = self.q_norm(q), self.k_norm(k)
87
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
88
-
89
- x = x.transpose(1, 2).reshape(B, self.latent_len, C)
90
- x = self.proj(x)
91
- x = self.proj_drop(x)
92
-
93
- x = x + self.mlp(self.norm(x))
94
- x = self.output_projection_attn(x)
95
- x = x.mean(1)
96
-
97
- return x
98
-
99
-
100
- if __name__ == "__main__":
101
- with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
102
- model = ReferenceEncoder(
103
- input_channels=128,
104
- output_channels=64,
105
- residual_channels=384,
106
- residual_layers=20,
107
- dilation_cycle=4,
108
- num_heads=8,
109
- )
110
- x = torch.randn(4, 128, 64)
111
- mask = torch.ones(4, 64, dtype=torch.bool)
112
- y = model(x, mask)
113
- print(y.shape)
114
- loss = F.mse_loss(y, torch.randn(4, 64))
115
- loss.backward()