xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+
7
+ import matplotlib
8
+ import matplotlib.pylab as plt
9
+ import torch
10
+ from scipy.io.wavfile import write
11
+ from torch.nn.utils import weight_norm
12
+
13
+ matplotlib.use("Agg")
14
+
15
+ MAX_WAV_VALUE = 32768.0
16
+
17
+
18
+ def plot_spectrogram(spectrogram):
19
+ fig, ax = plt.subplots(figsize=(10, 2))
20
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
21
+ plt.colorbar(im, ax=ax)
22
+
23
+ fig.canvas.draw()
24
+ plt.close()
25
+
26
+ return fig
27
+
28
+
29
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
30
+ fig, ax = plt.subplots(figsize=(10, 2))
31
+ im = ax.imshow(
32
+ spectrogram,
33
+ aspect="auto",
34
+ origin="lower",
35
+ interpolation="none",
36
+ vmin=1e-6,
37
+ vmax=clip_max,
38
+ )
39
+ plt.colorbar(im, ax=ax)
40
+
41
+ fig.canvas.draw()
42
+ plt.close()
43
+
44
+ return fig
45
+
46
+
47
+ def init_weights(m, mean=0.0, std=0.01):
48
+ classname = m.__class__.__name__
49
+ if classname.find("Conv") != -1:
50
+ m.weight.data.normal_(mean, std)
51
+
52
+
53
+ def apply_weight_norm(m):
54
+ classname = m.__class__.__name__
55
+ if classname.find("Conv") != -1:
56
+ weight_norm(m)
57
+
58
+
59
+ def get_padding(kernel_size, dilation=1):
60
+ return int((kernel_size * dilation - dilation) / 2)
61
+
62
+
63
+ def load_checkpoint(filepath, device):
64
+ assert os.path.isfile(filepath)
65
+ print(f"Loading '{filepath}'")
66
+ checkpoint_dict = torch.load(filepath, map_location=device)
67
+ print("Complete.")
68
+ return checkpoint_dict
69
+
70
+
71
+ def save_checkpoint(filepath, obj):
72
+ print(f"Saving checkpoint to {filepath}")
73
+ torch.save(obj, filepath)
74
+ print("Complete.")
75
+
76
+
77
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
78
+ # Fallback to original scanning logic first
79
+ pattern = os.path.join(cp_dir, prefix + "????????")
80
+ cp_list = glob.glob(pattern)
81
+
82
+ if len(cp_list) > 0:
83
+ last_checkpoint_path = sorted(cp_list)[-1]
84
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
85
+ return last_checkpoint_path
86
+
87
+ # If no pattern-based checkpoints are found, check for renamed file
88
+ if renamed_file:
89
+ renamed_path = os.path.join(cp_dir, renamed_file)
90
+ if os.path.isfile(renamed_path):
91
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
92
+ return renamed_path
93
+
94
+ return None
95
+
96
+
97
+ def save_audio(audio, path, sr):
98
+ # wav: torch with 1d shape
99
+ audio = audio * MAX_WAV_VALUE
100
+ audio = audio.cpu().numpy().astype("int16")
101
+ write(path, sr, audio)
File without changes
@@ -0,0 +1,65 @@
1
+ import os
2
+ import sys
3
+ import warnings
4
+ # Suppress warnings from tensorflow and other libraries
5
+ warnings.filterwarnings("ignore", category=UserWarning)
6
+ warnings.filterwarnings("ignore", category=FutureWarning)
7
+ def main():
8
+ import argparse
9
+ parser = argparse.ArgumentParser(description="IndexTTS Command Line")
10
+ parser.add_argument("text", type=str, help="Text to be synthesized")
11
+ parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
12
+ parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
13
+ parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
14
+ parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
15
+ parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
16
+ parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
17
+ parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
18
+ args = parser.parse_args()
19
+ if len(args.text.strip()) == 0:
20
+ print("ERROR: Text is empty.")
21
+ parser.print_help()
22
+ sys.exit(1)
23
+ if not os.path.exists(args.voice):
24
+ print(f"Audio prompt file {args.voice} does not exist.")
25
+ parser.print_help()
26
+ sys.exit(1)
27
+ if not os.path.exists(args.config):
28
+ print(f"Config file {args.config} does not exist.")
29
+ parser.print_help()
30
+ sys.exit(1)
31
+
32
+ output_path = args.output_path
33
+ if os.path.exists(output_path):
34
+ if not args.force:
35
+ print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
36
+ parser.print_help()
37
+ sys.exit(1)
38
+ else:
39
+ os.remove(output_path)
40
+
41
+ try:
42
+ import torch
43
+ except ImportError:
44
+ print("ERROR: PyTorch is not installed. Please install it first.")
45
+ sys.exit(1)
46
+
47
+ if args.device is None:
48
+ if torch.cuda.is_available():
49
+ args.device = "cuda:0"
50
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
51
+ args.device = "xpu"
52
+ elif hasattr(torch, "mps") and torch.mps.is_available():
53
+ args.device = "mps"
54
+ else:
55
+ args.device = "cpu"
56
+ args.fp16 = False # Disable FP16 on CPU
57
+ print("WARNING: Running on CPU may be slow.")
58
+
59
+ # TODO: Add CLI support for IndexTTS2.
60
+ from indextts.infer import IndexTTS
61
+ tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
62
+ tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
63
+
64
+ if __name__ == "__main__":
65
+ main()
File without changes
@@ -0,0 +1,312 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
36
+ """Construct an MultiHeadedAttention object."""
37
+ super().__init__()
38
+ assert n_feat % n_head == 0
39
+ # We assume d_v always equals d_k
40
+ self.d_k = n_feat // n_head
41
+ self.h = n_head
42
+ self.linear_q = nn.Linear(n_feat, n_feat)
43
+ self.linear_k = nn.Linear(n_feat, n_feat)
44
+ self.linear_v = nn.Linear(n_feat, n_feat)
45
+ self.linear_out = nn.Linear(n_feat, n_feat)
46
+ self.dropout = nn.Dropout(p=dropout_rate)
47
+
48
+ def forward_qkv(
49
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Transform query, key and value.
52
+
53
+ Args:
54
+ query (torch.Tensor): Query tensor (#batch, time1, size).
55
+ key (torch.Tensor): Key tensor (#batch, time2, size).
56
+ value (torch.Tensor): Value tensor (#batch, time2, size).
57
+
58
+ Returns:
59
+ torch.Tensor: Transformed query tensor, size
60
+ (#batch, n_head, time1, d_k).
61
+ torch.Tensor: Transformed key tensor, size
62
+ (#batch, n_head, time2, d_k).
63
+ torch.Tensor: Transformed value tensor, size
64
+ (#batch, n_head, time2, d_k).
65
+
66
+ """
67
+ n_batch = query.size(0)
68
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
69
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
70
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
71
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
72
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
73
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
74
+
75
+ return q, k, v
76
+
77
+ def forward_attention(
78
+ self, value: torch.Tensor, scores: torch.Tensor,
79
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
80
+ ) -> torch.Tensor:
81
+ """Compute attention context vector.
82
+
83
+ Args:
84
+ value (torch.Tensor): Transformed value, size
85
+ (#batch, n_head, time2, d_k).
86
+ scores (torch.Tensor): Attention score, size
87
+ (#batch, n_head, time1, time2).
88
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
89
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
90
+
91
+ Returns:
92
+ torch.Tensor: Transformed value (#batch, time1, d_model)
93
+ weighted by the attention score (#batch, time1, time2).
94
+
95
+ """
96
+ n_batch = value.size(0)
97
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
98
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
99
+ # 1st chunk to ease the onnx export.]
100
+ # 2. pytorch training
101
+ if mask.size(2) > 0 : # time2 > 0
102
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
103
+ # For last chunk, time2 might be larger than scores.size(-1)
104
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
105
+ scores = scores.masked_fill(mask, -float('inf'))
106
+ attn = torch.softmax(scores, dim=-1).masked_fill(
107
+ mask, 0.0) # (batch, head, time1, time2)
108
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
109
+ # 1. onnx(16/-1, -1/-1, 16/0)
110
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
111
+ else:
112
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
113
+
114
+ p_attn = self.dropout(attn)
115
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
116
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
117
+ self.h * self.d_k)
118
+ ) # (batch, time1, d_model)
119
+
120
+ return self.linear_out(x) # (batch, time1, d_model)
121
+
122
+ def forward(self, query: torch.Tensor, key: torch.Tensor,
123
+ value: torch.Tensor,
124
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
125
+ pos_emb: torch.Tensor = torch.empty(0),
126
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
127
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
128
+ """Compute scaled dot product attention.
129
+
130
+ Args:
131
+ query (torch.Tensor): Query tensor (#batch, time1, size).
132
+ key (torch.Tensor): Key tensor (#batch, time2, size).
133
+ value (torch.Tensor): Value tensor (#batch, time2, size).
134
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
135
+ (#batch, time1, time2).
136
+ 1.When applying cross attention between decoder and encoder,
137
+ the batch padding mask for input is in (#batch, 1, T) shape.
138
+ 2.When applying self attention of encoder,
139
+ the mask is in (#batch, T, T) shape.
140
+ 3.When applying self attention of decoder,
141
+ the mask is in (#batch, L, L) shape.
142
+ 4.If the different position in decoder see different block
143
+ of the encoder, such as Mocha, the passed in mask could be
144
+ in (#batch, L, T) shape. But there is no such case in current
145
+ Wenet.
146
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
147
+ where `cache_t == chunk_size * num_decoding_left_chunks`
148
+ and `head * d_k == size`
149
+
150
+
151
+ Returns:
152
+ torch.Tensor: Output tensor (#batch, time1, d_model).
153
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
154
+ where `cache_t == chunk_size * num_decoding_left_chunks`
155
+ and `head * d_k == size`
156
+
157
+ """
158
+ q, k, v = self.forward_qkv(query, key, value)
159
+
160
+ # NOTE(xcsong):
161
+ # when export onnx model, for 1st chunk, we feed
162
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
163
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
164
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
165
+ # and we will always do splitting and
166
+ # concatnation(this will simplify onnx export). Note that
167
+ # it's OK to concat & split zero-shaped tensors(see code below).
168
+ # when export jit model, for 1st chunk, we always feed
169
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
170
+ # >>> a = torch.ones((1, 2, 0, 4))
171
+ # >>> b = torch.ones((1, 2, 3, 4))
172
+ # >>> c = torch.cat((a, b), dim=2)
173
+ # >>> torch.equal(b, c) # True
174
+ # >>> d = torch.split(a, 2, dim=-1)
175
+ # >>> torch.equal(d[0], d[1]) # True
176
+ if cache.size(0) > 0:
177
+ key_cache, value_cache = torch.split(
178
+ cache, cache.size(-1) // 2, dim=-1)
179
+ k = torch.cat([key_cache, k], dim=2)
180
+ v = torch.cat([value_cache, v], dim=2)
181
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
182
+ # non-trivial to calculate `next_cache_start` here.
183
+ new_cache = torch.cat((k, v), dim=-1)
184
+
185
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
186
+ return self.forward_attention(v, scores, mask), new_cache
187
+
188
+
189
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
190
+ """Multi-Head Attention layer with relative position encoding.
191
+ Paper: https://arxiv.org/abs/1901.02860
192
+ Args:
193
+ n_head (int): The number of heads.
194
+ n_feat (int): The number of features.
195
+ dropout_rate (float): Dropout rate.
196
+ """
197
+ def __init__(self, n_head, n_feat, dropout_rate):
198
+ """Construct an RelPositionMultiHeadedAttention object."""
199
+ super().__init__(n_head, n_feat, dropout_rate)
200
+ # linear transformation for positional encoding
201
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
202
+ # these two learnable bias are used in matrix c and matrix d
203
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
204
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
205
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
206
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
207
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
208
+
209
+ def rel_shift(self, x, zero_triu: bool = False):
210
+ """Compute relative positinal encoding.
211
+ Args:
212
+ x (torch.Tensor): Input tensor (batch, time, size).
213
+ zero_triu (bool): If true, return the lower triangular part of
214
+ the matrix.
215
+ Returns:
216
+ torch.Tensor: Output tensor.
217
+ """
218
+
219
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
220
+ device=x.device,
221
+ dtype=x.dtype)
222
+ x_padded = torch.cat([zero_pad, x], dim=-1)
223
+
224
+ x_padded = x_padded.view(x.size()[0],
225
+ x.size()[1],
226
+ x.size(3) + 1, x.size(2))
227
+ x = x_padded[:, :, 1:].view_as(x)
228
+
229
+ if zero_triu:
230
+ ones = torch.ones((x.size(2), x.size(3)))
231
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
232
+
233
+ return x
234
+
235
+ def forward(self, query: torch.Tensor,
236
+ key: torch.Tensor, value: torch.Tensor,
237
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238
+ pos_emb: torch.Tensor = torch.empty(0),
239
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
240
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
241
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
242
+ Args:
243
+ query (torch.Tensor): Query tensor (#batch, time1, size).
244
+ key (torch.Tensor): Key tensor (#batch, time2, size).
245
+ value (torch.Tensor): Value tensor (#batch, time2, size).
246
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
247
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
248
+ pos_emb (torch.Tensor): Positional embedding tensor
249
+ (#batch, time2, size).
250
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
251
+ where `cache_t == chunk_size * num_decoding_left_chunks`
252
+ and `head * d_k == size`
253
+ Returns:
254
+ torch.Tensor: Output tensor (#batch, time1, d_model).
255
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
256
+ where `cache_t == chunk_size * num_decoding_left_chunks`
257
+ and `head * d_k == size`
258
+ """
259
+ q, k, v = self.forward_qkv(query, key, value)
260
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
261
+
262
+ # NOTE(xcsong):
263
+ # when export onnx model, for 1st chunk, we feed
264
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
265
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
266
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
267
+ # and we will always do splitting and
268
+ # concatnation(this will simplify onnx export). Note that
269
+ # it's OK to concat & split zero-shaped tensors(see code below).
270
+ # when export jit model, for 1st chunk, we always feed
271
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
272
+ # >>> a = torch.ones((1, 2, 0, 4))
273
+ # >>> b = torch.ones((1, 2, 3, 4))
274
+ # >>> c = torch.cat((a, b), dim=2)
275
+ # >>> torch.equal(b, c) # True
276
+ # >>> d = torch.split(a, 2, dim=-1)
277
+ # >>> torch.equal(d[0], d[1]) # True
278
+ if cache.size(0) > 0:
279
+ key_cache, value_cache = torch.split(
280
+ cache, cache.size(-1) // 2, dim=-1)
281
+ k = torch.cat([key_cache, k], dim=2)
282
+ v = torch.cat([value_cache, v], dim=2)
283
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
284
+ # non-trivial to calculate `next_cache_start` here.
285
+ new_cache = torch.cat((k, v), dim=-1)
286
+
287
+ n_batch_pos = pos_emb.size(0)
288
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
289
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
290
+
291
+ # (batch, head, time1, d_k)
292
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
293
+ # (batch, head, time1, d_k)
294
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
295
+
296
+ # compute attention score
297
+ # first compute matrix a and matrix c
298
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
299
+ # (batch, head, time1, time2)
300
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
301
+
302
+ # compute matrix b and matrix d
303
+ # (batch, head, time1, time2)
304
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
305
+ # Remove rel_shift since it is useless in speech recognition,
306
+ # and it requires special attention for streaming.
307
+ # matrix_bd = self.rel_shift(matrix_bd)
308
+
309
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
310
+ self.d_k) # (batch, head, time1, time2)
311
+
312
+ return self.forward_attention(v, scores, mask), new_cache
@@ -0,0 +1,163 @@
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified from ESPnet(https://github.com/espnet/espnet)
15
+
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ """Positional encoding.
27
+
28
+ :param int d_model: embedding dim
29
+ :param float dropout_rate: dropout rate
30
+ :param int max_len: maximum input length
31
+
32
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
33
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
34
+ """
35
+ def __init__(self,
36
+ d_model: int,
37
+ dropout_rate: float,
38
+ max_len: int = 5000,
39
+ reverse: bool = False):
40
+ """Construct an PositionalEncoding object."""
41
+ super().__init__()
42
+ self.d_model = d_model
43
+ self.xscale = math.sqrt(self.d_model)
44
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
45
+ self.max_len = max_len
46
+
47
+ pe = torch.zeros(self.max_len, self.d_model)
48
+ position = torch.arange(0, self.max_len).unsqueeze(1)
49
+ div_term = torch.exp(
50
+ torch.arange(0, self.d_model, 2) *
51
+ -(math.log(10000.0) / self.d_model))
52
+ pe[:, 0::2] = torch.sin(position * div_term)
53
+ pe[:, 1::2] = torch.cos(position * div_term)
54
+ pe = pe.unsqueeze(0)
55
+ self.register_buffer('pe', pe)
56
+
57
+ def forward(self,
58
+ x: torch.Tensor,
59
+ offset: Union[int, torch.Tensor] = 0) \
60
+ -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """Add positional encoding.
62
+
63
+ Args:
64
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
65
+ offset (int, torch.tensor): position offset
66
+
67
+ Returns:
68
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
69
+ torch.Tensor: for compatibility to RelPositionalEncoding
70
+ """
71
+
72
+ self.pe = self.pe.to(x.device)
73
+ pos_emb = self.position_encoding(offset, x.size(1), False)
74
+ x = x * self.xscale + pos_emb
75
+ return self.dropout(x), self.dropout(pos_emb)
76
+
77
+ def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
78
+ apply_dropout: bool = True) -> torch.Tensor:
79
+ """ For getting encoding in a streaming fashion
80
+
81
+ Attention!!!!!
82
+ we apply dropout only once at the whole utterance level in a none
83
+ streaming way, but will call this function several times with
84
+ increasing input size in a streaming scenario, so the dropout will
85
+ be applied several times.
86
+
87
+ Args:
88
+ offset (int or torch.tensor): start offset
89
+ size (int): required size of position encoding
90
+
91
+ Returns:
92
+ torch.Tensor: Corresponding encoding
93
+ """
94
+ # How to subscript a Union type:
95
+ # https://github.com/pytorch/pytorch/issues/69434
96
+ if isinstance(offset, int):
97
+ assert offset + size < self.max_len
98
+ pos_emb = self.pe[:, offset:offset + size]
99
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
100
+ assert offset + size < self.max_len
101
+ pos_emb = self.pe[:, offset:offset + size]
102
+ else: # for batched streaming decoding on GPU
103
+ assert torch.max(offset) + size < self.max_len
104
+ index = offset.unsqueeze(1) + \
105
+ torch.arange(0, size).to(offset.device) # B X T
106
+ flag = index > 0
107
+ # remove negative offset
108
+ index = index * flag
109
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
110
+
111
+ if apply_dropout:
112
+ pos_emb = self.dropout(pos_emb)
113
+ return pos_emb
114
+
115
+ class RelPositionalEncoding(PositionalEncoding):
116
+ """Relative positional encoding module.
117
+ See : Appendix B in https://arxiv.org/abs/1901.02860
118
+ Args:
119
+ d_model (int): Embedding dimension.
120
+ dropout_rate (float): Dropout rate.
121
+ max_len (int): Maximum input length.
122
+ """
123
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
124
+ """Initialize class."""
125
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
126
+
127
+ def forward(self,
128
+ x: torch.Tensor,
129
+ offset: Union[int, torch.Tensor] = 0) \
130
+ -> Tuple[torch.Tensor, torch.Tensor]:
131
+ """Compute positional encoding.
132
+ Args:
133
+ x (torch.Tensor): Input tensor (batch, time, `*`).
134
+ Returns:
135
+ torch.Tensor: Encoded tensor (batch, time, `*`).
136
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
137
+ """
138
+ self.pe = self.pe.to(x.device)
139
+ x = x * self.xscale
140
+ pos_emb = self.position_encoding(offset, x.size(1), False)
141
+ return self.dropout(x), self.dropout(pos_emb)
142
+
143
+
144
+ class NoPositionalEncoding(torch.nn.Module):
145
+ """ No position encoding
146
+ """
147
+ def __init__(self, d_model: int, dropout_rate: float):
148
+ super().__init__()
149
+ self.d_model = d_model
150
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
151
+
152
+ def forward(self,
153
+ x: torch.Tensor,
154
+ offset: Union[int, torch.Tensor] = 0) \
155
+ -> Tuple[torch.Tensor, torch.Tensor]:
156
+ """ Just return zero vector for interface compatibility
157
+ """
158
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
159
+ return self.dropout(x), pos_emb
160
+
161
+ def position_encoding(
162
+ self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
163
+ return torch.zeros(1, size, self.d_model)