xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,92 @@
1
+ from dataclasses import dataclass
2
+
3
+ import loralib as lora
4
+
5
+
6
+ @dataclass
7
+ class LoraConfig:
8
+ r: int
9
+ lora_alpha: float
10
+ lora_dropout: float = 0.0
11
+
12
+
13
+ def setup_lora(model, lora_config):
14
+ # Replace the embedding layer with a LoRA layer
15
+ model.embeddings = lora.Embedding(
16
+ num_embeddings=model.embeddings.num_embeddings,
17
+ embedding_dim=model.embeddings.embedding_dim,
18
+ padding_idx=model.embeddings.padding_idx,
19
+ r=lora_config.r,
20
+ lora_alpha=lora_config.lora_alpha,
21
+ )
22
+
23
+ model.codebook_embeddings = lora.Embedding(
24
+ num_embeddings=model.codebook_embeddings.num_embeddings,
25
+ embedding_dim=model.codebook_embeddings.embedding_dim,
26
+ padding_idx=model.codebook_embeddings.padding_idx,
27
+ r=lora_config.r,
28
+ lora_alpha=lora_config.lora_alpha,
29
+ )
30
+
31
+ # Replace output layer with a LoRA layer
32
+ linears = [(model, "output")]
33
+
34
+ # Replace all linear layers with LoRA layers
35
+ for layer in model.layers:
36
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
+ linears.extend(
38
+ [
39
+ (layer.feed_forward, "w1"),
40
+ (layer.feed_forward, "w2"),
41
+ (layer.feed_forward, "w3"),
42
+ ]
43
+ )
44
+
45
+ if hasattr(model, "fast_layers"):
46
+ model.fast_embeddings = lora.Embedding(
47
+ num_embeddings=model.fast_embeddings.num_embeddings,
48
+ embedding_dim=model.fast_embeddings.embedding_dim,
49
+ padding_idx=model.fast_embeddings.padding_idx,
50
+ r=lora_config.r,
51
+ lora_alpha=lora_config.lora_alpha,
52
+ )
53
+
54
+ # Dual-AR model
55
+ linears.append((model, "fast_output"))
56
+
57
+ for layer in model.fast_layers:
58
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
+ linears.extend(
60
+ [
61
+ (layer.feed_forward, "w1"),
62
+ (layer.feed_forward, "w2"),
63
+ (layer.feed_forward, "w3"),
64
+ ]
65
+ )
66
+
67
+ for module, layer in linears:
68
+ updated_linear = lora.Linear(
69
+ in_features=getattr(module, layer).in_features,
70
+ out_features=getattr(module, layer).out_features,
71
+ bias=getattr(module, layer).bias,
72
+ r=lora_config.r,
73
+ lora_alpha=lora_config.lora_alpha,
74
+ lora_dropout=lora_config.lora_dropout,
75
+ )
76
+ setattr(module, layer, updated_linear)
77
+
78
+ # Mark only the LoRA layers as trainable
79
+ lora.mark_only_lora_as_trainable(model, bias="none")
80
+
81
+
82
+ def get_merged_state_dict(model):
83
+ # This line will merge the state dict of the model and the LoRA parameters
84
+ model.eval()
85
+
86
+ # Then we need to remove the LoRA parameters from the state dict
87
+ state_dict = model.state_dict()
88
+ for name in list(state_dict.keys()):
89
+ if "lora" in name:
90
+ state_dict.pop(name)
91
+
92
+ return state_dict
@@ -0,0 +1,3 @@
1
+ from .lit_module import VQGAN
2
+
3
+ __all__ = ["VQGAN"]
@@ -0,0 +1,442 @@
1
+ import itertools
2
+ import math
3
+ from typing import Any, Callable
4
+
5
+ import lightning as L
6
+ import torch
7
+ import torch.nn.functional as F
8
+ # import wandb
9
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
10
+ from matplotlib import pyplot as plt
11
+ from torch import nn
12
+
13
+ from fish_speech.models.vqgan.modules.discriminator import Discriminator
14
+ from fish_speech.models.vqgan.modules.wavenet import WaveNet
15
+ from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
16
+
17
+
18
+ class VQGAN(L.LightningModule):
19
+ def __init__(
20
+ self,
21
+ optimizer: Callable,
22
+ lr_scheduler: Callable,
23
+ encoder: WaveNet,
24
+ quantizer: nn.Module,
25
+ decoder: WaveNet,
26
+ discriminator: Discriminator,
27
+ vocoder: nn.Module,
28
+ encode_mel_transform: nn.Module,
29
+ gt_mel_transform: nn.Module,
30
+ weight_adv: float = 1.0,
31
+ weight_vq: float = 1.0,
32
+ weight_mel: float = 1.0,
33
+ sampling_rate: int = 44100,
34
+ freeze_encoder: bool = False,
35
+ ):
36
+ super().__init__()
37
+
38
+ # Model parameters
39
+ self.optimizer_builder = optimizer
40
+ self.lr_scheduler_builder = lr_scheduler
41
+
42
+ # Modules
43
+ self.encoder = encoder
44
+ self.quantizer = quantizer
45
+ self.decoder = decoder
46
+ self.vocoder = vocoder
47
+ self.discriminator = discriminator
48
+ self.encode_mel_transform = encode_mel_transform
49
+ self.gt_mel_transform = gt_mel_transform
50
+
51
+ # A simple linear layer to project quality to condition channels
52
+ self.quality_projection = nn.Linear(1, 768)
53
+
54
+ # Freeze vocoder
55
+ for param in self.vocoder.parameters():
56
+ param.requires_grad = False
57
+
58
+ # Loss weights
59
+ self.weight_adv = weight_adv
60
+ self.weight_vq = weight_vq
61
+ self.weight_mel = weight_mel
62
+
63
+ # Other parameters
64
+ self.sampling_rate = sampling_rate
65
+
66
+ # Disable strict loading
67
+ self.strict_loading = False
68
+
69
+ # If encoder is frozen
70
+ if freeze_encoder:
71
+ for param in self.encoder.parameters():
72
+ param.requires_grad = False
73
+
74
+ for param in self.quantizer.parameters():
75
+ param.requires_grad = False
76
+
77
+ self.automatic_optimization = False
78
+
79
+ def on_save_checkpoint(self, checkpoint):
80
+ # Do not save vocoder
81
+ state_dict = checkpoint["state_dict"]
82
+ for name in list(state_dict.keys()):
83
+ if "vocoder" in name:
84
+ state_dict.pop(name)
85
+
86
+ def configure_optimizers(self):
87
+ optimizer_generator = self.optimizer_builder(
88
+ itertools.chain(
89
+ self.encoder.parameters(),
90
+ self.quantizer.parameters(),
91
+ self.decoder.parameters(),
92
+ self.quality_projection.parameters(),
93
+ )
94
+ )
95
+ optimizer_discriminator = self.optimizer_builder(
96
+ self.discriminator.parameters()
97
+ )
98
+
99
+ lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
100
+ lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
101
+
102
+ return (
103
+ {
104
+ "optimizer": optimizer_generator,
105
+ "lr_scheduler": {
106
+ "scheduler": lr_scheduler_generator,
107
+ "interval": "step",
108
+ "name": "optimizer/generator",
109
+ },
110
+ },
111
+ {
112
+ "optimizer": optimizer_discriminator,
113
+ "lr_scheduler": {
114
+ "scheduler": lr_scheduler_discriminator,
115
+ "interval": "step",
116
+ "name": "optimizer/discriminator",
117
+ },
118
+ },
119
+ )
120
+
121
+ def training_step(self, batch, batch_idx):
122
+ optim_g, optim_d = self.optimizers()
123
+
124
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
125
+
126
+ audios = audios.float()
127
+ audios = audios[:, None, :]
128
+
129
+ with torch.no_grad():
130
+ encoded_mels = self.encode_mel_transform(audios)
131
+ gt_mels = self.gt_mel_transform(audios)
132
+ quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
133
+ quality = quality.unsqueeze(-1)
134
+
135
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
136
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
137
+ mel_masks_float_conv = mel_masks[:, None, :].float()
138
+ gt_mels = gt_mels * mel_masks_float_conv
139
+ encoded_mels = encoded_mels * mel_masks_float_conv
140
+
141
+ # Encode
142
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
143
+
144
+ # Quantize
145
+ vq_result = self.quantizer(encoded_features)
146
+ loss_vq = getattr("vq_result", "loss", 0.0)
147
+ vq_recon_features = vq_result.z * mel_masks_float_conv
148
+ vq_recon_features = (
149
+ vq_recon_features + self.quality_projection(quality)[:, :, None]
150
+ )
151
+
152
+ # VQ Decode
153
+ gen_mel = (
154
+ self.decoder(
155
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
156
+ condition=vq_recon_features,
157
+ )
158
+ * mel_masks_float_conv
159
+ )
160
+
161
+ # Discriminator
162
+ real_logits = self.discriminator(gt_mels)
163
+ fake_logits = self.discriminator(gen_mel.detach())
164
+ d_mask = F.interpolate(
165
+ mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
166
+ )
167
+
168
+ loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
169
+ loss_fake = avg_with_mask(fake_logits**2, d_mask)
170
+
171
+ loss_d = loss_real + loss_fake
172
+
173
+ self.log(
174
+ "train/discriminator/loss",
175
+ loss_d,
176
+ on_step=True,
177
+ on_epoch=False,
178
+ prog_bar=True,
179
+ logger=True,
180
+ )
181
+
182
+ # Discriminator backward
183
+ optim_d.zero_grad()
184
+ self.manual_backward(loss_d)
185
+ self.clip_gradients(
186
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
187
+ )
188
+ optim_d.step()
189
+
190
+ # Mel Loss, applying l1, using a weighted sum
191
+ mel_distance = (
192
+ gen_mel - gt_mels
193
+ ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
194
+ loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
195
+ loss_mel_mid_freq = avg_with_mask(
196
+ mel_distance[:, 40:70, :], mel_masks_float_conv
197
+ )
198
+ loss_mel_high_freq = avg_with_mask(
199
+ mel_distance[:, 70:, :], mel_masks_float_conv
200
+ )
201
+ loss_mel = (
202
+ loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
203
+ )
204
+
205
+ # Adversarial Loss
206
+ fake_logits = self.discriminator(gen_mel)
207
+ loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
208
+
209
+ # Total loss
210
+ loss = (
211
+ self.weight_vq * loss_vq
212
+ + self.weight_mel * loss_mel
213
+ + self.weight_adv * loss_adv
214
+ )
215
+
216
+ # Log losses
217
+ self.log(
218
+ "train/generator/loss",
219
+ loss,
220
+ on_step=True,
221
+ on_epoch=False,
222
+ prog_bar=True,
223
+ logger=True,
224
+ )
225
+ self.log(
226
+ "train/generator/loss_vq",
227
+ loss_vq,
228
+ on_step=True,
229
+ on_epoch=False,
230
+ prog_bar=False,
231
+ logger=True,
232
+ )
233
+ self.log(
234
+ "train/generator/loss_mel",
235
+ loss_mel,
236
+ on_step=True,
237
+ on_epoch=False,
238
+ prog_bar=False,
239
+ logger=True,
240
+ )
241
+ self.log(
242
+ "train/generator/loss_adv",
243
+ loss_adv,
244
+ on_step=True,
245
+ on_epoch=False,
246
+ prog_bar=False,
247
+ logger=True,
248
+ )
249
+
250
+ # Generator backward
251
+ optim_g.zero_grad()
252
+ self.manual_backward(loss)
253
+ self.clip_gradients(
254
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
255
+ )
256
+ optim_g.step()
257
+
258
+ scheduler_g, scheduler_d = self.lr_schedulers()
259
+ scheduler_g.step()
260
+ scheduler_d.step()
261
+
262
+ def validation_step(self, batch: Any, batch_idx: int):
263
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
264
+
265
+ audios = audios.float()
266
+ audios = audios[:, None, :]
267
+
268
+ encoded_mels = self.encode_mel_transform(audios)
269
+ gt_mels = self.gt_mel_transform(audios)
270
+
271
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
272
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
273
+ mel_masks_float_conv = mel_masks[:, None, :].float()
274
+ gt_mels = gt_mels * mel_masks_float_conv
275
+ encoded_mels = encoded_mels * mel_masks_float_conv
276
+
277
+ # Encode
278
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
279
+
280
+ # Quantize
281
+ vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
282
+ vq_recon_features = (
283
+ vq_recon_features
284
+ + self.quality_projection(
285
+ torch.ones(
286
+ vq_recon_features.shape[0], 1, device=vq_recon_features.device
287
+ )
288
+ * 2
289
+ )[:, :, None]
290
+ )
291
+
292
+ # VQ Decode
293
+ gen_aux_mels = (
294
+ self.decoder(
295
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
296
+ condition=vq_recon_features,
297
+ )
298
+ * mel_masks_float_conv
299
+ )
300
+ loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
301
+
302
+ self.log(
303
+ "val/loss_mel",
304
+ loss_mel,
305
+ on_step=False,
306
+ on_epoch=True,
307
+ prog_bar=False,
308
+ logger=True,
309
+ sync_dist=True,
310
+ )
311
+
312
+ recon_audios = self.vocoder(gt_mels)
313
+ gen_aux_audios = self.vocoder(gen_aux_mels)
314
+
315
+ # only log the first batch
316
+ if batch_idx != 0:
317
+ return
318
+
319
+ for idx, (
320
+ gt_mel,
321
+ gen_aux_mel,
322
+ audio,
323
+ gen_aux_audio,
324
+ recon_audio,
325
+ audio_len,
326
+ ) in enumerate(
327
+ zip(
328
+ gt_mels,
329
+ gen_aux_mels,
330
+ audios.cpu().float(),
331
+ gen_aux_audios.cpu().float(),
332
+ recon_audios.cpu().float(),
333
+ audio_lengths,
334
+ )
335
+ ):
336
+ if idx > 4:
337
+ break
338
+
339
+ mel_len = audio_len // self.gt_mel_transform.hop_length
340
+
341
+ image_mels = plot_mel(
342
+ [
343
+ gt_mel[:, :mel_len],
344
+ gen_aux_mel[:, :mel_len],
345
+ ],
346
+ [
347
+ "Ground-Truth",
348
+ "Auxiliary",
349
+ ],
350
+ )
351
+
352
+ if isinstance(self.logger, WandbLogger):
353
+ self.logger.experiment.log(
354
+ {
355
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
356
+ "wavs": [
357
+ wandb.Audio(
358
+ audio[0, :audio_len],
359
+ sample_rate=self.sampling_rate,
360
+ caption="gt",
361
+ ),
362
+ wandb.Audio(
363
+ gen_aux_audio[0, :audio_len],
364
+ sample_rate=self.sampling_rate,
365
+ caption="aux",
366
+ ),
367
+ wandb.Audio(
368
+ recon_audio[0, :audio_len],
369
+ sample_rate=self.sampling_rate,
370
+ caption="recon",
371
+ ),
372
+ ],
373
+ },
374
+ )
375
+
376
+ if isinstance(self.logger, TensorBoardLogger):
377
+ self.logger.experiment.add_figure(
378
+ f"sample-{idx}/mels",
379
+ image_mels,
380
+ global_step=self.global_step,
381
+ )
382
+ self.logger.experiment.add_audio(
383
+ f"sample-{idx}/wavs/gt",
384
+ audio[0, :audio_len],
385
+ self.global_step,
386
+ sample_rate=self.sampling_rate,
387
+ )
388
+ self.logger.experiment.add_audio(
389
+ f"sample-{idx}/wavs/gen",
390
+ gen_aux_audio[0, :audio_len],
391
+ self.global_step,
392
+ sample_rate=self.sampling_rate,
393
+ )
394
+ self.logger.experiment.add_audio(
395
+ f"sample-{idx}/wavs/recon",
396
+ recon_audio[0, :audio_len],
397
+ self.global_step,
398
+ sample_rate=self.sampling_rate,
399
+ )
400
+
401
+ plt.close(image_mels)
402
+
403
+ def encode(self, audios, audio_lengths):
404
+ audios = audios.float()
405
+
406
+ mels = self.encode_mel_transform(audios)
407
+ mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
408
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
409
+ mel_masks_float_conv = mel_masks[:, None, :].float()
410
+ mels = mels * mel_masks_float_conv
411
+
412
+ # Encode
413
+ encoded_features = self.encoder(mels) * mel_masks_float_conv
414
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
415
+
416
+ return self.quantizer.encode(encoded_features), feature_lengths
417
+
418
+ def decode(self, indices, feature_lengths, return_audios=False):
419
+ factor = math.prod(self.quantizer.downsample_factor)
420
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
421
+ mel_masks_float_conv = mel_masks[:, None, :].float()
422
+
423
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
424
+ z = (
425
+ z
426
+ + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
427
+ :, :, None
428
+ ]
429
+ )
430
+
431
+ gen_mel = (
432
+ self.decoder(
433
+ torch.randn_like(z) * mel_masks_float_conv,
434
+ condition=z,
435
+ )
436
+ * mel_masks_float_conv
437
+ )
438
+
439
+ if return_audios:
440
+ return self.vocoder(gen_mel)
441
+
442
+ return gen_mel
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils.parametrizations import weight_norm
4
+
5
+
6
+ class Discriminator(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ blocks = []
11
+ convs = [
12
+ (1, 64, (3, 9), 1, (1, 4)),
13
+ (64, 128, (3, 9), (1, 2), (1, 4)),
14
+ (128, 256, (3, 9), (1, 2), (1, 4)),
15
+ (256, 512, (3, 9), (1, 2), (1, 4)),
16
+ (512, 1024, (3, 3), 1, (1, 1)),
17
+ (1024, 1, (3, 3), 1, (1, 1)),
18
+ ]
19
+
20
+ for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
21
+ convs
22
+ ):
23
+ blocks.append(
24
+ weight_norm(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
26
+ )
27
+ )
28
+
29
+ if idx != len(convs) - 1:
30
+ blocks.append(nn.SiLU(inplace=True))
31
+
32
+ self.blocks = nn.Sequential(*blocks)
33
+
34
+ def forward(self, x):
35
+ return self.blocks(x[:, None])[:, 0]
36
+
37
+
38
+ if __name__ == "__main__":
39
+ model = Discriminator()
40
+ print(sum(p.numel() for p in model.parameters()) / 1_000_000)
41
+ x = torch.randn(1, 128, 1024)
42
+ y = model(x)
43
+ print(y.shape)
44
+ print(y)