xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This source file is copied from https://github.com/facebookresearch/encodec
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+
13
+ """Torch distributed utilities."""
14
+
15
+ import typing as tp
16
+
17
+ import torch
18
+
19
+
20
+ def rank():
21
+ if torch.distributed.is_initialized():
22
+ return torch.distributed.get_rank()
23
+ else:
24
+ return 0
25
+
26
+
27
+ def world_size():
28
+ if torch.distributed.is_initialized():
29
+ return torch.distributed.get_world_size()
30
+ else:
31
+ return 1
32
+
33
+
34
+ def is_distributed():
35
+ return world_size() > 1
36
+
37
+
38
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
39
+ if is_distributed():
40
+ return torch.distributed.all_reduce(tensor, op)
41
+
42
+
43
+ def _is_complex_or_float(tensor):
44
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
45
+
46
+
47
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
48
+ # utility function to check that the number of params in all workers is the same,
49
+ # and thus avoid a deadlock with distributed all reduce.
50
+ if not is_distributed() or not params:
51
+ return
52
+ # print('params[0].device ', params[0].device)
53
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
54
+ all_reduce(tensor)
55
+ if tensor.item() != len(params) * world_size():
56
+ # If not all the workers have the same number, for at least one of them,
57
+ # this inequality will be verified.
58
+ raise RuntimeError(
59
+ f"Mismatch in number of params: ours is {len(params)}, "
60
+ "at least one worker has a different one."
61
+ )
62
+
63
+
64
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
65
+ """Broadcast the tensors from the given parameters to all workers.
66
+ This can be used to ensure that all workers have the same model to start with.
67
+ """
68
+ if not is_distributed():
69
+ return
70
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
71
+ _check_number_of_params(tensors)
72
+ handles = []
73
+ for tensor in tensors:
74
+ # src = int(rank()) # added code
75
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
76
+ handles.append(handle)
77
+ for handle in handles:
78
+ handle.wait()
79
+
80
+
81
+ def sync_buffer(buffers, average=True):
82
+ """
83
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
84
+ """
85
+ if not is_distributed():
86
+ return
87
+ handles = []
88
+ for buffer in buffers:
89
+ if torch.is_floating_point(buffer.data):
90
+ if average:
91
+ handle = torch.distributed.all_reduce(
92
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
93
+ )
94
+ else:
95
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
96
+ handles.append((buffer, handle))
97
+ for buffer, handle in handles:
98
+ handle.wait()
99
+ if average:
100
+ buffer.data /= world_size
101
+
102
+
103
+ def sync_grad(params):
104
+ """
105
+ Simpler alternative to DistributedDataParallel, that doesn't rely
106
+ on any black magic. For simple models it can also be as fast.
107
+ Just call this on your model parameters after the call to backward!
108
+ """
109
+ if not is_distributed():
110
+ return
111
+ handles = []
112
+ for p in params:
113
+ if p.grad is not None:
114
+ handle = torch.distributed.all_reduce(
115
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
116
+ )
117
+ handles.append((p, handle))
118
+ for p, handle in handles:
119
+ handle.wait()
120
+ p.grad.data /= world_size()
121
+
122
+
123
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
124
+ """Average a dictionary of metrics across all workers, using the optional
125
+ `count` as unormalized weight.
126
+ """
127
+ if not is_distributed():
128
+ return metrics
129
+ keys, values = zip(*metrics.items())
130
+ device = "cuda" if torch.cuda.is_available() else "cpu"
131
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
132
+ tensor *= count
133
+ all_reduce(tensor)
134
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
135
+ return dict(zip(keys, averaged))
@@ -0,0 +1,125 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This source file is copied from https://github.com/facebookresearch/encodec
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+
13
+ """Residual vector quantizer implementation."""
14
+
15
+ from dataclasses import dataclass, field
16
+ import math
17
+ import typing as tp
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from .core_vq import ResidualVectorQuantization
23
+
24
+
25
+ @dataclass
26
+ class QuantizedResult:
27
+ quantized: torch.Tensor
28
+ codes: torch.Tensor
29
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
30
+ penalty: tp.Optional[torch.Tensor] = None
31
+ metrics: dict = field(default_factory=dict)
32
+
33
+
34
+ class ResidualVectorQuantizer(nn.Module):
35
+ """Residual Vector Quantizer.
36
+ Args:
37
+ dimension (int): Dimension of the codebooks.
38
+ n_q (int): Number of residual vector quantizers used.
39
+ bins (int): Codebook size.
40
+ decay (float): Decay for exponential moving average over the codebooks.
41
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
42
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
43
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
44
+ that have an exponential moving average cluster size less than the specified threshold with
45
+ randomly selected vector from the current batch.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dimension: int = 256,
51
+ n_q: int = 8,
52
+ bins: int = 1024,
53
+ decay: float = 0.99,
54
+ kmeans_init: bool = True,
55
+ kmeans_iters: int = 50,
56
+ threshold_ema_dead_code: int = 2,
57
+ ):
58
+ super().__init__()
59
+ self.n_q = n_q
60
+ self.dimension = dimension
61
+ self.bins = bins
62
+ self.decay = decay
63
+ self.kmeans_init = kmeans_init
64
+ self.kmeans_iters = kmeans_iters
65
+ self.threshold_ema_dead_code = threshold_ema_dead_code
66
+ self.vq = ResidualVectorQuantization(
67
+ dim=self.dimension,
68
+ codebook_size=self.bins,
69
+ num_quantizers=self.n_q,
70
+ decay=self.decay,
71
+ kmeans_init=self.kmeans_init,
72
+ kmeans_iters=self.kmeans_iters,
73
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
74
+ )
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ n_q: tp.Optional[int] = None,
80
+ layers: tp.Optional[list] = None,
81
+ ) -> QuantizedResult:
82
+ """Residual vector quantization on the given input tensor.
83
+ Args:
84
+ x (torch.Tensor): Input tensor.
85
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
86
+ layers (list): Layer that need to return quantized. Defalt: None.
87
+ Returns:
88
+ QuantizedResult:
89
+ The quantized (or approximately quantized) representation with
90
+ the associated numbert quantizers and layer quantized required to return.
91
+ """
92
+ n_q = n_q if n_q else self.n_q
93
+ if layers and max(layers) >= n_q:
94
+ raise ValueError(
95
+ f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
96
+ )
97
+ quantized, codes, commit_loss, quantized_list = self.vq(
98
+ x, n_q=n_q, layers=layers
99
+ )
100
+ return quantized, codes, torch.mean(commit_loss), quantized_list
101
+
102
+ def encode(
103
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
104
+ ) -> torch.Tensor:
105
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
+ The RVQ encode method sets the appropriate number of quantizer to use
107
+ and returns indices for each quantizer.
108
+ Args:
109
+ x (torch.Tensor): Input tensor.
110
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
111
+ st (int): Start to encode input from which layers. Default: 0.
112
+ """
113
+ n_q = n_q if n_q else self.n_q
114
+ st = st or 0
115
+ codes = self.vq.encode(x, n_q=n_q, st=st)
116
+ return codes
117
+
118
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
119
+ """Decode the given codes to the quantized representation.
120
+ Args:
121
+ codes (torch.Tensor): Input indices for each quantizer.
122
+ st (int): Start to decode input codes from which layers. Default: 0.
123
+ """
124
+ quantized = self.vq.decode(codes, st=st)
125
+ return quantized
@@ -0,0 +1,414 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This source file is copied from https://github.com/facebookresearch/encodec
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+
13
+ """Encodec SEANet-based encoder and decoder implementation."""
14
+
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import torch.nn as nn
19
+ import torch
20
+
21
+ from . import SConv1d, SConvTranspose1d, SLSTM
22
+
23
+
24
+ @torch.jit.script
25
+ def snake(x, alpha):
26
+ shape = x.shape
27
+ x = x.reshape(shape[0], shape[1], -1)
28
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
29
+ x = x.reshape(shape)
30
+ return x
31
+
32
+
33
+ class Snake1d(nn.Module):
34
+ def __init__(self, channels):
35
+ super().__init__()
36
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
37
+
38
+ def forward(self, x):
39
+ return snake(x, self.alpha)
40
+
41
+
42
+ class SEANetResnetBlock(nn.Module):
43
+ """Residual block from SEANet model.
44
+ Args:
45
+ dim (int): Dimension of the input/output
46
+ kernel_sizes (list): List of kernel sizes for the convolutions.
47
+ dilations (list): List of dilations for the convolutions.
48
+ activation (str): Activation function.
49
+ activation_params (dict): Parameters to provide to the activation function
50
+ norm (str): Normalization method.
51
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
52
+ causal (bool): Whether to use fully causal convolution.
53
+ pad_mode (str): Padding mode for the convolutions.
54
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
55
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ dim: int,
61
+ kernel_sizes: tp.List[int] = [3, 1],
62
+ dilations: tp.List[int] = [1, 1],
63
+ activation: str = "ELU",
64
+ activation_params: dict = {"alpha": 1.0},
65
+ norm: str = "weight_norm",
66
+ norm_params: tp.Dict[str, tp.Any] = {},
67
+ causal: bool = False,
68
+ pad_mode: str = "reflect",
69
+ compress: int = 2,
70
+ true_skip: bool = True,
71
+ ):
72
+ super().__init__()
73
+ assert len(kernel_sizes) == len(
74
+ dilations
75
+ ), "Number of kernel sizes should match number of dilations"
76
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
77
+ hidden = dim // compress
78
+ block = []
79
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
80
+ in_chs = dim if i == 0 else hidden
81
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
82
+ block += [
83
+ act(**activation_params) if activation != "Snake" else act(in_chs),
84
+ SConv1d(
85
+ in_chs,
86
+ out_chs,
87
+ kernel_size=kernel_size,
88
+ dilation=dilation,
89
+ norm=norm,
90
+ norm_kwargs=norm_params,
91
+ causal=causal,
92
+ pad_mode=pad_mode,
93
+ ),
94
+ ]
95
+ self.block = nn.Sequential(*block)
96
+ self.shortcut: nn.Module
97
+ if true_skip:
98
+ self.shortcut = nn.Identity()
99
+ else:
100
+ self.shortcut = SConv1d(
101
+ dim,
102
+ dim,
103
+ kernel_size=1,
104
+ norm=norm,
105
+ norm_kwargs=norm_params,
106
+ causal=causal,
107
+ pad_mode=pad_mode,
108
+ )
109
+
110
+ def forward(self, x):
111
+ return self.shortcut(x) + self.block(x)
112
+
113
+
114
+ class SEANetEncoder(nn.Module):
115
+ """SEANet encoder.
116
+ Args:
117
+ channels (int): Audio channels.
118
+ dimension (int): Intermediate representation dimension.
119
+ n_filters (int): Base width for the model.
120
+ n_residual_layers (int): nb of residual layers.
121
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
122
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
123
+ that must match the decoder order
124
+ activation (str): Activation function.
125
+ activation_params (dict): Parameters to provide to the activation function
126
+ norm (str): Normalization method.
127
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
128
+ kernel_size (int): Kernel size for the initial convolution.
129
+ last_kernel_size (int): Kernel size for the initial convolution.
130
+ residual_kernel_size (int): Kernel size for the residual layers.
131
+ dilation_base (int): How much to increase the dilation with each layer.
132
+ causal (bool): Whether to use fully causal convolution.
133
+ pad_mode (str): Padding mode for the convolutions.
134
+ true_skip (bool): Whether to use true skip connection or a simple
135
+ (streamable) convolution as the skip connection in the residual network blocks.
136
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
137
+ lstm (int): Number of LSTM layers at the end of the encoder.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ channels: int = 1,
143
+ dimension: int = 128,
144
+ n_filters: int = 32,
145
+ n_residual_layers: int = 1,
146
+ ratios: tp.List[int] = [8, 5, 4, 2],
147
+ activation: str = "ELU",
148
+ activation_params: dict = {"alpha": 1.0},
149
+ norm: str = "weight_norm",
150
+ norm_params: tp.Dict[str, tp.Any] = {},
151
+ kernel_size: int = 7,
152
+ last_kernel_size: int = 7,
153
+ residual_kernel_size: int = 3,
154
+ dilation_base: int = 2,
155
+ causal: bool = False,
156
+ pad_mode: str = "reflect",
157
+ true_skip: bool = False,
158
+ compress: int = 2,
159
+ lstm: int = 2,
160
+ bidirectional: bool = False,
161
+ ):
162
+ super().__init__()
163
+ self.channels = channels
164
+ self.dimension = dimension
165
+ self.n_filters = n_filters
166
+ self.ratios = list(reversed(ratios))
167
+ del ratios
168
+ self.n_residual_layers = n_residual_layers
169
+ self.hop_length = np.prod(self.ratios) # 计算乘积
170
+
171
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
172
+ mult = 1
173
+ model: tp.List[nn.Module] = [
174
+ SConv1d(
175
+ channels,
176
+ mult * n_filters,
177
+ kernel_size,
178
+ norm=norm,
179
+ norm_kwargs=norm_params,
180
+ causal=causal,
181
+ pad_mode=pad_mode,
182
+ )
183
+ ]
184
+ # Downsample to raw audio scale
185
+ for i, ratio in enumerate(self.ratios):
186
+ # Add residual layers
187
+ for j in range(n_residual_layers):
188
+ model += [
189
+ SEANetResnetBlock(
190
+ mult * n_filters,
191
+ kernel_sizes=[residual_kernel_size, 1],
192
+ dilations=[dilation_base**j, 1],
193
+ norm=norm,
194
+ norm_params=norm_params,
195
+ activation=activation,
196
+ activation_params=activation_params,
197
+ causal=causal,
198
+ pad_mode=pad_mode,
199
+ compress=compress,
200
+ true_skip=true_skip,
201
+ )
202
+ ]
203
+
204
+ # Add downsampling layers
205
+ model += [
206
+ (
207
+ act(**activation_params)
208
+ if activation != "Snake"
209
+ else act(mult * n_filters)
210
+ ),
211
+ SConv1d(
212
+ mult * n_filters,
213
+ mult * n_filters * 2,
214
+ kernel_size=ratio * 2,
215
+ stride=ratio,
216
+ norm=norm,
217
+ norm_kwargs=norm_params,
218
+ causal=causal,
219
+ pad_mode=pad_mode,
220
+ ),
221
+ ]
222
+ mult *= 2
223
+
224
+ if lstm:
225
+ model += [
226
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
227
+ ]
228
+
229
+ mult = mult * 2 if bidirectional else mult
230
+ model += [
231
+ (
232
+ act(**activation_params)
233
+ if activation != "Snake"
234
+ else act(mult * n_filters)
235
+ ),
236
+ SConv1d(
237
+ mult * n_filters,
238
+ dimension,
239
+ last_kernel_size,
240
+ norm=norm,
241
+ norm_kwargs=norm_params,
242
+ causal=causal,
243
+ pad_mode=pad_mode,
244
+ ),
245
+ ]
246
+
247
+ self.model = nn.Sequential(*model)
248
+
249
+ def forward(self, x):
250
+ return self.model(x)
251
+
252
+
253
+ class SEANetDecoder(nn.Module):
254
+ """SEANet decoder.
255
+ Args:
256
+ channels (int): Audio channels.
257
+ dimension (int): Intermediate representation dimension.
258
+ n_filters (int): Base width for the model.
259
+ n_residual_layers (int): nb of residual layers.
260
+ ratios (Sequence[int]): kernel size and stride ratios
261
+ activation (str): Activation function.
262
+ activation_params (dict): Parameters to provide to the activation function
263
+ final_activation (str): Final activation function after all convolutions.
264
+ final_activation_params (dict): Parameters to provide to the activation function
265
+ norm (str): Normalization method.
266
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
267
+ kernel_size (int): Kernel size for the initial convolution.
268
+ last_kernel_size (int): Kernel size for the initial convolution.
269
+ residual_kernel_size (int): Kernel size for the residual layers.
270
+ dilation_base (int): How much to increase the dilation with each layer.
271
+ causal (bool): Whether to use fully causal convolution.
272
+ pad_mode (str): Padding mode for the convolutions.
273
+ true_skip (bool): Whether to use true skip connection or a simple
274
+ (streamable) convolution as the skip connection in the residual network blocks.
275
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
276
+ lstm (int): Number of LSTM layers at the end of the encoder.
277
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
278
+ If equal to 1.0, it means that all the trimming is done at the right.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ channels: int = 1,
284
+ dimension: int = 128,
285
+ n_filters: int = 32,
286
+ n_residual_layers: int = 1,
287
+ ratios: tp.List[int] = [8, 5, 4, 2],
288
+ activation: str = "ELU",
289
+ activation_params: dict = {"alpha": 1.0},
290
+ final_activation: tp.Optional[str] = None,
291
+ final_activation_params: tp.Optional[dict] = None,
292
+ norm: str = "weight_norm",
293
+ norm_params: tp.Dict[str, tp.Any] = {},
294
+ kernel_size: int = 7,
295
+ last_kernel_size: int = 7,
296
+ residual_kernel_size: int = 3,
297
+ dilation_base: int = 2,
298
+ causal: bool = False,
299
+ pad_mode: str = "reflect",
300
+ true_skip: bool = False,
301
+ compress: int = 2,
302
+ lstm: int = 2,
303
+ trim_right_ratio: float = 1.0,
304
+ bidirectional: bool = False,
305
+ ):
306
+ super().__init__()
307
+ self.dimension = dimension
308
+ self.channels = channels
309
+ self.n_filters = n_filters
310
+ self.ratios = ratios
311
+ del ratios
312
+ self.n_residual_layers = n_residual_layers
313
+ self.hop_length = np.prod(self.ratios)
314
+
315
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
316
+ mult = int(2 ** len(self.ratios))
317
+ model: tp.List[nn.Module] = [
318
+ SConv1d(
319
+ dimension,
320
+ mult * n_filters,
321
+ kernel_size,
322
+ norm=norm,
323
+ norm_kwargs=norm_params,
324
+ causal=causal,
325
+ pad_mode=pad_mode,
326
+ )
327
+ ]
328
+
329
+ if lstm:
330
+ model += [
331
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
332
+ ]
333
+
334
+ # Upsample to raw audio scale
335
+ for i, ratio in enumerate(self.ratios):
336
+ # Add upsampling layers
337
+ model += [
338
+ (
339
+ act(**activation_params)
340
+ if activation != "Snake"
341
+ else act(mult * n_filters)
342
+ ),
343
+ SConvTranspose1d(
344
+ mult * n_filters,
345
+ mult * n_filters // 2,
346
+ kernel_size=ratio * 2,
347
+ stride=ratio,
348
+ norm=norm,
349
+ norm_kwargs=norm_params,
350
+ causal=causal,
351
+ trim_right_ratio=trim_right_ratio,
352
+ ),
353
+ ]
354
+ # Add residual layers
355
+ for j in range(n_residual_layers):
356
+ model += [
357
+ SEANetResnetBlock(
358
+ mult * n_filters // 2,
359
+ kernel_sizes=[residual_kernel_size, 1],
360
+ dilations=[dilation_base**j, 1],
361
+ activation=activation,
362
+ activation_params=activation_params,
363
+ norm=norm,
364
+ norm_params=norm_params,
365
+ causal=causal,
366
+ pad_mode=pad_mode,
367
+ compress=compress,
368
+ true_skip=true_skip,
369
+ )
370
+ ]
371
+
372
+ mult //= 2
373
+
374
+ # Add final layers
375
+ model += [
376
+ act(**activation_params) if activation != "Snake" else act(n_filters),
377
+ SConv1d(
378
+ n_filters,
379
+ channels,
380
+ last_kernel_size,
381
+ norm=norm,
382
+ norm_kwargs=norm_params,
383
+ causal=causal,
384
+ pad_mode=pad_mode,
385
+ ),
386
+ ]
387
+ # Add optional final activation to decoder (eg. tanh)
388
+ if final_activation is not None:
389
+ final_act = getattr(nn, final_activation)
390
+ final_activation_params = final_activation_params or {}
391
+ model += [final_act(**final_activation_params)]
392
+ self.model = nn.Sequential(*model)
393
+
394
+ def forward(self, z):
395
+ y = self.model(z)
396
+ return y
397
+
398
+
399
+ def test():
400
+ import torch
401
+
402
+ encoder = SEANetEncoder()
403
+ decoder = SEANetDecoder()
404
+ x = torch.randn(1, 1, 24000)
405
+ z = encoder(x)
406
+ print("z ", z.shape)
407
+ assert 1 == 2
408
+ assert list(z.shape) == [1, 128, 75], z.shape
409
+ y = decoder(z)
410
+ assert y.shape == x.shape, (x.shape, y.shape)
411
+
412
+
413
+ if __name__ == "__main__":
414
+ test()