xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,670 @@
1
+ """Library implementing normalization.
2
+
3
+ Authors
4
+ * Mirco Ravanelli 2020
5
+ * Guillermo Cámbara 2021
6
+ * Sarthak Yadav 2022
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class BatchNorm1d(nn.Module):
14
+ """Applies 1d batch normalization to the input tensor.
15
+
16
+ Arguments
17
+ ---------
18
+ input_shape : tuple
19
+ The expected shape of the input. Alternatively, use ``input_size``.
20
+ input_size : int
21
+ The expected size of the input. Alternatively, use ``input_shape``.
22
+ eps : float
23
+ This value is added to std deviation estimation to improve the numerical
24
+ stability.
25
+ momentum : float
26
+ It is a value used for the running_mean and running_var computation.
27
+ affine : bool
28
+ When set to True, the affine parameters are learned.
29
+ track_running_stats : bool
30
+ When set to True, this module tracks the running mean and variance,
31
+ and when set to False, this module does not track such statistics.
32
+ combine_batch_time : bool
33
+ When true, it combines batch an time axis.
34
+ skip_transpose : bool
35
+ Whether to skip the transposition.
36
+
37
+
38
+ Example
39
+ -------
40
+ >>> input = torch.randn(100, 10)
41
+ >>> norm = BatchNorm1d(input_shape=input.shape)
42
+ >>> output = norm(input)
43
+ >>> output.shape
44
+ torch.Size([100, 10])
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ input_shape=None,
50
+ input_size=None,
51
+ eps=1e-05,
52
+ momentum=0.1,
53
+ affine=True,
54
+ track_running_stats=True,
55
+ combine_batch_time=False,
56
+ skip_transpose=False,
57
+ ):
58
+ super().__init__()
59
+ self.combine_batch_time = combine_batch_time
60
+ self.skip_transpose = skip_transpose
61
+
62
+ if input_size is None and skip_transpose:
63
+ input_size = input_shape[1]
64
+ elif input_size is None:
65
+ input_size = input_shape[-1]
66
+
67
+ self.norm = nn.BatchNorm1d(
68
+ input_size,
69
+ eps=eps,
70
+ momentum=momentum,
71
+ affine=affine,
72
+ track_running_stats=track_running_stats,
73
+ )
74
+
75
+ def forward(self, x):
76
+ """Returns the normalized input tensor.
77
+
78
+ Arguments
79
+ ---------
80
+ x : torch.Tensor (batch, time, [channels])
81
+ input to normalize. 2d or 3d tensors are expected in input
82
+ 4d tensors can be used when combine_dims=True.
83
+
84
+ Returns
85
+ -------
86
+ x_n : torch.Tensor
87
+ The normalized outputs.
88
+ """
89
+ shape_or = x.shape
90
+ if self.combine_batch_time:
91
+ if x.ndim == 3:
92
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
93
+ else:
94
+ x = x.reshape(
95
+ shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
96
+ )
97
+
98
+ elif not self.skip_transpose:
99
+ x = x.transpose(-1, 1)
100
+
101
+ x_n = self.norm(x)
102
+
103
+ if self.combine_batch_time:
104
+ x_n = x_n.reshape(shape_or)
105
+ elif not self.skip_transpose:
106
+ x_n = x_n.transpose(1, -1)
107
+
108
+ return x_n
109
+
110
+
111
+ class BatchNorm2d(nn.Module):
112
+ """Applies 2d batch normalization to the input tensor.
113
+
114
+ Arguments
115
+ ---------
116
+ input_shape : tuple
117
+ The expected shape of the input. Alternatively, use ``input_size``.
118
+ input_size : int
119
+ The expected size of the input. Alternatively, use ``input_shape``.
120
+ eps : float
121
+ This value is added to std deviation estimation to improve the numerical
122
+ stability.
123
+ momentum : float
124
+ It is a value used for the running_mean and running_var computation.
125
+ affine : bool
126
+ When set to True, the affine parameters are learned.
127
+ track_running_stats : bool
128
+ When set to True, this module tracks the running mean and variance,
129
+ and when set to False, this module does not track such statistics.
130
+
131
+ Example
132
+ -------
133
+ >>> input = torch.randn(100, 10, 5, 20)
134
+ >>> norm = BatchNorm2d(input_shape=input.shape)
135
+ >>> output = norm(input)
136
+ >>> output.shape
137
+ torch.Size([100, 10, 5, 20])
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ input_shape=None,
143
+ input_size=None,
144
+ eps=1e-05,
145
+ momentum=0.1,
146
+ affine=True,
147
+ track_running_stats=True,
148
+ ):
149
+ super().__init__()
150
+
151
+ if input_shape is None and input_size is None:
152
+ raise ValueError("Expected input_shape or input_size as input")
153
+
154
+ if input_size is None:
155
+ input_size = input_shape[-1]
156
+
157
+ self.norm = nn.BatchNorm2d(
158
+ input_size,
159
+ eps=eps,
160
+ momentum=momentum,
161
+ affine=affine,
162
+ track_running_stats=track_running_stats,
163
+ )
164
+
165
+ def forward(self, x):
166
+ """Returns the normalized input tensor.
167
+
168
+ Arguments
169
+ ---------
170
+ x : torch.Tensor (batch, time, channel1, channel2)
171
+ input to normalize. 4d tensors are expected.
172
+
173
+ Returns
174
+ -------
175
+ x_n : torch.Tensor
176
+ The normalized outputs.
177
+ """
178
+ x = x.transpose(-1, 1)
179
+ x_n = self.norm(x)
180
+ x_n = x_n.transpose(1, -1)
181
+
182
+ return x_n
183
+
184
+
185
+ class LayerNorm(nn.Module):
186
+ """Applies layer normalization to the input tensor.
187
+
188
+ Arguments
189
+ ---------
190
+ input_size : int
191
+ The expected size of the dimension to be normalized.
192
+ input_shape : tuple
193
+ The expected shape of the input.
194
+ eps : float
195
+ This value is added to std deviation estimation to improve the numerical
196
+ stability.
197
+ elementwise_affine : bool
198
+ If True, this module has learnable per-element affine parameters
199
+ initialized to ones (for weights) and zeros (for biases).
200
+
201
+ Example
202
+ -------
203
+ >>> input = torch.randn(100, 101, 128)
204
+ >>> norm = LayerNorm(input_shape=input.shape)
205
+ >>> output = norm(input)
206
+ >>> output.shape
207
+ torch.Size([100, 101, 128])
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ input_size=None,
213
+ input_shape=None,
214
+ eps=1e-05,
215
+ elementwise_affine=True,
216
+ ):
217
+ super().__init__()
218
+ self.eps = eps
219
+ self.elementwise_affine = elementwise_affine
220
+
221
+ if input_shape is not None:
222
+ input_size = input_shape[2:]
223
+
224
+ self.norm = torch.nn.LayerNorm(
225
+ input_size,
226
+ eps=self.eps,
227
+ elementwise_affine=self.elementwise_affine,
228
+ )
229
+
230
+ def forward(self, x):
231
+ """Returns the normalized input tensor.
232
+
233
+ Arguments
234
+ ---------
235
+ x : torch.Tensor (batch, time, channels)
236
+ input to normalize. 3d or 4d tensors are expected.
237
+
238
+ Returns
239
+ -------
240
+ The normalized outputs.
241
+ """
242
+ return self.norm(x)
243
+
244
+
245
+ class InstanceNorm1d(nn.Module):
246
+ """Applies 1d instance normalization to the input tensor.
247
+
248
+ Arguments
249
+ ---------
250
+ input_shape : tuple
251
+ The expected shape of the input. Alternatively, use ``input_size``.
252
+ input_size : int
253
+ The expected size of the input. Alternatively, use ``input_shape``.
254
+ eps : float
255
+ This value is added to std deviation estimation to improve the numerical
256
+ stability.
257
+ momentum : float
258
+ It is a value used for the running_mean and running_var computation.
259
+ track_running_stats : bool
260
+ When set to True, this module tracks the running mean and variance,
261
+ and when set to False, this module does not track such statistics.
262
+ affine : bool
263
+ A boolean value that when set to True, this module has learnable
264
+ affine parameters, initialized the same way as done for
265
+ batch normalization. Default: False.
266
+
267
+ Example
268
+ -------
269
+ >>> input = torch.randn(100, 10, 20)
270
+ >>> norm = InstanceNorm1d(input_shape=input.shape)
271
+ >>> output = norm(input)
272
+ >>> output.shape
273
+ torch.Size([100, 10, 20])
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ input_shape=None,
279
+ input_size=None,
280
+ eps=1e-05,
281
+ momentum=0.1,
282
+ track_running_stats=True,
283
+ affine=False,
284
+ ):
285
+ super().__init__()
286
+
287
+ if input_shape is None and input_size is None:
288
+ raise ValueError("Expected input_shape or input_size as input")
289
+
290
+ if input_size is None:
291
+ input_size = input_shape[-1]
292
+
293
+ self.norm = nn.InstanceNorm1d(
294
+ input_size,
295
+ eps=eps,
296
+ momentum=momentum,
297
+ track_running_stats=track_running_stats,
298
+ affine=affine,
299
+ )
300
+
301
+ def forward(self, x):
302
+ """Returns the normalized input tensor.
303
+
304
+ Arguments
305
+ ---------
306
+ x : torch.Tensor (batch, time, channels)
307
+ input to normalize. 3d tensors are expected.
308
+
309
+ Returns
310
+ -------
311
+ x_n : torch.Tensor
312
+ The normalized outputs.
313
+ """
314
+ x = x.transpose(-1, 1)
315
+ x_n = self.norm(x)
316
+ x_n = x_n.transpose(1, -1)
317
+
318
+ return x_n
319
+
320
+
321
+ class InstanceNorm2d(nn.Module):
322
+ """Applies 2d instance normalization to the input tensor.
323
+
324
+ Arguments
325
+ ---------
326
+ input_shape : tuple
327
+ The expected shape of the input. Alternatively, use ``input_size``.
328
+ input_size : int
329
+ The expected size of the input. Alternatively, use ``input_shape``.
330
+ eps : float
331
+ This value is added to std deviation estimation to improve the numerical
332
+ stability.
333
+ momentum : float
334
+ It is a value used for the running_mean and running_var computation.
335
+ track_running_stats : bool
336
+ When set to True, this module tracks the running mean and variance,
337
+ and when set to False, this module does not track such statistics.
338
+ affine : bool
339
+ A boolean value that when set to True, this module has learnable
340
+ affine parameters, initialized the same way as done for
341
+ batch normalization. Default: False.
342
+
343
+ Example
344
+ -------
345
+ >>> input = torch.randn(100, 10, 20, 2)
346
+ >>> norm = InstanceNorm2d(input_shape=input.shape)
347
+ >>> output = norm(input)
348
+ >>> output.shape
349
+ torch.Size([100, 10, 20, 2])
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ input_shape=None,
355
+ input_size=None,
356
+ eps=1e-05,
357
+ momentum=0.1,
358
+ track_running_stats=True,
359
+ affine=False,
360
+ ):
361
+ super().__init__()
362
+
363
+ if input_shape is None and input_size is None:
364
+ raise ValueError("Expected input_shape or input_size as input")
365
+
366
+ if input_size is None:
367
+ input_size = input_shape[-1]
368
+
369
+ self.norm = nn.InstanceNorm2d(
370
+ input_size,
371
+ eps=eps,
372
+ momentum=momentum,
373
+ track_running_stats=track_running_stats,
374
+ affine=affine,
375
+ )
376
+
377
+ def forward(self, x):
378
+ """Returns the normalized input tensor.
379
+
380
+ Arguments
381
+ ---------
382
+ x : torch.Tensor (batch, time, channel1, channel2)
383
+ input to normalize. 4d tensors are expected.
384
+
385
+ Returns
386
+ -------
387
+ x_n : torch.Tensor
388
+ The normalized outputs.
389
+ """
390
+ x = x.transpose(-1, 1)
391
+ x_n = self.norm(x)
392
+ x_n = x_n.transpose(1, -1)
393
+
394
+ return x_n
395
+
396
+
397
+ class GroupNorm(nn.Module):
398
+ """Applies group normalization to the input tensor.
399
+
400
+ Arguments
401
+ ---------
402
+ input_shape : tuple
403
+ The expected shape of the input. Alternatively, use ``input_size``.
404
+ input_size : int
405
+ The expected size of the input. Alternatively, use ``input_shape``.
406
+ num_groups : int
407
+ Number of groups to separate the channels into.
408
+ eps : float
409
+ This value is added to std deviation estimation to improve the numerical
410
+ stability.
411
+ affine : bool
412
+ A boolean value that when set to True, this module has learnable per-channel
413
+ affine parameters initialized to ones (for weights) and zeros (for biases).
414
+
415
+ Example
416
+ -------
417
+ >>> input = torch.randn(100, 101, 128)
418
+ >>> norm = GroupNorm(input_size=128, num_groups=128)
419
+ >>> output = norm(input)
420
+ >>> output.shape
421
+ torch.Size([100, 101, 128])
422
+ """
423
+
424
+ def __init__(
425
+ self,
426
+ input_shape=None,
427
+ input_size=None,
428
+ num_groups=None,
429
+ eps=1e-05,
430
+ affine=True,
431
+ ):
432
+ super().__init__()
433
+ self.eps = eps
434
+ self.affine = affine
435
+
436
+ if input_shape is None and input_size is None:
437
+ raise ValueError("Expected input_shape or input_size as input")
438
+
439
+ if num_groups is None:
440
+ raise ValueError("Expected num_groups as input")
441
+
442
+ if input_shape is not None:
443
+ input_size = input_shape[-1]
444
+
445
+ self.norm = torch.nn.GroupNorm(
446
+ num_groups,
447
+ input_size,
448
+ eps=self.eps,
449
+ affine=self.affine,
450
+ )
451
+
452
+ def forward(self, x):
453
+ """Returns the normalized input tensor.
454
+
455
+ Arguments
456
+ ---------
457
+ x : torch.Tensor (batch, time, channels)
458
+ input to normalize. 3d or 4d tensors are expected.
459
+
460
+ Returns
461
+ -------
462
+ x_n : torch.Tensor
463
+ The normalized outputs.
464
+ """
465
+ x = x.transpose(-1, 1)
466
+ x_n = self.norm(x)
467
+ x_n = x_n.transpose(1, -1)
468
+
469
+ return x_n
470
+
471
+
472
+ class ExponentialMovingAverage(nn.Module):
473
+ """
474
+ Applies learnable exponential moving average, as required by learnable PCEN layer
475
+
476
+ Arguments
477
+ ---------
478
+ input_size : int
479
+ The expected size of the input.
480
+ coeff_init: float
481
+ Initial smoothing coefficient value
482
+ per_channel: bool
483
+ Controls whether every smoothing coefficients are learned
484
+ independently for every input channel
485
+ trainable: bool
486
+ whether to learn the PCEN parameters or use fixed
487
+ skip_transpose : bool
488
+ If False, uses batch x time x channel convention of speechbrain.
489
+ If True, uses batch x channel x time convention.
490
+
491
+ Example
492
+ -------
493
+ >>> inp_tensor = torch.rand([10, 50, 40])
494
+ >>> pcen = ExponentialMovingAverage(40)
495
+ >>> out_tensor = pcen(inp_tensor)
496
+ >>> out_tensor.shape
497
+ torch.Size([10, 50, 40])
498
+ """
499
+
500
+ def __init__(
501
+ self,
502
+ input_size: int,
503
+ coeff_init: float = 0.04,
504
+ per_channel: bool = False,
505
+ trainable: bool = True,
506
+ skip_transpose: bool = False,
507
+ ):
508
+ super().__init__()
509
+ self._coeff_init = coeff_init
510
+ self._per_channel = per_channel
511
+ self.skip_transpose = skip_transpose
512
+ self.trainable = trainable
513
+ weights = (
514
+ torch.ones(
515
+ input_size,
516
+ )
517
+ if self._per_channel
518
+ else torch.ones(
519
+ 1,
520
+ )
521
+ )
522
+ self._weights = nn.Parameter(
523
+ weights * self._coeff_init, requires_grad=trainable
524
+ )
525
+
526
+ def forward(self, x):
527
+ """Returns the normalized input tensor.
528
+
529
+ Arguments
530
+ ---------
531
+ x : torch.Tensor (batch, time, channels)
532
+ input to normalize.
533
+ """
534
+ if not self.skip_transpose:
535
+ x = x.transpose(1, -1)
536
+ w = torch.clamp(self._weights, min=0.0, max=1.0)
537
+ initial_state = x[:, :, 0]
538
+
539
+ def scan(init_state, x, w):
540
+ """Loops and accumulates."""
541
+ x = x.permute(2, 0, 1)
542
+ acc = init_state
543
+ results = []
544
+ for ix in range(x.shape[0]):
545
+ acc = (w * x[ix]) + ((1.0 - w) * acc)
546
+ results.append(acc.unsqueeze(0))
547
+ results = torch.cat(results, dim=0)
548
+ results = results.permute(1, 2, 0)
549
+ return results
550
+
551
+ output = scan(initial_state, x, w)
552
+ if not self.skip_transpose:
553
+ output = output.transpose(1, -1)
554
+ return output
555
+
556
+
557
+ class PCEN(nn.Module):
558
+ """
559
+ This class implements a learnable Per-channel energy normalization (PCEN) layer, supporting both
560
+ original PCEN as specified in [1] as well as sPCEN as specified in [2]
561
+
562
+ [1] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For
563
+ Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 (https://arxiv.org/abs/1607.05666)
564
+
565
+ [2] Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND
566
+ FOR AUDIO CLASSIFICATION", in Proc of ICLR 2021 (https://arxiv.org/abs/2101.08596)
567
+
568
+ The default argument values correspond with those used by [2].
569
+
570
+ Arguments
571
+ ---------
572
+ input_size : int
573
+ The expected size of the input.
574
+ alpha: float
575
+ specifies alpha coefficient for PCEN
576
+ smooth_coef: float
577
+ specified smooth coefficient for PCEN
578
+ delta: float
579
+ specifies delta coefficient for PCEN
580
+ root: float
581
+ specifies root coefficient for PCEN
582
+ floor: float
583
+ specifies floor coefficient for PCEN
584
+ trainable: bool
585
+ whether to learn the PCEN parameters or use fixed
586
+ per_channel_smooth_coef: bool
587
+ whether to learn independent smooth coefficients for every channel.
588
+ when True, essentially using sPCEN from [2]
589
+ skip_transpose : bool
590
+ If False, uses batch x time x channel convention of speechbrain.
591
+ If True, uses batch x channel x time convention.
592
+
593
+ Example
594
+ -------
595
+ >>> inp_tensor = torch.rand([10, 50, 40])
596
+ >>> pcen = PCEN(40, alpha=0.96) # sPCEN
597
+ >>> out_tensor = pcen(inp_tensor)
598
+ >>> out_tensor.shape
599
+ torch.Size([10, 50, 40])
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ input_size,
605
+ alpha: float = 0.96,
606
+ smooth_coef: float = 0.04,
607
+ delta: float = 2.0,
608
+ root: float = 2.0,
609
+ floor: float = 1e-12,
610
+ trainable: bool = True,
611
+ per_channel_smooth_coef: bool = True,
612
+ skip_transpose: bool = False,
613
+ ):
614
+ super().__init__()
615
+ self._smooth_coef = smooth_coef
616
+ self._floor = floor
617
+ self._per_channel_smooth_coef = per_channel_smooth_coef
618
+ self.skip_transpose = skip_transpose
619
+ self.alpha = nn.Parameter(
620
+ torch.ones(input_size) * alpha, requires_grad=trainable
621
+ )
622
+ self.delta = nn.Parameter(
623
+ torch.ones(input_size) * delta, requires_grad=trainable
624
+ )
625
+ self.root = nn.Parameter(
626
+ torch.ones(input_size) * root, requires_grad=trainable
627
+ )
628
+
629
+ self.ema = ExponentialMovingAverage(
630
+ input_size,
631
+ coeff_init=self._smooth_coef,
632
+ per_channel=self._per_channel_smooth_coef,
633
+ skip_transpose=True,
634
+ trainable=trainable,
635
+ )
636
+
637
+ def forward(self, x):
638
+ """Returns the normalized input tensor.
639
+
640
+ Arguments
641
+ ---------
642
+ x : torch.Tensor (batch, time, channels)
643
+ input to normalize.
644
+
645
+ Returns
646
+ -------
647
+ output : torch.Tensor
648
+ The normalized outputs.
649
+ """
650
+ if not self.skip_transpose:
651
+ x = x.transpose(1, -1)
652
+ alpha = torch.min(
653
+ self.alpha, torch.tensor(1.0, dtype=x.dtype, device=x.device)
654
+ )
655
+ root = torch.max(
656
+ self.root, torch.tensor(1.0, dtype=x.dtype, device=x.device)
657
+ )
658
+ ema_smoother = self.ema(x)
659
+ one_over_root = 1.0 / root
660
+ output = (
661
+ x / (self._floor + ema_smoother) ** alpha.view(1, -1, 1)
662
+ + self.delta.view(1, -1, 1)
663
+ ) ** one_over_root.view(1, -1, 1) - self.delta.view(
664
+ 1, -1, 1
665
+ ) ** one_over_root.view(
666
+ 1, -1, 1
667
+ )
668
+ if not self.skip_transpose:
669
+ output = output.transpose(1, -1)
670
+ return output