xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
6
+
7
+ from .spectral_ops import IMDCT, ISTFT
8
+ from .modules import symexp
9
+
10
+
11
+ class FourierHead(nn.Module):
12
+ """Base class for inverse fourier modules."""
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
18
+ L is the sequence length, and H denotes the model dimension.
19
+
20
+ Returns:
21
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
22
+ """
23
+ raise NotImplementedError("Subclasses must implement the forward method.")
24
+
25
+
26
+ class ISTFTHead(FourierHead):
27
+ """
28
+ ISTFT Head module for predicting STFT complex coefficients.
29
+
30
+ Args:
31
+ dim (int): Hidden dimension of the model.
32
+ n_fft (int): Size of Fourier transform.
33
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
34
+ the resolution of the input features.
35
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
36
+ """
37
+
38
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
39
+ super().__init__()
40
+ out_dim = n_fft + 2
41
+ self.out = torch.nn.Linear(dim, out_dim)
42
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Forward pass of the ISTFTHead module.
47
+
48
+ Args:
49
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
50
+ L is the sequence length, and H denotes the model dimension.
51
+
52
+ Returns:
53
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
54
+ """
55
+ x = self.out(x).transpose(1, 2)
56
+ mag, p = x.chunk(2, dim=1)
57
+ mag = torch.exp(mag)
58
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
59
+ # wrapping happens here. These two lines produce real and imaginary value
60
+ x = torch.cos(p)
61
+ y = torch.sin(p)
62
+ # recalculating phase here does not produce anything new
63
+ # only costs time
64
+ # phase = torch.atan2(y, x)
65
+ # S = mag * torch.exp(phase * 1j)
66
+ # better directly produce the complex value
67
+ S = mag * (x + 1j * y)
68
+ audio = self.istft(S)
69
+ return audio
70
+
71
+
72
+ class IMDCTSymExpHead(FourierHead):
73
+ """
74
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
75
+
76
+ Args:
77
+ dim (int): Hidden dimension of the model.
78
+ mdct_frame_len (int): Length of the MDCT frame.
79
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
80
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
81
+ based on perceptual scaling. Defaults to None.
82
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ mdct_frame_len: int,
89
+ padding: str = "same",
90
+ sample_rate: Optional[int] = None,
91
+ clip_audio: bool = False,
92
+ ):
93
+ super().__init__()
94
+ out_dim = mdct_frame_len // 2
95
+ self.out = nn.Linear(dim, out_dim)
96
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
97
+ self.clip_audio = clip_audio
98
+
99
+ if sample_rate is not None:
100
+ # optionally init the last layer following mel-scale
101
+ m_max = _hz_to_mel(sample_rate // 2)
102
+ m_pts = torch.linspace(0, m_max, out_dim)
103
+ f_pts = _mel_to_hz(m_pts)
104
+ scale = 1 - (f_pts / f_pts.max())
105
+
106
+ with torch.no_grad():
107
+ self.out.weight.mul_(scale.view(-1, 1))
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Forward pass of the IMDCTSymExpHead module.
112
+
113
+ Args:
114
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
115
+ L is the sequence length, and H denotes the model dimension.
116
+
117
+ Returns:
118
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
119
+ """
120
+ x = self.out(x)
121
+ x = symexp(x)
122
+ x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
123
+ audio = self.imdct(x)
124
+ if self.clip_audio:
125
+ audio = torch.clip(x, min=-1.0, max=1.0)
126
+
127
+ return audio
128
+
129
+
130
+ class IMDCTCosHead(FourierHead):
131
+ """
132
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
133
+
134
+ Args:
135
+ dim (int): Hidden dimension of the model.
136
+ mdct_frame_len (int): Length of the MDCT frame.
137
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
138
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
139
+ """
140
+
141
+ def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
142
+ super().__init__()
143
+ self.clip_audio = clip_audio
144
+ self.out = nn.Linear(dim, mdct_frame_len)
145
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Forward pass of the IMDCTCosHead module.
150
+
151
+ Args:
152
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
153
+ L is the sequence length, and H denotes the model dimension.
154
+
155
+ Returns:
156
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
157
+ """
158
+ x = self.out(x)
159
+ m, p = x.chunk(2, dim=2)
160
+ m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
161
+ audio = self.imdct(m * torch.cos(p))
162
+ if self.clip_audio:
163
+ audio = torch.clip(x, min=-1.0, max=1.0)
164
+ return audio
@@ -0,0 +1,71 @@
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_lightning import Callback
6
+
7
+ matplotlib.use("Agg")
8
+
9
+
10
+ def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
11
+ """
12
+ Save a matplotlib figure to a numpy array.
13
+
14
+ Args:
15
+ fig (Figure): Matplotlib figure object.
16
+
17
+ Returns:
18
+ ndarray: Numpy array representing the figure.
19
+ """
20
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
21
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
22
+ return data
23
+
24
+
25
+ def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
26
+ """
27
+ Plot a spectrogram and convert it to a numpy array.
28
+
29
+ Args:
30
+ spectrogram (ndarray): Spectrogram data.
31
+
32
+ Returns:
33
+ ndarray: Numpy array representing the plotted spectrogram.
34
+ """
35
+ spectrogram = spectrogram.astype(np.float32)
36
+ fig, ax = plt.subplots(figsize=(12, 3))
37
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
38
+ plt.colorbar(im, ax=ax)
39
+ plt.xlabel("Frames")
40
+ plt.ylabel("Channels")
41
+ plt.tight_layout()
42
+
43
+ fig.canvas.draw()
44
+ data = save_figure_to_numpy(fig)
45
+ plt.close()
46
+ return data
47
+
48
+
49
+ class GradNormCallback(Callback):
50
+ """
51
+ Callback to log the gradient norm.
52
+ """
53
+
54
+ def on_after_backward(self, trainer, model):
55
+ model.log("grad_norm", gradient_norm(model))
56
+
57
+
58
+ def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
59
+ """
60
+ Compute the gradient norm.
61
+
62
+ Args:
63
+ model (Module): PyTorch model.
64
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
65
+
66
+ Returns:
67
+ Tensor: Gradient norm.
68
+ """
69
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
70
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
71
+ return total_norm
@@ -0,0 +1,114 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from vocos.modules import safe_log
8
+
9
+
10
+ class MelSpecReconstructionLoss(nn.Module):
11
+ """
12
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
13
+ """
14
+
15
+ def __init__(
16
+ self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
17
+ ):
18
+ super().__init__()
19
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
21
+ )
22
+
23
+ def forward(self, y_hat, y) -> torch.Tensor:
24
+ """
25
+ Args:
26
+ y_hat (Tensor): Predicted audio waveform.
27
+ y (Tensor): Ground truth audio waveform.
28
+
29
+ Returns:
30
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
31
+ """
32
+ mel_hat = safe_log(self.mel_spec(y_hat))
33
+ mel = safe_log(self.mel_spec(y))
34
+
35
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
36
+
37
+ return loss
38
+
39
+
40
+ class GeneratorLoss(nn.Module):
41
+ """
42
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
43
+ """
44
+
45
+ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
46
+ """
47
+ Args:
48
+ disc_outputs (List[Tensor]): List of discriminator outputs.
49
+
50
+ Returns:
51
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
52
+ the sub-discriminators
53
+ """
54
+ loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
55
+ gen_losses = []
56
+ for dg in disc_outputs:
57
+ l = torch.mean(torch.clamp(1 - dg, min=0))
58
+ gen_losses.append(l)
59
+ loss += l
60
+
61
+ return loss, gen_losses
62
+
63
+
64
+ class DiscriminatorLoss(nn.Module):
65
+ """
66
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
67
+ """
68
+
69
+ def forward(
70
+ self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
71
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
72
+ """
73
+ Args:
74
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
75
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
76
+
77
+ Returns:
78
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
79
+ the sub-discriminators for real outputs, and a list of
80
+ loss values for generated outputs.
81
+ """
82
+ loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
83
+ r_losses = []
84
+ g_losses = []
85
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
86
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
87
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
88
+ loss += r_loss + g_loss
89
+ r_losses.append(r_loss)
90
+ g_losses.append(g_loss)
91
+
92
+ return loss, r_losses, g_losses
93
+
94
+
95
+ class FeatureMatchingLoss(nn.Module):
96
+ """
97
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
98
+ """
99
+
100
+ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
104
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
105
+
106
+ Returns:
107
+ Tensor: The calculated feature matching loss.
108
+ """
109
+ loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
110
+ for dr, dg in zip(fmap_r, fmap_g):
111
+ for rl, gl in zip(dr, dg):
112
+ loss += torch.mean(torch.abs(rl - gl))
113
+
114
+ return loss
@@ -0,0 +1,118 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from .modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
8
+
9
+
10
+ class Backbone(nn.Module):
11
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
12
+
13
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
14
+ """
15
+ Args:
16
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
17
+ C denotes output features, and L is the sequence length.
18
+
19
+ Returns:
20
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
21
+ and H denotes the model dimension.
22
+ """
23
+ raise NotImplementedError("Subclasses must implement the forward method.")
24
+
25
+
26
+ class VocosBackbone(Backbone):
27
+ """
28
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
29
+
30
+ Args:
31
+ input_channels (int): Number of input features channels.
32
+ dim (int): Hidden dimension of the model.
33
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
34
+ num_layers (int): Number of ConvNeXtBlock layers.
35
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
36
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
37
+ None means non-conditional model. Defaults to None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ input_channels: int,
43
+ dim: int,
44
+ intermediate_dim: int,
45
+ num_layers: int,
46
+ layer_scale_init_value: Optional[float] = None,
47
+ adanorm_num_embeddings: Optional[int] = None,
48
+ ):
49
+ super().__init__()
50
+ self.input_channels = input_channels
51
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
52
+ self.adanorm = adanorm_num_embeddings is not None
53
+ if adanorm_num_embeddings:
54
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
55
+ else:
56
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
57
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
58
+ self.convnext = nn.ModuleList(
59
+ [
60
+ ConvNeXtBlock(
61
+ dim=dim,
62
+ intermediate_dim=intermediate_dim,
63
+ layer_scale_init_value=layer_scale_init_value,
64
+ adanorm_num_embeddings=adanorm_num_embeddings,
65
+ )
66
+ for _ in range(num_layers)
67
+ ]
68
+ )
69
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
70
+ self.apply(self._init_weights)
71
+
72
+ def _init_weights(self, m):
73
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
74
+ nn.init.trunc_normal_(m.weight, std=0.02)
75
+ nn.init.constant_(m.bias, 0)
76
+
77
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
78
+ bandwidth_id = kwargs.get('bandwidth_id', None)
79
+ x = self.embed(x)
80
+ if self.adanorm:
81
+ assert bandwidth_id is not None
82
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
83
+ else:
84
+ x = self.norm(x.transpose(1, 2))
85
+ x = x.transpose(1, 2)
86
+ for conv_block in self.convnext:
87
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
88
+ x = self.final_layer_norm(x.transpose(1, 2))
89
+ return x
90
+
91
+
92
+ class VocosResNetBackbone(Backbone):
93
+ """
94
+ Vocos backbone module built with ResBlocks.
95
+
96
+ Args:
97
+ input_channels (int): Number of input features channels.
98
+ dim (int): Hidden dimension of the model.
99
+ num_blocks (int): Number of ResBlock1 blocks.
100
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
101
+ """
102
+
103
+ def __init__(
104
+ self, input_channels, dim, num_blocks, layer_scale_init_value=None,
105
+ ):
106
+ super().__init__()
107
+ self.input_channels = input_channels
108
+ self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
109
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
110
+ self.resnet = nn.Sequential(
111
+ *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
112
+ )
113
+
114
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
115
+ x = self.embed(x)
116
+ x = self.resnet(x)
117
+ x = x.transpose(1, 2)
118
+ return x
@@ -0,0 +1,213 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+
8
+ class ConvNeXtBlock(nn.Module):
9
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ intermediate_dim (int): Dimensionality of the intermediate layer.
14
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
15
+ Defaults to None.
16
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
17
+ None means non-conditional LayerNorm. Defaults to None.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dim: int,
23
+ intermediate_dim: int,
24
+ layer_scale_init_value: float,
25
+ adanorm_num_embeddings: Optional[int] = None,
26
+ ):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.adanorm = adanorm_num_embeddings is not None
30
+ if adanorm_num_embeddings:
31
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
32
+ else:
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
37
+ self.gamma = (
38
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
39
+ if layer_scale_init_value > 0
40
+ else None
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
44
+ residual = x
45
+ x = self.dwconv(x)
46
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
47
+ if self.adanorm:
48
+ assert cond_embedding_id is not None
49
+ x = self.norm(x, cond_embedding_id)
50
+ else:
51
+ x = self.norm(x)
52
+ x = self.pwconv1(x)
53
+ x = self.act(x)
54
+ x = self.pwconv2(x)
55
+ if self.gamma is not None:
56
+ x = self.gamma * x
57
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
58
+
59
+ x = residual + x
60
+ return x
61
+
62
+
63
+ class AdaLayerNorm(nn.Module):
64
+ """
65
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
66
+
67
+ Args:
68
+ num_embeddings (int): Number of embeddings.
69
+ embedding_dim (int): Dimension of the embeddings.
70
+ """
71
+
72
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
73
+ super().__init__()
74
+ self.eps = eps
75
+ self.dim = embedding_dim
76
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
77
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
78
+ torch.nn.init.ones_(self.scale.weight)
79
+ torch.nn.init.zeros_(self.shift.weight)
80
+
81
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
82
+ scale = self.scale(cond_embedding_id)
83
+ shift = self.shift(cond_embedding_id)
84
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
85
+ x = x * scale + shift
86
+ return x
87
+
88
+
89
+ class ResBlock1(nn.Module):
90
+ """
91
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
92
+ but without upsampling layers.
93
+
94
+ Args:
95
+ dim (int): Number of input channels.
96
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
97
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
98
+ Defaults to (1, 3, 5).
99
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
100
+ Defaults to 0.1.
101
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
102
+ Defaults to None.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ kernel_size: int = 3,
109
+ dilation: Tuple[int, int, int] = (1, 3, 5),
110
+ lrelu_slope: float = 0.1,
111
+ layer_scale_init_value: Optional[float] = None,
112
+ ):
113
+ super().__init__()
114
+ self.lrelu_slope = lrelu_slope
115
+ self.convs1 = nn.ModuleList(
116
+ [
117
+ weight_norm(
118
+ nn.Conv1d(
119
+ dim,
120
+ dim,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation[0],
124
+ padding=self.get_padding(kernel_size, dilation[0]),
125
+ )
126
+ ),
127
+ weight_norm(
128
+ nn.Conv1d(
129
+ dim,
130
+ dim,
131
+ kernel_size,
132
+ 1,
133
+ dilation=dilation[1],
134
+ padding=self.get_padding(kernel_size, dilation[1]),
135
+ )
136
+ ),
137
+ weight_norm(
138
+ nn.Conv1d(
139
+ dim,
140
+ dim,
141
+ kernel_size,
142
+ 1,
143
+ dilation=dilation[2],
144
+ padding=self.get_padding(kernel_size, dilation[2]),
145
+ )
146
+ ),
147
+ ]
148
+ )
149
+
150
+ self.convs2 = nn.ModuleList(
151
+ [
152
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
153
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
154
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
155
+ ]
156
+ )
157
+
158
+ self.gamma = nn.ParameterList(
159
+ [
160
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
161
+ if layer_scale_init_value is not None
162
+ else None,
163
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
164
+ if layer_scale_init_value is not None
165
+ else None,
166
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
167
+ if layer_scale_init_value is not None
168
+ else None,
169
+ ]
170
+ )
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
174
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
175
+ xt = c1(xt)
176
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
177
+ xt = c2(xt)
178
+ if gamma is not None:
179
+ xt = gamma * xt
180
+ x = xt + x
181
+ return x
182
+
183
+ def remove_weight_norm(self):
184
+ for l in self.convs1:
185
+ remove_weight_norm(l)
186
+ for l in self.convs2:
187
+ remove_weight_norm(l)
188
+
189
+ @staticmethod
190
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
191
+ return int((kernel_size * dilation - dilation) / 2)
192
+
193
+
194
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
195
+ """
196
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
197
+
198
+ Args:
199
+ x (Tensor): Input tensor.
200
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
201
+
202
+ Returns:
203
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
204
+ """
205
+ return torch.log(torch.clip(x, min=clip_val))
206
+
207
+
208
+ def symlog(x: torch.Tensor) -> torch.Tensor:
209
+ return torch.sign(x) * torch.log1p(x.abs())
210
+
211
+
212
+ def symexp(x: torch.Tensor) -> torch.Tensor:
213
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)