xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,656 @@
1
+ """A popular speaker recognition and diarization model.
2
+
3
+ Authors
4
+ * Hwidong Na 2020
5
+ """
6
+
7
+ import torch # noqa: F401
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
12
+ from indextts.BigVGAN.nnet.linear import Linear
13
+ from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
14
+
15
+
16
+ def length_to_mask(length, max_len=None, dtype=None, device=None):
17
+ """Creates a binary mask for each sequence.
18
+
19
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
20
+
21
+ Arguments
22
+ ---------
23
+ length : torch.LongTensor
24
+ Containing the length of each sequence in the batch. Must be 1D.
25
+ max_len : int
26
+ Max length for the mask, also the size of the second dimension.
27
+ dtype : torch.dtype, default: None
28
+ The dtype of the generated mask.
29
+ device: torch.device, default: None
30
+ The device to put the mask variable.
31
+
32
+ Returns
33
+ -------
34
+ mask : tensor
35
+ The binary mask.
36
+
37
+ Example
38
+ -------
39
+ >>> length=torch.Tensor([1,2,3])
40
+ >>> mask=length_to_mask(length)
41
+ >>> mask
42
+ tensor([[1., 0., 0.],
43
+ [1., 1., 0.],
44
+ [1., 1., 1.]])
45
+ """
46
+ assert len(length.shape) == 1
47
+
48
+ if max_len is None:
49
+ max_len = length.max().long().item() # using arange to generate mask
50
+ mask = torch.arange(
51
+ max_len, device=length.device, dtype=length.dtype
52
+ ).expand(len(length), max_len) < length.unsqueeze(1)
53
+
54
+ if dtype is None:
55
+ dtype = length.dtype
56
+
57
+ if device is None:
58
+ device = length.device
59
+
60
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
61
+ return mask
62
+
63
+
64
+ # Skip transpose as much as possible for efficiency
65
+ class Conv1d(_Conv1d):
66
+ """1D convolution. Skip transpose is used to improve efficiency."""
67
+
68
+ def __init__(self, *args, **kwargs):
69
+ super().__init__(skip_transpose=True, *args, **kwargs)
70
+
71
+
72
+ class BatchNorm1d(_BatchNorm1d):
73
+ """1D batch normalization. Skip transpose is used to improve efficiency."""
74
+
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__(skip_transpose=True, *args, **kwargs)
77
+
78
+
79
+ class TDNNBlock(nn.Module):
80
+ """An implementation of TDNN.
81
+
82
+ Arguments
83
+ ---------
84
+ in_channels : int
85
+ Number of input channels.
86
+ out_channels : int
87
+ The number of output channels.
88
+ kernel_size : int
89
+ The kernel size of the TDNN blocks.
90
+ dilation : int
91
+ The dilation of the TDNN block.
92
+ activation : torch class
93
+ A class for constructing the activation layers.
94
+ groups : int
95
+ The groups size of the TDNN blocks.
96
+
97
+ Example
98
+ -------
99
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
100
+ >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
101
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
102
+ >>> out_tensor.shape
103
+ torch.Size([8, 120, 64])
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ in_channels,
109
+ out_channels,
110
+ kernel_size,
111
+ dilation,
112
+ activation=nn.ReLU,
113
+ groups=1,
114
+ ):
115
+ super().__init__()
116
+ self.conv = Conv1d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ dilation=dilation,
121
+ groups=groups,
122
+ )
123
+ self.activation = activation()
124
+ self.norm = BatchNorm1d(input_size=out_channels)
125
+
126
+ def forward(self, x):
127
+ """Processes the input tensor x and returns an output tensor."""
128
+ return self.norm(self.activation(self.conv(x)))
129
+
130
+
131
+ class Res2NetBlock(torch.nn.Module):
132
+ """An implementation of Res2NetBlock w/ dilation.
133
+
134
+ Arguments
135
+ ---------
136
+ in_channels : int
137
+ The number of channels expected in the input.
138
+ out_channels : int
139
+ The number of output channels.
140
+ scale : int
141
+ The scale of the Res2Net block.
142
+ kernel_size: int
143
+ The kernel size of the Res2Net block.
144
+ dilation : int
145
+ The dilation of the Res2Net block.
146
+
147
+ Example
148
+ -------
149
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
150
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
151
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
152
+ >>> out_tensor.shape
153
+ torch.Size([8, 120, 64])
154
+ """
155
+
156
+ def __init__(
157
+ self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
158
+ ):
159
+ super().__init__()
160
+ assert in_channels % scale == 0
161
+ assert out_channels % scale == 0
162
+
163
+ in_channel = in_channels // scale
164
+ hidden_channel = out_channels // scale
165
+
166
+ self.blocks = nn.ModuleList(
167
+ [
168
+ TDNNBlock(
169
+ in_channel,
170
+ hidden_channel,
171
+ kernel_size=kernel_size,
172
+ dilation=dilation,
173
+ )
174
+ for i in range(scale - 1)
175
+ ]
176
+ )
177
+ self.scale = scale
178
+
179
+ def forward(self, x):
180
+ """Processes the input tensor x and returns an output tensor."""
181
+ y = []
182
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
183
+ if i == 0:
184
+ y_i = x_i
185
+ elif i == 1:
186
+ y_i = self.blocks[i - 1](x_i)
187
+ else:
188
+ y_i = self.blocks[i - 1](x_i + y_i)
189
+ y.append(y_i)
190
+ y = torch.cat(y, dim=1)
191
+ return y
192
+
193
+
194
+ class SEBlock(nn.Module):
195
+ """An implementation of squeeze-and-excitation block.
196
+
197
+ Arguments
198
+ ---------
199
+ in_channels : int
200
+ The number of input channels.
201
+ se_channels : int
202
+ The number of output channels after squeeze.
203
+ out_channels : int
204
+ The number of output channels.
205
+
206
+ Example
207
+ -------
208
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
209
+ >>> se_layer = SEBlock(64, 16, 64)
210
+ >>> lengths = torch.rand((8,))
211
+ >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
212
+ >>> out_tensor.shape
213
+ torch.Size([8, 120, 64])
214
+ """
215
+
216
+ def __init__(self, in_channels, se_channels, out_channels):
217
+ super().__init__()
218
+
219
+ self.conv1 = Conv1d(
220
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
221
+ )
222
+ self.relu = torch.nn.ReLU(inplace=True)
223
+ self.conv2 = Conv1d(
224
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
225
+ )
226
+ self.sigmoid = torch.nn.Sigmoid()
227
+
228
+ def forward(self, x, lengths=None):
229
+ """Processes the input tensor x and returns an output tensor."""
230
+ L = x.shape[-1]
231
+ if lengths is not None:
232
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
233
+ mask = mask.unsqueeze(1)
234
+ total = mask.sum(dim=2, keepdim=True)
235
+ s = (x * mask).sum(dim=2, keepdim=True) / total
236
+ else:
237
+ s = x.mean(dim=2, keepdim=True)
238
+
239
+ s = self.relu(self.conv1(s))
240
+ s = self.sigmoid(self.conv2(s))
241
+
242
+ return s * x
243
+
244
+
245
+ class AttentiveStatisticsPooling(nn.Module):
246
+ """This class implements an attentive statistic pooling layer for each channel.
247
+ It returns the concatenated mean and std of the input tensor.
248
+
249
+ Arguments
250
+ ---------
251
+ channels: int
252
+ The number of input channels.
253
+ attention_channels: int
254
+ The number of attention channels.
255
+ global_context: bool
256
+ Whether to use global context.
257
+
258
+ Example
259
+ -------
260
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
261
+ >>> asp_layer = AttentiveStatisticsPooling(64)
262
+ >>> lengths = torch.rand((8,))
263
+ >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
264
+ >>> out_tensor.shape
265
+ torch.Size([8, 1, 128])
266
+ """
267
+
268
+ def __init__(self, channels, attention_channels=128, global_context=True):
269
+ super().__init__()
270
+
271
+ self.eps = 1e-12
272
+ self.global_context = global_context
273
+ if global_context:
274
+ self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
275
+ else:
276
+ self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
277
+ self.tanh = nn.Tanh()
278
+ self.conv = Conv1d(
279
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
280
+ )
281
+
282
+ def forward(self, x, lengths=None):
283
+ """Calculates mean and std for a batch (input tensor).
284
+
285
+ Arguments
286
+ ---------
287
+ x : torch.Tensor
288
+ Tensor of shape [N, C, L].
289
+ lengths : torch.Tensor
290
+ The corresponding relative lengths of the inputs.
291
+
292
+ Returns
293
+ -------
294
+ pooled_stats : torch.Tensor
295
+ mean and std of batch
296
+ """
297
+ L = x.shape[-1]
298
+
299
+ def _compute_statistics(x, m, dim=2, eps=self.eps):
300
+ mean = (m * x).sum(dim)
301
+ std = torch.sqrt(
302
+ (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
303
+ )
304
+ return mean, std
305
+
306
+ if lengths is None:
307
+ lengths = torch.ones(x.shape[0], device=x.device)
308
+
309
+ # Make binary mask of shape [N, 1, L]
310
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
311
+ mask = mask.unsqueeze(1)
312
+
313
+ # Expand the temporal context of the pooling layer by allowing the
314
+ # self-attention to look at global properties of the utterance.
315
+ if self.global_context:
316
+ # torch.std is unstable for backward computation
317
+ # https://github.com/pytorch/pytorch/issues/4320
318
+ total = mask.sum(dim=2, keepdim=True).float()
319
+ mean, std = _compute_statistics(x, mask / total)
320
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
321
+ std = std.unsqueeze(2).repeat(1, 1, L)
322
+ attn = torch.cat([x, mean, std], dim=1)
323
+ else:
324
+ attn = x
325
+
326
+ # Apply layers
327
+ attn = self.conv(self.tanh(self.tdnn(attn)))
328
+
329
+ # Filter out zero-paddings
330
+ attn = attn.masked_fill(mask == 0, float("-inf"))
331
+
332
+ attn = F.softmax(attn, dim=2)
333
+ mean, std = _compute_statistics(x, attn)
334
+ # Append mean and std of the batch
335
+ pooled_stats = torch.cat((mean, std), dim=1)
336
+ pooled_stats = pooled_stats.unsqueeze(2)
337
+
338
+ return pooled_stats
339
+
340
+
341
+ class SERes2NetBlock(nn.Module):
342
+ """An implementation of building block in ECAPA-TDNN, i.e.,
343
+ TDNN-Res2Net-TDNN-SEBlock.
344
+
345
+ Arguments
346
+ ---------
347
+ in_channels: int
348
+ Expected size of input channels.
349
+ out_channels: int
350
+ The number of output channels.
351
+ res2net_scale: int
352
+ The scale of the Res2Net block.
353
+ se_channels : int
354
+ The number of output channels after squeeze.
355
+ kernel_size: int
356
+ The kernel size of the TDNN blocks.
357
+ dilation: int
358
+ The dilation of the Res2Net block.
359
+ activation : torch class
360
+ A class for constructing the activation layers.
361
+ groups: int
362
+ Number of blocked connections from input channels to output channels.
363
+
364
+ Example
365
+ -------
366
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
367
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
368
+ >>> out = conv(x).transpose(1, 2)
369
+ >>> out.shape
370
+ torch.Size([8, 120, 64])
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ in_channels,
376
+ out_channels,
377
+ res2net_scale=8,
378
+ se_channels=128,
379
+ kernel_size=1,
380
+ dilation=1,
381
+ activation=torch.nn.ReLU,
382
+ groups=1,
383
+ ):
384
+ super().__init__()
385
+ self.out_channels = out_channels
386
+ self.tdnn1 = TDNNBlock(
387
+ in_channels,
388
+ out_channels,
389
+ kernel_size=1,
390
+ dilation=1,
391
+ activation=activation,
392
+ groups=groups,
393
+ )
394
+ self.res2net_block = Res2NetBlock(
395
+ out_channels, out_channels, res2net_scale, kernel_size, dilation
396
+ )
397
+ self.tdnn2 = TDNNBlock(
398
+ out_channels,
399
+ out_channels,
400
+ kernel_size=1,
401
+ dilation=1,
402
+ activation=activation,
403
+ groups=groups,
404
+ )
405
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
406
+
407
+ self.shortcut = None
408
+ if in_channels != out_channels:
409
+ self.shortcut = Conv1d(
410
+ in_channels=in_channels,
411
+ out_channels=out_channels,
412
+ kernel_size=1,
413
+ )
414
+
415
+ def forward(self, x, lengths=None):
416
+ """Processes the input tensor x and returns an output tensor."""
417
+ residual = x
418
+ if self.shortcut:
419
+ residual = self.shortcut(x)
420
+
421
+ x = self.tdnn1(x)
422
+ x = self.res2net_block(x)
423
+ x = self.tdnn2(x)
424
+ x = self.se_block(x, lengths)
425
+
426
+ return x + residual
427
+
428
+
429
+ class ECAPA_TDNN(torch.nn.Module):
430
+ """An implementation of the speaker embedding model in a paper.
431
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
432
+ TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
433
+
434
+ Arguments
435
+ ---------
436
+ input_size : int
437
+ Expected size of the input dimension.
438
+ device : str
439
+ Device used, e.g., "cpu" or "cuda".
440
+ lin_neurons : int
441
+ Number of neurons in linear layers.
442
+ activation : torch class
443
+ A class for constructing the activation layers.
444
+ channels : list of ints
445
+ Output channels for TDNN/SERes2Net layer.
446
+ kernel_sizes : list of ints
447
+ List of kernel sizes for each layer.
448
+ dilations : list of ints
449
+ List of dilations for kernels in each layer.
450
+ attention_channels: int
451
+ The number of attention channels.
452
+ res2net_scale : int
453
+ The scale of the Res2Net block.
454
+ se_channels : int
455
+ The number of output channels after squeeze.
456
+ global_context: bool
457
+ Whether to use global context.
458
+ groups : list of ints
459
+ List of groups for kernels in each layer.
460
+
461
+ Example
462
+ -------
463
+ >>> input_feats = torch.rand([5, 120, 80])
464
+ >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
465
+ >>> outputs = compute_embedding(input_feats)
466
+ >>> outputs.shape
467
+ torch.Size([5, 1, 192])
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ input_size,
473
+ device="cpu",
474
+ lin_neurons=192,
475
+ activation=torch.nn.ReLU,
476
+ channels=[512, 512, 512, 512, 1536],
477
+ kernel_sizes=[5, 3, 3, 3, 1],
478
+ dilations=[1, 2, 3, 4, 1],
479
+ attention_channels=128,
480
+ res2net_scale=8,
481
+ se_channels=128,
482
+ global_context=True,
483
+ groups=[1, 1, 1, 1, 1],
484
+ ):
485
+ super().__init__()
486
+ assert len(channels) == len(kernel_sizes)
487
+ assert len(channels) == len(dilations)
488
+ self.channels = channels
489
+ self.blocks = nn.ModuleList()
490
+
491
+ # The initial TDNN layer
492
+ self.blocks.append(
493
+ TDNNBlock(
494
+ input_size,
495
+ channels[0],
496
+ kernel_sizes[0],
497
+ dilations[0],
498
+ activation,
499
+ groups[0],
500
+ )
501
+ )
502
+
503
+ # SE-Res2Net layers
504
+ for i in range(1, len(channels) - 1):
505
+ self.blocks.append(
506
+ SERes2NetBlock(
507
+ channels[i - 1],
508
+ channels[i],
509
+ res2net_scale=res2net_scale,
510
+ se_channels=se_channels,
511
+ kernel_size=kernel_sizes[i],
512
+ dilation=dilations[i],
513
+ activation=activation,
514
+ groups=groups[i],
515
+ )
516
+ )
517
+
518
+ # Multi-layer feature aggregation
519
+ self.mfa = TDNNBlock(
520
+ channels[-2] * (len(channels) - 2),
521
+ channels[-1],
522
+ kernel_sizes[-1],
523
+ dilations[-1],
524
+ activation,
525
+ groups=groups[-1],
526
+ )
527
+
528
+ # Attentive Statistical Pooling
529
+ self.asp = AttentiveStatisticsPooling(
530
+ channels[-1],
531
+ attention_channels=attention_channels,
532
+ global_context=global_context,
533
+ )
534
+ self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
535
+
536
+ # Final linear transformation
537
+ self.fc = Conv1d(
538
+ in_channels=channels[-1] * 2,
539
+ out_channels=lin_neurons,
540
+ kernel_size=1,
541
+ )
542
+
543
+ def forward(self, x, lengths=None):
544
+ """Returns the embedding vector.
545
+
546
+ Arguments
547
+ ---------
548
+ x : torch.Tensor
549
+ Tensor of shape (batch, time, channel).
550
+ lengths : torch.Tensor
551
+ Corresponding relative lengths of inputs.
552
+
553
+ Returns
554
+ -------
555
+ x : torch.Tensor
556
+ Embedding vector.
557
+ """
558
+ # Minimize transpose for efficiency
559
+ x = x.transpose(1, 2)
560
+
561
+ xl = []
562
+ for layer in self.blocks:
563
+ try:
564
+ x = layer(x, lengths=lengths)
565
+ except TypeError:
566
+ x = layer(x)
567
+ xl.append(x)
568
+
569
+ # Multi-layer feature aggregation
570
+ x = torch.cat(xl[1:], dim=1)
571
+ x = self.mfa(x)
572
+
573
+ # Attentive Statistical Pooling
574
+ x = self.asp(x, lengths=lengths)
575
+ x = self.asp_bn(x)
576
+
577
+ # Final linear transformation
578
+ x = self.fc(x)
579
+
580
+ x = x.transpose(1, 2)
581
+ return x
582
+
583
+
584
+ class Classifier(torch.nn.Module):
585
+ """This class implements the cosine similarity on the top of features.
586
+
587
+ Arguments
588
+ ---------
589
+ input_size : int
590
+ Expected size of input dimension.
591
+ device : str
592
+ Device used, e.g., "cpu" or "cuda".
593
+ lin_blocks : int
594
+ Number of linear layers.
595
+ lin_neurons : int
596
+ Number of neurons in linear layers.
597
+ out_neurons : int
598
+ Number of classes.
599
+
600
+ Example
601
+ -------
602
+ >>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
603
+ >>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
604
+ >>> outputs = outputs.unsqueeze(1)
605
+ >>> cos = classify(outputs)
606
+ >>> (cos < -1.0).long().sum()
607
+ tensor(0)
608
+ >>> (cos > 1.0).long().sum()
609
+ tensor(0)
610
+ """
611
+
612
+ def __init__(
613
+ self,
614
+ input_size,
615
+ device="cpu",
616
+ lin_blocks=0,
617
+ lin_neurons=192,
618
+ out_neurons=1211,
619
+ ):
620
+ super().__init__()
621
+ self.blocks = nn.ModuleList()
622
+
623
+ for block_index in range(lin_blocks):
624
+ self.blocks.extend(
625
+ [
626
+ _BatchNorm1d(input_size=input_size),
627
+ Linear(input_size=input_size, n_neurons=lin_neurons),
628
+ ]
629
+ )
630
+ input_size = lin_neurons
631
+
632
+ # Final Layer
633
+ self.weight = nn.Parameter(
634
+ torch.FloatTensor(out_neurons, input_size, device=device)
635
+ )
636
+ nn.init.xavier_uniform_(self.weight)
637
+
638
+ def forward(self, x):
639
+ """Returns the output probabilities over speakers.
640
+
641
+ Arguments
642
+ ---------
643
+ x : torch.Tensor
644
+ Torch tensor.
645
+
646
+ Returns
647
+ -------
648
+ out : torch.Tensor
649
+ Output probabilities over speakers.
650
+ """
651
+ for layer in self.blocks:
652
+ x = layer(x)
653
+
654
+ # Need to be normalized
655
+ x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight))
656
+ return x.unsqueeze(1)
File without changes