xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,881 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import scipy
10
+ import torch
11
+ from torch import nn, view_as_real, view_as_complex
12
+ from torch import nn
13
+ from torch.nn.utils import weight_norm, remove_weight_norm
14
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
15
+ import librosa
16
+
17
+
18
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
19
+ """
20
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
21
+
22
+ Args:
23
+ x (Tensor): Input tensor.
24
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
25
+
26
+ Returns:
27
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
28
+ """
29
+ return torch.log(torch.clip(x, min=clip_val))
30
+
31
+
32
+ def symlog(x: torch.Tensor) -> torch.Tensor:
33
+ return torch.sign(x) * torch.log1p(x.abs())
34
+
35
+
36
+ def symexp(x: torch.Tensor) -> torch.Tensor:
37
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
38
+
39
+
40
+ class STFT(nn.Module):
41
+ def __init__(
42
+ self,
43
+ n_fft: int,
44
+ hop_length: int,
45
+ win_length: int,
46
+ center=True,
47
+ ):
48
+ super().__init__()
49
+ self.center = center
50
+ self.n_fft = n_fft
51
+ self.hop_length = hop_length
52
+ self.win_length = win_length
53
+ window = torch.hann_window(win_length)
54
+ self.register_buffer("window", window)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ # x: (B, T * hop_length)
58
+
59
+ if not self.center:
60
+ pad = self.win_length - self.hop_length
61
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
62
+
63
+ stft_spec = torch.stft(
64
+ x,
65
+ self.n_fft,
66
+ hop_length=self.hop_length,
67
+ win_length=self.win_length,
68
+ window=self.window,
69
+ center=self.center,
70
+ return_complex=False,
71
+ ) # (B, n_fft // 2 + 1, T, 2)
72
+
73
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
74
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
75
+
76
+ log_mag = torch.log(
77
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
78
+ ) # (B, n_fft // 2 + 1, T)
79
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
80
+
81
+ return log_mag, phase
82
+
83
+
84
+ class ISTFT(nn.Module):
85
+ """
86
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
87
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
88
+ See issue: https://github.com/pytorch/pytorch/issues/62323
89
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
90
+ The NOLA constraint is met as we trim padded samples anyway.
91
+
92
+ Args:
93
+ n_fft (int): Size of Fourier transform.
94
+ hop_length (int): The distance between neighboring sliding window frames.
95
+ win_length (int): The size of window frame and STFT filter.
96
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
97
+ """
98
+
99
+ def __init__(
100
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
101
+ ):
102
+ super().__init__()
103
+ if padding not in ["center", "same"]:
104
+ raise ValueError("Padding must be 'center' or 'same'.")
105
+ self.padding = padding
106
+ self.n_fft = n_fft
107
+ self.hop_length = hop_length
108
+ self.win_length = win_length
109
+ window = torch.hann_window(win_length)
110
+ self.register_buffer("window", window)
111
+
112
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
115
+
116
+ Args:
117
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
118
+ N is the number of frequency bins, and T is the number of time frames.
119
+
120
+ Returns:
121
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
122
+ """
123
+ if self.padding == "center":
124
+ # Fallback to pytorch native implementation
125
+ return torch.istft(
126
+ spec,
127
+ self.n_fft,
128
+ self.hop_length,
129
+ self.win_length,
130
+ self.window,
131
+ center=True,
132
+ )
133
+ elif self.padding == "same":
134
+ pad = (self.win_length - self.hop_length) // 2
135
+ else:
136
+ raise ValueError("Padding must be 'center' or 'same'.")
137
+
138
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
139
+ B, N, T = spec.shape
140
+
141
+ # Inverse FFT
142
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
143
+ ifft = ifft * self.window[None, :, None]
144
+
145
+ # Overlap and Add
146
+ output_size = (T - 1) * self.hop_length + self.win_length
147
+ y = torch.nn.functional.fold(
148
+ ifft,
149
+ output_size=(1, output_size),
150
+ kernel_size=(1, self.win_length),
151
+ stride=(1, self.hop_length),
152
+ )[:, 0, 0, pad:-pad]
153
+
154
+ # Window envelope
155
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
156
+ window_envelope = torch.nn.functional.fold(
157
+ window_sq,
158
+ output_size=(1, output_size),
159
+ kernel_size=(1, self.win_length),
160
+ stride=(1, self.hop_length),
161
+ ).squeeze()[pad:-pad]
162
+
163
+ # Normalize
164
+ assert (window_envelope > 1e-11).all()
165
+ y = y / window_envelope
166
+
167
+ return y
168
+
169
+
170
+ class MDCT(nn.Module):
171
+ """
172
+ Modified Discrete Cosine Transform (MDCT) module.
173
+
174
+ Args:
175
+ frame_len (int): Length of the MDCT frame.
176
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
177
+ """
178
+
179
+ def __init__(self, frame_len: int, padding: str = "same"):
180
+ super().__init__()
181
+ if padding not in ["center", "same"]:
182
+ raise ValueError("Padding must be 'center' or 'same'.")
183
+ self.padding = padding
184
+ self.frame_len = frame_len
185
+ N = frame_len // 2
186
+ n0 = (N + 1) / 2
187
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
188
+ self.register_buffer("window", window)
189
+
190
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
191
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
192
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
193
+ # https://github.com/pytorch/pytorch/issues/71613
194
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
195
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
196
+
197
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
200
+
201
+ Args:
202
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
203
+ and T is the length of the audio.
204
+
205
+ Returns:
206
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
207
+ and N is the number of frequency bins.
208
+ """
209
+ if self.padding == "center":
210
+ audio = torch.nn.functional.pad(
211
+ audio, (self.frame_len // 2, self.frame_len // 2)
212
+ )
213
+ elif self.padding == "same":
214
+ # hop_length is 1/2 frame_len
215
+ audio = torch.nn.functional.pad(
216
+ audio, (self.frame_len // 4, self.frame_len // 4)
217
+ )
218
+ else:
219
+ raise ValueError("Padding must be 'center' or 'same'.")
220
+
221
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
222
+ N = self.frame_len // 2
223
+ x = x * self.window.expand(x.shape)
224
+ X = torch.fft.fft(
225
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
226
+ )[..., :N]
227
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
228
+ return torch.real(res) * np.sqrt(2)
229
+
230
+
231
+ class IMDCT(nn.Module):
232
+ """
233
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
234
+
235
+ Args:
236
+ frame_len (int): Length of the MDCT frame.
237
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
238
+ """
239
+
240
+ def __init__(self, frame_len: int, padding: str = "same"):
241
+ super().__init__()
242
+ if padding not in ["center", "same"]:
243
+ raise ValueError("Padding must be 'center' or 'same'.")
244
+ self.padding = padding
245
+ self.frame_len = frame_len
246
+ N = frame_len // 2
247
+ n0 = (N + 1) / 2
248
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
249
+ self.register_buffer("window", window)
250
+
251
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
252
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
253
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
254
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
255
+
256
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
259
+
260
+ Args:
261
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
262
+ L is the number of frames, and N is the number of frequency bins.
263
+
264
+ Returns:
265
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
266
+ """
267
+ B, L, N = X.shape
268
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
269
+ Y[..., :N] = X
270
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
271
+ y = torch.fft.ifft(
272
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
273
+ )
274
+ y = (
275
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
276
+ * np.sqrt(N)
277
+ * np.sqrt(2)
278
+ )
279
+ result = y * self.window.expand(y.shape)
280
+ output_size = (1, (L + 1) * N)
281
+ audio = torch.nn.functional.fold(
282
+ result.transpose(1, 2),
283
+ output_size=output_size,
284
+ kernel_size=(1, self.frame_len),
285
+ stride=(1, self.frame_len // 2),
286
+ )[:, 0, 0, :]
287
+
288
+ if self.padding == "center":
289
+ pad = self.frame_len // 2
290
+ elif self.padding == "same":
291
+ pad = self.frame_len // 4
292
+ else:
293
+ raise ValueError("Padding must be 'center' or 'same'.")
294
+
295
+ audio = audio[:, pad:-pad]
296
+ return audio
297
+
298
+
299
+ class FourierHead(nn.Module):
300
+ """Base class for inverse fourier modules."""
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Args:
305
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
306
+ L is the sequence length, and H denotes the model dimension.
307
+
308
+ Returns:
309
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
310
+ """
311
+ raise NotImplementedError("Subclasses must implement the forward method.")
312
+
313
+
314
+ class ISTFTHead(FourierHead):
315
+ """
316
+ ISTFT Head module for predicting STFT complex coefficients.
317
+
318
+ Args:
319
+ dim (int): Hidden dimension of the model.
320
+ n_fft (int): Size of Fourier transform.
321
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
322
+ the resolution of the input features.
323
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
324
+ """
325
+
326
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
327
+ super().__init__()
328
+ out_dim = n_fft + 2
329
+ self.out = torch.nn.Linear(dim, out_dim)
330
+ self.istft = ISTFT(
331
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
332
+ )
333
+
334
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
335
+ """
336
+ Forward pass of the ISTFTHead module.
337
+
338
+ Args:
339
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
340
+ L is the sequence length, and H denotes the model dimension.
341
+
342
+ Returns:
343
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
344
+ """
345
+ x = self.out(x).transpose(1, 2)
346
+ mag, p = x.chunk(2, dim=1)
347
+ mag = torch.exp(mag)
348
+ mag = torch.clip(
349
+ mag, max=1e2
350
+ ) # safeguard to prevent excessively large magnitudes
351
+ # wrapping happens here. These two lines produce real and imaginary value
352
+ x = torch.cos(p)
353
+ y = torch.sin(p)
354
+ # recalculating phase here does not produce anything new
355
+ # only costs time
356
+ # phase = torch.atan2(y, x)
357
+ # S = mag * torch.exp(phase * 1j)
358
+ # better directly produce the complex value
359
+ S = mag * (x + 1j * y)
360
+ audio = self.istft(S)
361
+ return audio
362
+
363
+
364
+ class IMDCTSymExpHead(FourierHead):
365
+ """
366
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
367
+
368
+ Args:
369
+ dim (int): Hidden dimension of the model.
370
+ mdct_frame_len (int): Length of the MDCT frame.
371
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
372
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
373
+ based on perceptual scaling. Defaults to None.
374
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ dim: int,
380
+ mdct_frame_len: int,
381
+ padding: str = "same",
382
+ sample_rate: Optional[int] = None,
383
+ clip_audio: bool = False,
384
+ ):
385
+ super().__init__()
386
+ out_dim = mdct_frame_len // 2
387
+ self.out = nn.Linear(dim, out_dim)
388
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
389
+ self.clip_audio = clip_audio
390
+
391
+ if sample_rate is not None:
392
+ # optionally init the last layer following mel-scale
393
+ m_max = _hz_to_mel(sample_rate // 2)
394
+ m_pts = torch.linspace(0, m_max, out_dim)
395
+ f_pts = _mel_to_hz(m_pts)
396
+ scale = 1 - (f_pts / f_pts.max())
397
+
398
+ with torch.no_grad():
399
+ self.out.weight.mul_(scale.view(-1, 1))
400
+
401
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
402
+ """
403
+ Forward pass of the IMDCTSymExpHead module.
404
+
405
+ Args:
406
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
407
+ L is the sequence length, and H denotes the model dimension.
408
+
409
+ Returns:
410
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
411
+ """
412
+ x = self.out(x)
413
+ x = symexp(x)
414
+ x = torch.clip(
415
+ x, min=-1e2, max=1e2
416
+ ) # safeguard to prevent excessively large magnitudes
417
+ audio = self.imdct(x)
418
+ if self.clip_audio:
419
+ audio = torch.clip(x, min=-1.0, max=1.0)
420
+
421
+ return audio
422
+
423
+
424
+ class IMDCTCosHead(FourierHead):
425
+ """
426
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
427
+
428
+ Args:
429
+ dim (int): Hidden dimension of the model.
430
+ mdct_frame_len (int): Length of the MDCT frame.
431
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
432
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ dim: int,
438
+ mdct_frame_len: int,
439
+ padding: str = "same",
440
+ clip_audio: bool = False,
441
+ ):
442
+ super().__init__()
443
+ self.clip_audio = clip_audio
444
+ self.out = nn.Linear(dim, mdct_frame_len)
445
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
446
+
447
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
448
+ """
449
+ Forward pass of the IMDCTCosHead module.
450
+
451
+ Args:
452
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
453
+ L is the sequence length, and H denotes the model dimension.
454
+
455
+ Returns:
456
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
457
+ """
458
+ x = self.out(x)
459
+ m, p = x.chunk(2, dim=2)
460
+ m = torch.exp(m).clip(
461
+ max=1e2
462
+ ) # safeguard to prevent excessively large magnitudes
463
+ audio = self.imdct(m * torch.cos(p))
464
+ if self.clip_audio:
465
+ audio = torch.clip(x, min=-1.0, max=1.0)
466
+ return audio
467
+
468
+
469
+ class ConvNeXtBlock(nn.Module):
470
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
471
+
472
+ Args:
473
+ dim (int): Number of input channels.
474
+ intermediate_dim (int): Dimensionality of the intermediate layer.
475
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
476
+ Defaults to None.
477
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
478
+ None means non-conditional LayerNorm. Defaults to None.
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ dim: int,
484
+ intermediate_dim: int,
485
+ layer_scale_init_value: float,
486
+ adanorm_num_embeddings: Optional[int] = None,
487
+ ):
488
+ super().__init__()
489
+ self.dwconv = nn.Conv1d(
490
+ dim, dim, kernel_size=7, padding=3, groups=dim
491
+ ) # depthwise conv
492
+ self.adanorm = adanorm_num_embeddings is not None
493
+ if adanorm_num_embeddings:
494
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
495
+ else:
496
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
497
+ self.pwconv1 = nn.Linear(
498
+ dim, intermediate_dim
499
+ ) # pointwise/1x1 convs, implemented with linear layers
500
+ self.act = nn.GELU()
501
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
502
+ self.gamma = (
503
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
504
+ if layer_scale_init_value > 0
505
+ else None
506
+ )
507
+
508
+ def forward(
509
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
510
+ ) -> torch.Tensor:
511
+ residual = x
512
+ x = self.dwconv(x)
513
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
514
+ if self.adanorm:
515
+ assert cond_embedding_id is not None
516
+ x = self.norm(x, cond_embedding_id)
517
+ else:
518
+ x = self.norm(x)
519
+ x = self.pwconv1(x)
520
+ x = self.act(x)
521
+ x = self.pwconv2(x)
522
+ if self.gamma is not None:
523
+ x = self.gamma * x
524
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
525
+
526
+ x = residual + x
527
+ return x
528
+
529
+
530
+ class AdaLayerNorm(nn.Module):
531
+ """
532
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
533
+
534
+ Args:
535
+ num_embeddings (int): Number of embeddings.
536
+ embedding_dim (int): Dimension of the embeddings.
537
+ """
538
+
539
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
540
+ super().__init__()
541
+ self.eps = eps
542
+ self.dim = embedding_dim
543
+ self.scale = nn.Embedding(
544
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
545
+ )
546
+ self.shift = nn.Embedding(
547
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
548
+ )
549
+ torch.nn.init.ones_(self.scale.weight)
550
+ torch.nn.init.zeros_(self.shift.weight)
551
+
552
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
553
+ scale = self.scale(cond_embedding_id)
554
+ shift = self.shift(cond_embedding_id)
555
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
556
+ x = x * scale + shift
557
+ return x
558
+
559
+
560
+ class ResBlock1(nn.Module):
561
+ """
562
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
563
+ but without upsampling layers.
564
+
565
+ Args:
566
+ dim (int): Number of input channels.
567
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
568
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
569
+ Defaults to (1, 3, 5).
570
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
571
+ Defaults to 0.1.
572
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
573
+ Defaults to None.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ dim: int,
579
+ kernel_size: int = 3,
580
+ dilation: Tuple[int, int, int] = (1, 3, 5),
581
+ lrelu_slope: float = 0.1,
582
+ layer_scale_init_value: Optional[float] = None,
583
+ ):
584
+ super().__init__()
585
+ self.lrelu_slope = lrelu_slope
586
+ self.convs1 = nn.ModuleList(
587
+ [
588
+ weight_norm(
589
+ nn.Conv1d(
590
+ dim,
591
+ dim,
592
+ kernel_size,
593
+ 1,
594
+ dilation=dilation[0],
595
+ padding=self.get_padding(kernel_size, dilation[0]),
596
+ )
597
+ ),
598
+ weight_norm(
599
+ nn.Conv1d(
600
+ dim,
601
+ dim,
602
+ kernel_size,
603
+ 1,
604
+ dilation=dilation[1],
605
+ padding=self.get_padding(kernel_size, dilation[1]),
606
+ )
607
+ ),
608
+ weight_norm(
609
+ nn.Conv1d(
610
+ dim,
611
+ dim,
612
+ kernel_size,
613
+ 1,
614
+ dilation=dilation[2],
615
+ padding=self.get_padding(kernel_size, dilation[2]),
616
+ )
617
+ ),
618
+ ]
619
+ )
620
+
621
+ self.convs2 = nn.ModuleList(
622
+ [
623
+ weight_norm(
624
+ nn.Conv1d(
625
+ dim,
626
+ dim,
627
+ kernel_size,
628
+ 1,
629
+ dilation=1,
630
+ padding=self.get_padding(kernel_size, 1),
631
+ )
632
+ ),
633
+ weight_norm(
634
+ nn.Conv1d(
635
+ dim,
636
+ dim,
637
+ kernel_size,
638
+ 1,
639
+ dilation=1,
640
+ padding=self.get_padding(kernel_size, 1),
641
+ )
642
+ ),
643
+ weight_norm(
644
+ nn.Conv1d(
645
+ dim,
646
+ dim,
647
+ kernel_size,
648
+ 1,
649
+ dilation=1,
650
+ padding=self.get_padding(kernel_size, 1),
651
+ )
652
+ ),
653
+ ]
654
+ )
655
+
656
+ self.gamma = nn.ParameterList(
657
+ [
658
+ (
659
+ nn.Parameter(
660
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
661
+ )
662
+ if layer_scale_init_value is not None
663
+ else None
664
+ ),
665
+ (
666
+ nn.Parameter(
667
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
668
+ )
669
+ if layer_scale_init_value is not None
670
+ else None
671
+ ),
672
+ (
673
+ nn.Parameter(
674
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
675
+ )
676
+ if layer_scale_init_value is not None
677
+ else None
678
+ ),
679
+ ]
680
+ )
681
+
682
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
683
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
684
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
685
+ xt = c1(xt)
686
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
687
+ xt = c2(xt)
688
+ if gamma is not None:
689
+ xt = gamma * xt
690
+ x = xt + x
691
+ return x
692
+
693
+ def remove_weight_norm(self):
694
+ for l in self.convs1:
695
+ remove_weight_norm(l)
696
+ for l in self.convs2:
697
+ remove_weight_norm(l)
698
+
699
+ @staticmethod
700
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
701
+ return int((kernel_size * dilation - dilation) / 2)
702
+
703
+
704
+ class Backbone(nn.Module):
705
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
706
+
707
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
708
+ """
709
+ Args:
710
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
711
+ C denotes output features, and L is the sequence length.
712
+
713
+ Returns:
714
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
715
+ and H denotes the model dimension.
716
+ """
717
+ raise NotImplementedError("Subclasses must implement the forward method.")
718
+
719
+
720
+ class VocosBackbone(Backbone):
721
+ """
722
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
723
+
724
+ Args:
725
+ input_channels (int): Number of input features channels.
726
+ dim (int): Hidden dimension of the model.
727
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
728
+ num_layers (int): Number of ConvNeXtBlock layers.
729
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
730
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
731
+ None means non-conditional model. Defaults to None.
732
+ """
733
+
734
+ def __init__(
735
+ self,
736
+ input_channels: int,
737
+ dim: int,
738
+ intermediate_dim: int,
739
+ num_layers: int,
740
+ layer_scale_init_value: Optional[float] = None,
741
+ adanorm_num_embeddings: Optional[int] = None,
742
+ ):
743
+ super().__init__()
744
+ self.input_channels = input_channels
745
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
746
+ self.adanorm = adanorm_num_embeddings is not None
747
+ if adanorm_num_embeddings:
748
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
749
+ else:
750
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
751
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
752
+ self.convnext = nn.ModuleList(
753
+ [
754
+ ConvNeXtBlock(
755
+ dim=dim,
756
+ intermediate_dim=intermediate_dim,
757
+ layer_scale_init_value=layer_scale_init_value,
758
+ adanorm_num_embeddings=adanorm_num_embeddings,
759
+ )
760
+ for _ in range(num_layers)
761
+ ]
762
+ )
763
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
768
+ nn.init.trunc_normal_(m.weight, std=0.02)
769
+ nn.init.constant_(m.bias, 0)
770
+
771
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
772
+ bandwidth_id = kwargs.get("bandwidth_id", None)
773
+ x = self.embed(x)
774
+ if self.adanorm:
775
+ assert bandwidth_id is not None
776
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
777
+ else:
778
+ x = self.norm(x.transpose(1, 2))
779
+ x = x.transpose(1, 2)
780
+ for conv_block in self.convnext:
781
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
782
+ x = self.final_layer_norm(x.transpose(1, 2))
783
+ return x
784
+
785
+
786
+ class VocosResNetBackbone(Backbone):
787
+ """
788
+ Vocos backbone module built with ResBlocks.
789
+
790
+ Args:
791
+ input_channels (int): Number of input features channels.
792
+ dim (int): Hidden dimension of the model.
793
+ num_blocks (int): Number of ResBlock1 blocks.
794
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
795
+ """
796
+
797
+ def __init__(
798
+ self,
799
+ input_channels,
800
+ dim,
801
+ num_blocks,
802
+ layer_scale_init_value=None,
803
+ ):
804
+ super().__init__()
805
+ self.input_channels = input_channels
806
+ self.embed = weight_norm(
807
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
808
+ )
809
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
810
+ self.resnet = nn.Sequential(
811
+ *[
812
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
813
+ for _ in range(num_blocks)
814
+ ]
815
+ )
816
+
817
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
818
+ x = self.embed(x)
819
+ x = self.resnet(x)
820
+ x = x.transpose(1, 2)
821
+ return x
822
+
823
+
824
+ class Vocos(nn.Module):
825
+ def __init__(
826
+ self,
827
+ input_channels: int = 256,
828
+ dim: int = 384,
829
+ intermediate_dim: int = 1152,
830
+ num_layers: int = 8,
831
+ n_fft: int = 800,
832
+ hop_size: int = 200,
833
+ padding: str = "same",
834
+ adanorm_num_embeddings=None,
835
+ cfg=None,
836
+ ):
837
+ super().__init__()
838
+
839
+ input_channels = (
840
+ cfg.input_channels
841
+ if cfg is not None and hasattr(cfg, "input_channels")
842
+ else input_channels
843
+ )
844
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
845
+ intermediate_dim = (
846
+ cfg.intermediate_dim
847
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
848
+ else intermediate_dim
849
+ )
850
+ num_layers = (
851
+ cfg.num_layers
852
+ if cfg is not None and hasattr(cfg, "num_layers")
853
+ else num_layers
854
+ )
855
+ adanorm_num_embeddings = (
856
+ cfg.adanorm_num_embeddings
857
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
858
+ else adanorm_num_embeddings
859
+ )
860
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
861
+ hop_size = (
862
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
863
+ )
864
+ padding = (
865
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
866
+ )
867
+
868
+ self.backbone = VocosBackbone(
869
+ input_channels=input_channels,
870
+ dim=dim,
871
+ intermediate_dim=intermediate_dim,
872
+ num_layers=num_layers,
873
+ adanorm_num_embeddings=adanorm_num_embeddings,
874
+ )
875
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
876
+
877
+ def forward(self, x):
878
+ x = self.backbone(x)
879
+ x = self.head(x)
880
+
881
+ return x[:, None, :]