xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,850 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import scipy
10
+ import torch
11
+ from torch import nn, view_as_real, view_as_complex
12
+ from torch import nn
13
+ from torch.nn.utils import weight_norm, remove_weight_norm
14
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
15
+
16
+
17
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
18
+ """
19
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
20
+
21
+ Args:
22
+ x (Tensor): Input tensor.
23
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
24
+
25
+ Returns:
26
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
27
+ """
28
+ return torch.log(torch.clip(x, min=clip_val))
29
+
30
+
31
+ def symlog(x: torch.Tensor) -> torch.Tensor:
32
+ return torch.sign(x) * torch.log1p(x.abs())
33
+
34
+
35
+ def symexp(x: torch.Tensor) -> torch.Tensor:
36
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
37
+
38
+
39
+ class STFT(nn.Module):
40
+ def __init__(
41
+ self,
42
+ n_fft: int,
43
+ hop_length: int,
44
+ win_length: int,
45
+ center=True,
46
+ ):
47
+ super().__init__()
48
+ self.center = center
49
+ self.n_fft = n_fft
50
+ self.hop_length = hop_length
51
+ self.win_length = win_length
52
+ window = torch.hann_window(win_length)
53
+ self.register_buffer("window", window)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ # x: (B, T * hop_length)
57
+
58
+ if not self.center:
59
+ pad = self.win_length - self.hop_length
60
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
61
+
62
+ stft_spec = torch.stft(
63
+ x,
64
+ self.n_fft,
65
+ hop_length=self.hop_length,
66
+ win_length=self.win_length,
67
+ window=self.window,
68
+ center=self.center,
69
+ return_complex=False,
70
+ ) # (B, n_fft // 2 + 1, T, 2)
71
+
72
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
73
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
74
+
75
+ log_mag = torch.log(
76
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
77
+ ) # (B, n_fft // 2 + 1, T)
78
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
79
+
80
+ return log_mag, phase
81
+
82
+
83
+ class ISTFT(nn.Module):
84
+ """
85
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
86
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
87
+ See issue: https://github.com/pytorch/pytorch/issues/62323
88
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
89
+ The NOLA constraint is met as we trim padded samples anyway.
90
+
91
+ Args:
92
+ n_fft (int): Size of Fourier transform.
93
+ hop_length (int): The distance between neighboring sliding window frames.
94
+ win_length (int): The size of window frame and STFT filter.
95
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
96
+ """
97
+
98
+ def __init__(
99
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
100
+ ):
101
+ super().__init__()
102
+ if padding not in ["center", "same"]:
103
+ raise ValueError("Padding must be 'center' or 'same'.")
104
+ self.padding = padding
105
+ self.n_fft = n_fft
106
+ self.hop_length = hop_length
107
+ self.win_length = win_length
108
+ window = torch.hann_window(win_length)
109
+ self.register_buffer("window", window)
110
+
111
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
112
+ """
113
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
114
+
115
+ Args:
116
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
117
+ N is the number of frequency bins, and T is the number of time frames.
118
+
119
+ Returns:
120
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
121
+ """
122
+ if self.padding == "center":
123
+ # Fallback to pytorch native implementation
124
+ return torch.istft(
125
+ spec,
126
+ self.n_fft,
127
+ self.hop_length,
128
+ self.win_length,
129
+ self.window,
130
+ center=True,
131
+ )
132
+ elif self.padding == "same":
133
+ pad = (self.win_length - self.hop_length) // 2
134
+ else:
135
+ raise ValueError("Padding must be 'center' or 'same'.")
136
+
137
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
138
+ B, N, T = spec.shape
139
+
140
+ # Inverse FFT
141
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
142
+ ifft = ifft * self.window[None, :, None]
143
+
144
+ # Overlap and Add
145
+ output_size = (T - 1) * self.hop_length + self.win_length
146
+ y = torch.nn.functional.fold(
147
+ ifft,
148
+ output_size=(1, output_size),
149
+ kernel_size=(1, self.win_length),
150
+ stride=(1, self.hop_length),
151
+ )[:, 0, 0, pad:-pad]
152
+
153
+ # Window envelope
154
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
155
+ window_envelope = torch.nn.functional.fold(
156
+ window_sq,
157
+ output_size=(1, output_size),
158
+ kernel_size=(1, self.win_length),
159
+ stride=(1, self.hop_length),
160
+ ).squeeze()[pad:-pad]
161
+
162
+ # Normalize
163
+ assert (window_envelope > 1e-11).all()
164
+ y = y / window_envelope
165
+
166
+ return y
167
+
168
+
169
+ class MDCT(nn.Module):
170
+ """
171
+ Modified Discrete Cosine Transform (MDCT) module.
172
+
173
+ Args:
174
+ frame_len (int): Length of the MDCT frame.
175
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
176
+ """
177
+
178
+ def __init__(self, frame_len: int, padding: str = "same"):
179
+ super().__init__()
180
+ if padding not in ["center", "same"]:
181
+ raise ValueError("Padding must be 'center' or 'same'.")
182
+ self.padding = padding
183
+ self.frame_len = frame_len
184
+ N = frame_len // 2
185
+ n0 = (N + 1) / 2
186
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
187
+ self.register_buffer("window", window)
188
+
189
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
190
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
191
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
192
+ # https://github.com/pytorch/pytorch/issues/71613
193
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
194
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
195
+
196
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
199
+
200
+ Args:
201
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
202
+ and T is the length of the audio.
203
+
204
+ Returns:
205
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
206
+ and N is the number of frequency bins.
207
+ """
208
+ if self.padding == "center":
209
+ audio = torch.nn.functional.pad(
210
+ audio, (self.frame_len // 2, self.frame_len // 2)
211
+ )
212
+ elif self.padding == "same":
213
+ # hop_length is 1/2 frame_len
214
+ audio = torch.nn.functional.pad(
215
+ audio, (self.frame_len // 4, self.frame_len // 4)
216
+ )
217
+ else:
218
+ raise ValueError("Padding must be 'center' or 'same'.")
219
+
220
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
221
+ N = self.frame_len // 2
222
+ x = x * self.window.expand(x.shape)
223
+ X = torch.fft.fft(
224
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
225
+ )[..., :N]
226
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
227
+ return torch.real(res) * np.sqrt(2)
228
+
229
+
230
+ class IMDCT(nn.Module):
231
+ """
232
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
233
+
234
+ Args:
235
+ frame_len (int): Length of the MDCT frame.
236
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
237
+ """
238
+
239
+ def __init__(self, frame_len: int, padding: str = "same"):
240
+ super().__init__()
241
+ if padding not in ["center", "same"]:
242
+ raise ValueError("Padding must be 'center' or 'same'.")
243
+ self.padding = padding
244
+ self.frame_len = frame_len
245
+ N = frame_len // 2
246
+ n0 = (N + 1) / 2
247
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
248
+ self.register_buffer("window", window)
249
+
250
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
251
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
252
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
253
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
254
+
255
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
256
+ """
257
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
258
+
259
+ Args:
260
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
261
+ L is the number of frames, and N is the number of frequency bins.
262
+
263
+ Returns:
264
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
265
+ """
266
+ B, L, N = X.shape
267
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
268
+ Y[..., :N] = X
269
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
270
+ y = torch.fft.ifft(
271
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
272
+ )
273
+ y = (
274
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
275
+ * np.sqrt(N)
276
+ * np.sqrt(2)
277
+ )
278
+ result = y * self.window.expand(y.shape)
279
+ output_size = (1, (L + 1) * N)
280
+ audio = torch.nn.functional.fold(
281
+ result.transpose(1, 2),
282
+ output_size=output_size,
283
+ kernel_size=(1, self.frame_len),
284
+ stride=(1, self.frame_len // 2),
285
+ )[:, 0, 0, :]
286
+
287
+ if self.padding == "center":
288
+ pad = self.frame_len // 2
289
+ elif self.padding == "same":
290
+ pad = self.frame_len // 4
291
+ else:
292
+ raise ValueError("Padding must be 'center' or 'same'.")
293
+
294
+ audio = audio[:, pad:-pad]
295
+ return audio
296
+
297
+
298
+ class FourierHead(nn.Module):
299
+ """Base class for inverse fourier modules."""
300
+
301
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
302
+ """
303
+ Args:
304
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
305
+ L is the sequence length, and H denotes the model dimension.
306
+
307
+ Returns:
308
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
309
+ """
310
+ raise NotImplementedError("Subclasses must implement the forward method.")
311
+
312
+
313
+ class ISTFTHead(FourierHead):
314
+ """
315
+ ISTFT Head module for predicting STFT complex coefficients.
316
+
317
+ Args:
318
+ dim (int): Hidden dimension of the model.
319
+ n_fft (int): Size of Fourier transform.
320
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
321
+ the resolution of the input features.
322
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
323
+ """
324
+
325
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
326
+ super().__init__()
327
+ out_dim = n_fft + 2
328
+ self.out = torch.nn.Linear(dim, out_dim)
329
+ self.istft = ISTFT(
330
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
331
+ )
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ """
335
+ Forward pass of the ISTFTHead module.
336
+
337
+ Args:
338
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
339
+ L is the sequence length, and H denotes the model dimension.
340
+
341
+ Returns:
342
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
343
+ """
344
+ x = self.out(x).transpose(1, 2)
345
+ mag, p = x.chunk(2, dim=1)
346
+ mag = torch.exp(mag)
347
+ mag = torch.clip(
348
+ mag, max=1e2
349
+ ) # safeguard to prevent excessively large magnitudes
350
+ # wrapping happens here. These two lines produce real and imaginary value
351
+ x = torch.cos(p)
352
+ y = torch.sin(p)
353
+ # recalculating phase here does not produce anything new
354
+ # only costs time
355
+ # phase = torch.atan2(y, x)
356
+ # S = mag * torch.exp(phase * 1j)
357
+ # better directly produce the complex value
358
+ S = mag * (x + 1j * y)
359
+ audio = self.istft(S)
360
+ return audio
361
+
362
+
363
+ class IMDCTSymExpHead(FourierHead):
364
+ """
365
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
366
+
367
+ Args:
368
+ dim (int): Hidden dimension of the model.
369
+ mdct_frame_len (int): Length of the MDCT frame.
370
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
371
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
372
+ based on perceptual scaling. Defaults to None.
373
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ dim: int,
379
+ mdct_frame_len: int,
380
+ padding: str = "same",
381
+ sample_rate: Optional[int] = None,
382
+ clip_audio: bool = False,
383
+ ):
384
+ super().__init__()
385
+ out_dim = mdct_frame_len // 2
386
+ self.out = nn.Linear(dim, out_dim)
387
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
388
+ self.clip_audio = clip_audio
389
+
390
+ if sample_rate is not None:
391
+ # optionally init the last layer following mel-scale
392
+ m_max = _hz_to_mel(sample_rate // 2)
393
+ m_pts = torch.linspace(0, m_max, out_dim)
394
+ f_pts = _mel_to_hz(m_pts)
395
+ scale = 1 - (f_pts / f_pts.max())
396
+
397
+ with torch.no_grad():
398
+ self.out.weight.mul_(scale.view(-1, 1))
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ """
402
+ Forward pass of the IMDCTSymExpHead module.
403
+
404
+ Args:
405
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
406
+ L is the sequence length, and H denotes the model dimension.
407
+
408
+ Returns:
409
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
410
+ """
411
+ x = self.out(x)
412
+ x = symexp(x)
413
+ x = torch.clip(
414
+ x, min=-1e2, max=1e2
415
+ ) # safeguard to prevent excessively large magnitudes
416
+ audio = self.imdct(x)
417
+ if self.clip_audio:
418
+ audio = torch.clip(x, min=-1.0, max=1.0)
419
+
420
+ return audio
421
+
422
+
423
+ class IMDCTCosHead(FourierHead):
424
+ """
425
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
426
+
427
+ Args:
428
+ dim (int): Hidden dimension of the model.
429
+ mdct_frame_len (int): Length of the MDCT frame.
430
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
431
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ dim: int,
437
+ mdct_frame_len: int,
438
+ padding: str = "same",
439
+ clip_audio: bool = False,
440
+ ):
441
+ super().__init__()
442
+ self.clip_audio = clip_audio
443
+ self.out = nn.Linear(dim, mdct_frame_len)
444
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
445
+
446
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
447
+ """
448
+ Forward pass of the IMDCTCosHead module.
449
+
450
+ Args:
451
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
452
+ L is the sequence length, and H denotes the model dimension.
453
+
454
+ Returns:
455
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
456
+ """
457
+ x = self.out(x)
458
+ m, p = x.chunk(2, dim=2)
459
+ m = torch.exp(m).clip(
460
+ max=1e2
461
+ ) # safeguard to prevent excessively large magnitudes
462
+ audio = self.imdct(m * torch.cos(p))
463
+ if self.clip_audio:
464
+ audio = torch.clip(x, min=-1.0, max=1.0)
465
+ return audio
466
+
467
+
468
+ class ConvNeXtBlock(nn.Module):
469
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
470
+
471
+ Args:
472
+ dim (int): Number of input channels.
473
+ intermediate_dim (int): Dimensionality of the intermediate layer.
474
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
475
+ Defaults to None.
476
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
477
+ None means non-conditional LayerNorm. Defaults to None.
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ dim: int,
483
+ intermediate_dim: int,
484
+ layer_scale_init_value: float,
485
+ adanorm_num_embeddings: Optional[int] = None,
486
+ ):
487
+ super().__init__()
488
+ self.dwconv = nn.Conv1d(
489
+ dim, dim, kernel_size=7, padding=3, groups=dim
490
+ ) # depthwise conv
491
+ self.adanorm = adanorm_num_embeddings is not None
492
+ if adanorm_num_embeddings:
493
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
494
+ else:
495
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
496
+ self.pwconv1 = nn.Linear(
497
+ dim, intermediate_dim
498
+ ) # pointwise/1x1 convs, implemented with linear layers
499
+ self.act = nn.GELU()
500
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
501
+ self.gamma = (
502
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
503
+ if layer_scale_init_value > 0
504
+ else None
505
+ )
506
+
507
+ def forward(
508
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
509
+ ) -> torch.Tensor:
510
+ residual = x
511
+ x = self.dwconv(x)
512
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
513
+ if self.adanorm:
514
+ assert cond_embedding_id is not None
515
+ x = self.norm(x, cond_embedding_id)
516
+ else:
517
+ x = self.norm(x)
518
+ x = self.pwconv1(x)
519
+ x = self.act(x)
520
+ x = self.pwconv2(x)
521
+ if self.gamma is not None:
522
+ x = self.gamma * x
523
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
524
+
525
+ x = residual + x
526
+ return x
527
+
528
+
529
+ class AdaLayerNorm(nn.Module):
530
+ """
531
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
532
+
533
+ Args:
534
+ num_embeddings (int): Number of embeddings.
535
+ embedding_dim (int): Dimension of the embeddings.
536
+ """
537
+
538
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
539
+ super().__init__()
540
+ self.eps = eps
541
+ self.dim = embedding_dim
542
+ self.scale = nn.Embedding(
543
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
544
+ )
545
+ self.shift = nn.Embedding(
546
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
547
+ )
548
+ torch.nn.init.ones_(self.scale.weight)
549
+ torch.nn.init.zeros_(self.shift.weight)
550
+
551
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
552
+ scale = self.scale(cond_embedding_id)
553
+ shift = self.shift(cond_embedding_id)
554
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
555
+ x = x * scale + shift
556
+ return x
557
+
558
+
559
+ class ResBlock1(nn.Module):
560
+ """
561
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
562
+ but without upsampling layers.
563
+
564
+ Args:
565
+ dim (int): Number of input channels.
566
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
567
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
568
+ Defaults to (1, 3, 5).
569
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
570
+ Defaults to 0.1.
571
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
572
+ Defaults to None.
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ dim: int,
578
+ kernel_size: int = 3,
579
+ dilation: Tuple[int, int, int] = (1, 3, 5),
580
+ lrelu_slope: float = 0.1,
581
+ layer_scale_init_value: Optional[float] = None,
582
+ ):
583
+ super().__init__()
584
+ self.lrelu_slope = lrelu_slope
585
+ self.convs1 = nn.ModuleList(
586
+ [
587
+ weight_norm(
588
+ nn.Conv1d(
589
+ dim,
590
+ dim,
591
+ kernel_size,
592
+ 1,
593
+ dilation=dilation[0],
594
+ padding=self.get_padding(kernel_size, dilation[0]),
595
+ )
596
+ ),
597
+ weight_norm(
598
+ nn.Conv1d(
599
+ dim,
600
+ dim,
601
+ kernel_size,
602
+ 1,
603
+ dilation=dilation[1],
604
+ padding=self.get_padding(kernel_size, dilation[1]),
605
+ )
606
+ ),
607
+ weight_norm(
608
+ nn.Conv1d(
609
+ dim,
610
+ dim,
611
+ kernel_size,
612
+ 1,
613
+ dilation=dilation[2],
614
+ padding=self.get_padding(kernel_size, dilation[2]),
615
+ )
616
+ ),
617
+ ]
618
+ )
619
+
620
+ self.convs2 = nn.ModuleList(
621
+ [
622
+ weight_norm(
623
+ nn.Conv1d(
624
+ dim,
625
+ dim,
626
+ kernel_size,
627
+ 1,
628
+ dilation=1,
629
+ padding=self.get_padding(kernel_size, 1),
630
+ )
631
+ ),
632
+ weight_norm(
633
+ nn.Conv1d(
634
+ dim,
635
+ dim,
636
+ kernel_size,
637
+ 1,
638
+ dilation=1,
639
+ padding=self.get_padding(kernel_size, 1),
640
+ )
641
+ ),
642
+ weight_norm(
643
+ nn.Conv1d(
644
+ dim,
645
+ dim,
646
+ kernel_size,
647
+ 1,
648
+ dilation=1,
649
+ padding=self.get_padding(kernel_size, 1),
650
+ )
651
+ ),
652
+ ]
653
+ )
654
+
655
+ self.gamma = nn.ParameterList(
656
+ [
657
+ (
658
+ nn.Parameter(
659
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
660
+ )
661
+ if layer_scale_init_value is not None
662
+ else None
663
+ ),
664
+ (
665
+ nn.Parameter(
666
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
667
+ )
668
+ if layer_scale_init_value is not None
669
+ else None
670
+ ),
671
+ (
672
+ nn.Parameter(
673
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
674
+ )
675
+ if layer_scale_init_value is not None
676
+ else None
677
+ ),
678
+ ]
679
+ )
680
+
681
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
682
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
683
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
684
+ xt = c1(xt)
685
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
686
+ xt = c2(xt)
687
+ if gamma is not None:
688
+ xt = gamma * xt
689
+ x = xt + x
690
+ return x
691
+
692
+ def remove_weight_norm(self):
693
+ for l in self.convs1:
694
+ remove_weight_norm(l)
695
+ for l in self.convs2:
696
+ remove_weight_norm(l)
697
+
698
+ @staticmethod
699
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
700
+ return int((kernel_size * dilation - dilation) / 2)
701
+
702
+
703
+ class Backbone(nn.Module):
704
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
705
+
706
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
707
+ """
708
+ Args:
709
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
710
+ C denotes output features, and L is the sequence length.
711
+
712
+ Returns:
713
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
714
+ and H denotes the model dimension.
715
+ """
716
+ raise NotImplementedError("Subclasses must implement the forward method.")
717
+
718
+
719
+ class VocosBackbone(Backbone):
720
+ """
721
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
722
+
723
+ Args:
724
+ input_channels (int): Number of input features channels.
725
+ dim (int): Hidden dimension of the model.
726
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
727
+ num_layers (int): Number of ConvNeXtBlock layers.
728
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
729
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
730
+ None means non-conditional model. Defaults to None.
731
+ """
732
+
733
+ def __init__(
734
+ self,
735
+ input_channels: int,
736
+ dim: int,
737
+ intermediate_dim: int,
738
+ num_layers: int,
739
+ layer_scale_init_value: Optional[float] = None,
740
+ adanorm_num_embeddings: Optional[int] = None,
741
+ ):
742
+ super().__init__()
743
+ self.input_channels = input_channels
744
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
745
+ self.adanorm = adanorm_num_embeddings is not None
746
+ if adanorm_num_embeddings:
747
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
748
+ else:
749
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
750
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
751
+ self.convnext = nn.ModuleList(
752
+ [
753
+ ConvNeXtBlock(
754
+ dim=dim,
755
+ intermediate_dim=intermediate_dim,
756
+ layer_scale_init_value=layer_scale_init_value,
757
+ adanorm_num_embeddings=adanorm_num_embeddings,
758
+ )
759
+ for _ in range(num_layers)
760
+ ]
761
+ )
762
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
763
+ self.apply(self._init_weights)
764
+
765
+ def _init_weights(self, m):
766
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
767
+ nn.init.trunc_normal_(m.weight, std=0.02)
768
+ nn.init.constant_(m.bias, 0)
769
+
770
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
771
+ bandwidth_id = kwargs.get("bandwidth_id", None)
772
+ x = self.embed(x)
773
+ if self.adanorm:
774
+ assert bandwidth_id is not None
775
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
776
+ else:
777
+ x = self.norm(x.transpose(1, 2))
778
+ x = x.transpose(1, 2)
779
+ for conv_block in self.convnext:
780
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
781
+ x = self.final_layer_norm(x.transpose(1, 2))
782
+ return x
783
+
784
+
785
+ class VocosResNetBackbone(Backbone):
786
+ """
787
+ Vocos backbone module built with ResBlocks.
788
+
789
+ Args:
790
+ input_channels (int): Number of input features channels.
791
+ dim (int): Hidden dimension of the model.
792
+ num_blocks (int): Number of ResBlock1 blocks.
793
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
794
+ """
795
+
796
+ def __init__(
797
+ self,
798
+ input_channels,
799
+ dim,
800
+ num_blocks,
801
+ layer_scale_init_value=None,
802
+ ):
803
+ super().__init__()
804
+ self.input_channels = input_channels
805
+ self.embed = weight_norm(
806
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
807
+ )
808
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
809
+ self.resnet = nn.Sequential(
810
+ *[
811
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
812
+ for _ in range(num_blocks)
813
+ ]
814
+ )
815
+
816
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
817
+ x = self.embed(x)
818
+ x = self.resnet(x)
819
+ x = x.transpose(1, 2)
820
+ return x
821
+
822
+
823
+ class Vocos(nn.Module):
824
+ def __init__(
825
+ self,
826
+ input_channels: int = 256,
827
+ dim: int = 384,
828
+ intermediate_dim: int = 1152,
829
+ num_layers: int = 8,
830
+ adanorm_num_embeddings: int = 4,
831
+ n_fft: int = 800,
832
+ hop_size: int = 200,
833
+ padding: str = "same",
834
+ ):
835
+ super().__init__()
836
+
837
+ self.backbone = VocosBackbone(
838
+ input_channels=input_channels,
839
+ dim=dim,
840
+ intermediate_dim=intermediate_dim,
841
+ num_layers=num_layers,
842
+ adanorm_num_embeddings=adanorm_num_embeddings,
843
+ )
844
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
845
+
846
+ def forward(self, x):
847
+ x = self.backbone(x)
848
+ x = self.head(x)
849
+
850
+ return x[:, None, :]