xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,427 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from torch.nn.utils import weight_norm
13
+
14
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize import (
15
+ ResidualVQ,
16
+ VectorQuantize,
17
+ FactorizedVectorQuantize,
18
+ LookupFreeQuantize,
19
+ )
20
+
21
+ from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ def init_weights(m):
52
+ if isinstance(m, nn.Conv1d):
53
+ nn.init.trunc_normal_(m.weight, std=0.02)
54
+ nn.init.constant_(m.bias, 0)
55
+ if isinstance(m, nn.Linear):
56
+ nn.init.trunc_normal_(m.weight, std=0.02)
57
+ nn.init.constant_(m.bias, 0)
58
+
59
+
60
+ class ResidualUnit(nn.Module):
61
+ def __init__(self, dim: int = 16, dilation: int = 1):
62
+ super().__init__()
63
+ pad = ((7 - 1) * dilation) // 2
64
+ self.block = nn.Sequential(
65
+ Snake1d(dim),
66
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
67
+ Snake1d(dim),
68
+ WNConv1d(dim, dim, kernel_size=1),
69
+ )
70
+
71
+ def forward(self, x):
72
+ y = self.block(x)
73
+ pad = (x.shape[-1] - y.shape[-1]) // 2
74
+ if pad > 0:
75
+ x = x[..., pad:-pad]
76
+ return x + y
77
+
78
+
79
+ class EncoderBlock(nn.Module):
80
+ def __init__(self, dim: int = 16, stride: int = 1):
81
+ super().__init__()
82
+ self.block = nn.Sequential(
83
+ ResidualUnit(dim // 2, dilation=1),
84
+ ResidualUnit(dim // 2, dilation=3),
85
+ ResidualUnit(dim // 2, dilation=9),
86
+ Snake1d(dim // 2),
87
+ WNConv1d(
88
+ dim // 2,
89
+ dim,
90
+ kernel_size=2 * stride,
91
+ stride=stride,
92
+ padding=math.ceil(stride / 2),
93
+ ),
94
+ )
95
+
96
+ def forward(self, x):
97
+ return self.block(x)
98
+
99
+
100
+ class CodecEncoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ d_model: int = 64,
104
+ up_ratios: list = [4, 5, 5, 6],
105
+ out_channels: int = 256,
106
+ use_tanh: bool = False,
107
+ cfg=None,
108
+ ):
109
+ super().__init__()
110
+
111
+ d_model = cfg.d_model if cfg is not None else d_model
112
+ up_ratios = cfg.up_ratios if cfg is not None else up_ratios
113
+ out_channels = cfg.out_channels if cfg is not None else out_channels
114
+ use_tanh = cfg.use_tanh if cfg is not None else use_tanh
115
+
116
+ # Create first convolution
117
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
118
+
119
+ # Create EncoderBlocks that double channels as they downsample by `stride`
120
+ for stride in up_ratios:
121
+ d_model *= 2
122
+ self.block += [EncoderBlock(d_model, stride=stride)]
123
+
124
+ # Create last convolution
125
+ self.block += [
126
+ Snake1d(d_model),
127
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
128
+ ]
129
+
130
+ if use_tanh:
131
+ self.block += [nn.Tanh()]
132
+
133
+ # Wrap black into nn.Sequential
134
+ self.block = nn.Sequential(*self.block)
135
+ self.enc_dim = d_model
136
+
137
+ self.reset_parameters()
138
+
139
+ def forward(self, x):
140
+ return self.block(x)
141
+
142
+ def reset_parameters(self):
143
+ self.apply(init_weights)
144
+
145
+
146
+ class DecoderBlock(nn.Module):
147
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
148
+ super().__init__()
149
+ self.block = nn.Sequential(
150
+ Snake1d(input_dim),
151
+ WNConvTranspose1d(
152
+ input_dim,
153
+ output_dim,
154
+ kernel_size=2 * stride,
155
+ stride=stride,
156
+ padding=stride // 2 + stride % 2,
157
+ output_padding=stride % 2,
158
+ ),
159
+ ResidualUnit(output_dim, dilation=1),
160
+ ResidualUnit(output_dim, dilation=3),
161
+ ResidualUnit(output_dim, dilation=9),
162
+ )
163
+
164
+ def forward(self, x):
165
+ return self.block(x)
166
+
167
+
168
+ class CodecDecoder(nn.Module):
169
+ def __init__(
170
+ self,
171
+ in_channels: int = 256,
172
+ upsample_initial_channel: int = 1536,
173
+ up_ratios: list = [5, 5, 4, 2],
174
+ num_quantizers: int = 8,
175
+ codebook_size: int = 1024,
176
+ codebook_dim: int = 256,
177
+ quantizer_type: str = "vq",
178
+ quantizer_dropout: float = 0.5,
179
+ commitment: float = 0.25,
180
+ codebook_loss_weight: float = 1.0,
181
+ use_l2_normlize: bool = False,
182
+ codebook_type: str = "euclidean",
183
+ kmeans_init: bool = False,
184
+ kmeans_iters: int = 10,
185
+ decay: float = 0.8,
186
+ eps: float = 1e-5,
187
+ threshold_ema_dead_code: int = 2,
188
+ weight_init: bool = False,
189
+ use_vocos: bool = False,
190
+ vocos_dim: int = 384,
191
+ vocos_intermediate_dim: int = 1152,
192
+ vocos_num_layers: int = 8,
193
+ n_fft: int = 800,
194
+ hop_size: int = 200,
195
+ padding: str = "same",
196
+ cfg=None,
197
+ ):
198
+ super().__init__()
199
+
200
+ in_channels = (
201
+ cfg.in_channels
202
+ if cfg is not None and hasattr(cfg, "in_channels")
203
+ else in_channels
204
+ )
205
+ upsample_initial_channel = (
206
+ cfg.upsample_initial_channel
207
+ if cfg is not None and hasattr(cfg, "upsample_initial_channel")
208
+ else upsample_initial_channel
209
+ )
210
+ up_ratios = (
211
+ cfg.up_ratios
212
+ if cfg is not None and hasattr(cfg, "up_ratios")
213
+ else up_ratios
214
+ )
215
+ num_quantizers = (
216
+ cfg.num_quantizers
217
+ if cfg is not None and hasattr(cfg, "num_quantizers")
218
+ else num_quantizers
219
+ )
220
+ codebook_size = (
221
+ cfg.codebook_size
222
+ if cfg is not None and hasattr(cfg, "codebook_size")
223
+ else codebook_size
224
+ )
225
+ codebook_dim = (
226
+ cfg.codebook_dim
227
+ if cfg is not None and hasattr(cfg, "codebook_dim")
228
+ else codebook_dim
229
+ )
230
+ quantizer_type = (
231
+ cfg.quantizer_type
232
+ if cfg is not None and hasattr(cfg, "quantizer_type")
233
+ else quantizer_type
234
+ )
235
+ quantizer_dropout = (
236
+ cfg.quantizer_dropout
237
+ if cfg is not None and hasattr(cfg, "quantizer_dropout")
238
+ else quantizer_dropout
239
+ )
240
+ commitment = (
241
+ cfg.commitment
242
+ if cfg is not None and hasattr(cfg, "commitment")
243
+ else commitment
244
+ )
245
+ codebook_loss_weight = (
246
+ cfg.codebook_loss_weight
247
+ if cfg is not None and hasattr(cfg, "codebook_loss_weight")
248
+ else codebook_loss_weight
249
+ )
250
+ use_l2_normlize = (
251
+ cfg.use_l2_normlize
252
+ if cfg is not None and hasattr(cfg, "use_l2_normlize")
253
+ else use_l2_normlize
254
+ )
255
+ codebook_type = (
256
+ cfg.codebook_type
257
+ if cfg is not None and hasattr(cfg, "codebook_type")
258
+ else codebook_type
259
+ )
260
+ kmeans_init = (
261
+ cfg.kmeans_init
262
+ if cfg is not None and hasattr(cfg, "kmeans_init")
263
+ else kmeans_init
264
+ )
265
+ kmeans_iters = (
266
+ cfg.kmeans_iters
267
+ if cfg is not None and hasattr(cfg, "kmeans_iters")
268
+ else kmeans_iters
269
+ )
270
+ decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
271
+ eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
272
+ threshold_ema_dead_code = (
273
+ cfg.threshold_ema_dead_code
274
+ if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
275
+ else threshold_ema_dead_code
276
+ )
277
+ weight_init = (
278
+ cfg.weight_init
279
+ if cfg is not None and hasattr(cfg, "weight_init")
280
+ else weight_init
281
+ )
282
+ use_vocos = (
283
+ cfg.use_vocos
284
+ if cfg is not None and hasattr(cfg, "use_vocos")
285
+ else use_vocos
286
+ )
287
+ vocos_dim = (
288
+ cfg.vocos_dim
289
+ if cfg is not None and hasattr(cfg, "vocos_dim")
290
+ else vocos_dim
291
+ )
292
+ vocos_intermediate_dim = (
293
+ cfg.vocos_intermediate_dim
294
+ if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
295
+ else vocos_intermediate_dim
296
+ )
297
+ vocos_num_layers = (
298
+ cfg.vocos_num_layers
299
+ if cfg is not None and hasattr(cfg, "vocos_num_layers")
300
+ else vocos_num_layers
301
+ )
302
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
303
+ hop_size = (
304
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
305
+ )
306
+ padding = (
307
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
308
+ )
309
+
310
+ if quantizer_type == "vq":
311
+ self.quantizer = ResidualVQ(
312
+ input_dim=in_channels,
313
+ num_quantizers=num_quantizers,
314
+ codebook_size=codebook_size,
315
+ codebook_dim=codebook_dim,
316
+ quantizer_type=quantizer_type,
317
+ quantizer_dropout=quantizer_dropout,
318
+ commitment=commitment,
319
+ codebook_loss_weight=codebook_loss_weight,
320
+ use_l2_normlize=use_l2_normlize,
321
+ codebook_type=codebook_type,
322
+ kmeans_init=kmeans_init,
323
+ kmeans_iters=kmeans_iters,
324
+ decay=decay,
325
+ eps=eps,
326
+ threshold_ema_dead_code=threshold_ema_dead_code,
327
+ weight_init=weight_init,
328
+ )
329
+ elif quantizer_type == "fvq":
330
+ self.quantizer = ResidualVQ(
331
+ input_dim=in_channels,
332
+ num_quantizers=num_quantizers,
333
+ codebook_size=codebook_size,
334
+ codebook_dim=codebook_dim,
335
+ quantizer_type=quantizer_type,
336
+ quantizer_dropout=quantizer_dropout,
337
+ commitment=commitment,
338
+ codebook_loss_weight=codebook_loss_weight,
339
+ use_l2_normlize=use_l2_normlize,
340
+ )
341
+ elif quantizer_type == "lfq":
342
+ self.quantizer = ResidualVQ(
343
+ input_dim=in_channels,
344
+ num_quantizers=num_quantizers,
345
+ codebook_size=codebook_size,
346
+ codebook_dim=codebook_dim,
347
+ quantizer_type=quantizer_type,
348
+ )
349
+ else:
350
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
351
+
352
+ if not use_vocos:
353
+ # Add first conv layer
354
+ channels = upsample_initial_channel
355
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
356
+
357
+ # Add upsampling + MRF blocks
358
+ for i, stride in enumerate(up_ratios):
359
+ input_dim = channels // 2**i
360
+ output_dim = channels // 2 ** (i + 1)
361
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
362
+
363
+ # Add final conv layer
364
+ layers += [
365
+ Snake1d(output_dim),
366
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
367
+ nn.Tanh(),
368
+ ]
369
+
370
+ self.model = nn.Sequential(*layers)
371
+
372
+ if use_vocos:
373
+ self.model = Vocos(
374
+ input_channels=in_channels,
375
+ dim=vocos_dim,
376
+ intermediate_dim=vocos_intermediate_dim,
377
+ num_layers=vocos_num_layers,
378
+ adanorm_num_embeddings=None,
379
+ n_fft=n_fft,
380
+ hop_size=hop_size,
381
+ padding=padding,
382
+ )
383
+
384
+ self.reset_parameters()
385
+
386
+ def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
387
+ """
388
+ if vq is True, x = encoder output, then return quantized output;
389
+ else, x = quantized output, then return decoder output
390
+ """
391
+ if vq is True:
392
+ if eval_vq:
393
+ self.quantizer.eval()
394
+ (
395
+ quantized_out,
396
+ all_indices,
397
+ all_commit_losses,
398
+ all_codebook_losses,
399
+ all_quantized,
400
+ ) = self.quantizer(x, n_quantizers=n_quantizers)
401
+ return (
402
+ quantized_out,
403
+ all_indices,
404
+ all_commit_losses,
405
+ all_codebook_losses,
406
+ all_quantized,
407
+ )
408
+
409
+ return self.model(x)
410
+
411
+ def quantize(self, x, n_quantizers=None):
412
+ self.quantizer.eval()
413
+ quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
414
+ return quantized_out, vq
415
+
416
+ # TODO: check consistency of vq2emb and quantize
417
+ def vq2emb(self, vq, n_quantizers=None):
418
+ return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
419
+
420
+ def decode(self, x):
421
+ return self.model(x)
422
+
423
+ def latent2dist(self, x, n_quantizers=None):
424
+ return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
425
+
426
+ def reset_parameters(self):
427
+ self.apply(init_weights)
@@ -0,0 +1,11 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
7
+ FactorizedVectorQuantize,
8
+ )
9
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
10
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
11
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
@@ -0,0 +1,150 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class FactorizedVectorQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ commitment=0.005,
29
+ codebook_loss_weight=1.0,
30
+ use_l2_normlize=True,
31
+ ):
32
+ super().__init__()
33
+ self.input_dim = input_dim
34
+ self.codebook_size = codebook_size
35
+ self.codebook_dim = codebook_dim
36
+ self.commitment = commitment
37
+ self.codebook_loss_weight = codebook_loss_weight
38
+ self.use_l2_normlize = use_l2_normlize
39
+
40
+ if self.input_dim != self.codebook_dim:
41
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
42
+ self.out_project = WNConv1d(
43
+ self.codebook_dim, self.input_dim, kernel_size=1
44
+ )
45
+
46
+ else:
47
+ self.in_project = nn.Identity()
48
+ self.out_project = nn.Identity()
49
+
50
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
51
+
52
+ def forward(self, z):
53
+ """
54
+ Parameters
55
+ ----------
56
+ z: torch.Tensor[B x D x T]
57
+
58
+ Returns
59
+ -------
60
+ z_q: torch.Tensor[B x D x T]
61
+ Quantized continuous representation of input
62
+ commit_loss: Tensor[B]
63
+ Commitment loss to train encoder to predict vectors closer to codebook entries
64
+ codebook_loss: Tensor[B]
65
+ Codebook loss to update the codebook
66
+ indices: torch.Tensor[B x T]
67
+ Codebook indices (quantized discrete representation of input)
68
+ z_e: torch.Tensor[B x D x T]
69
+ Projected latents (continuous representation of input before quantization)
70
+ """
71
+
72
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
73
+ z_e = self.in_project(z)
74
+ z_q, indices = self.decode_latents(z_e)
75
+
76
+ # Compute commitment loss and codebook loss
77
+ if self.training:
78
+ commit_loss = (
79
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
80
+ * self.commitment
81
+ )
82
+ codebook_loss = (
83
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
84
+ * self.codebook_loss_weight
85
+ )
86
+ else:
87
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
88
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
89
+
90
+ z_q = z_e + (z_q - z_e).detach()
91
+
92
+ z_q = self.out_project(z_q)
93
+
94
+ return z_q, commit_loss, codebook_loss, indices, z_e
95
+
96
+ def embed_code(self, embed_id):
97
+ return F.embedding(embed_id, self.codebook.weight)
98
+
99
+ def decode_code(self, embed_id):
100
+ return self.embed_code(embed_id).transpose(1, 2)
101
+
102
+ def decode_latents(self, latents):
103
+ encodings = rearrange(latents, "b d t -> (b t) d")
104
+ codebook = self.codebook.weight
105
+
106
+ # L2 normalize encodings and codebook
107
+ if self.use_l2_normlize:
108
+ encodings = F.normalize(encodings)
109
+ codebook = F.normalize(codebook)
110
+
111
+ # Compute euclidean distance between encodings and codebook,
112
+ # if use_l2_normlize is True, the distance is equal to cosine distance
113
+ dist = (
114
+ encodings.pow(2).sum(1, keepdim=True)
115
+ - 2 * encodings @ codebook.t()
116
+ + codebook.pow(2).sum(1, keepdim=True).t()
117
+ )
118
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
119
+ z_q = self.decode_code(indices)
120
+
121
+ return z_q, indices
122
+
123
+ def vq2emb(self, vq, out_proj=True):
124
+ emb = self.decode_code(vq)
125
+ if out_proj:
126
+ emb = self.out_project(emb)
127
+ return emb
128
+
129
+ def latent2dist(self, latents):
130
+ encodings = rearrange(latents, "b d t -> (b t) d")
131
+ codebook = self.codebook.weight
132
+
133
+ # L2 normalize encodings and codebook
134
+ if self.use_l2_normlize:
135
+ encodings = F.normalize(encodings)
136
+ codebook = F.normalize(codebook)
137
+
138
+ # Compute euclidean distance between encodings and codebook,
139
+ # if use_l2_normlize is True, the distance is equal to cosine distance
140
+ dist = (
141
+ encodings.pow(2).sum(1, keepdim=True)
142
+ - 2 * encodings @ codebook.t()
143
+ + codebook.pow(2).sum(1, keepdim=True).t()
144
+ ) # (b*t, k)
145
+
146
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
147
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
148
+ z_q = self.decode_code(indices)
149
+
150
+ return -dist, indices, z_q
@@ -0,0 +1,77 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class LookupFreeQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ ):
29
+ super().__init__()
30
+ self.input_dim = input_dim
31
+ self.codebook_size = codebook_size
32
+ self.codebook_dim = codebook_dim
33
+
34
+ assert 2**codebook_dim == codebook_size
35
+
36
+ if self.input_dim != self.codebook_dim:
37
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
38
+ self.out_project = WNConv1d(
39
+ self.codebook_dim, self.input_dim, kernel_size=1
40
+ )
41
+
42
+ else:
43
+ self.in_project = nn.Identity()
44
+ self.out_project = nn.Identity()
45
+
46
+ def forward(self, z):
47
+ z_e = self.in_project(z)
48
+ z_e = F.sigmoid(z_e)
49
+
50
+ z_q = z_e + (torch.round(z_e) - z_e).detach()
51
+
52
+ z_q = self.out_project(z_q)
53
+
54
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
55
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
56
+
57
+ bits = (
58
+ 2
59
+ ** torch.arange(self.codebook_dim, device=z.device)
60
+ .unsqueeze(0)
61
+ .unsqueeze(-1)
62
+ .long()
63
+ ) # (1, d, 1)
64
+ indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
65
+
66
+ return z_q, commit_loss, codebook_loss, indices, z_e
67
+
68
+ def vq2emb(self, vq, out_proj=True):
69
+ emb = torch.zeros(
70
+ vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
71
+ ) # (B, d, T)
72
+ for i in range(self.codebook_dim):
73
+ emb[:, i, :] = (vq % 2).float()
74
+ vq = vq // 2
75
+ if out_proj:
76
+ emb = self.out_project(emb)
77
+ return emb