xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,520 @@
1
+
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from indextts.gpt.conformer.attention import (MultiHeadedAttention,
8
+ RelPositionMultiHeadedAttention)
9
+ from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
10
+ PositionalEncoding,
11
+ RelPositionalEncoding)
12
+ from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
13
+ Conv2dSubsampling4,
14
+ Conv2dSubsampling6,
15
+ Conv2dSubsampling8,
16
+ LinearNoSubsampling)
17
+ from indextts.utils.common import make_pad_mask
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(self,
34
+ idim: int,
35
+ hidden_units: int,
36
+ dropout_rate: float,
37
+ activation: torch.nn.Module = torch.nn.ReLU()):
38
+ """Construct a PositionwiseFeedForward object."""
39
+ super(PositionwiseFeedForward, self).__init__()
40
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
41
+ self.activation = activation
42
+ self.dropout = torch.nn.Dropout(dropout_rate)
43
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
44
+
45
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
46
+ """Forward function.
47
+
48
+ Args:
49
+ xs: input tensor (B, L, D)
50
+ Returns:
51
+ output tensor, (B, L, D)
52
+ """
53
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
54
+
55
+
56
+ class ConvolutionModule(nn.Module):
57
+ """ConvolutionModule in Conformer model."""
58
+
59
+ def __init__(self,
60
+ channels: int,
61
+ kernel_size: int = 15,
62
+ activation: nn.Module = nn.ReLU(),
63
+ bias: bool = True):
64
+ """Construct an ConvolutionModule object.
65
+ Args:
66
+ channels (int): The number of channels of conv layers.
67
+ kernel_size (int): Kernel size of conv layers.
68
+ causal (int): Whether use causal convolution or not
69
+ """
70
+ super().__init__()
71
+
72
+ self.pointwise_conv1 = nn.Conv1d(
73
+ channels,
74
+ 2 * channels,
75
+ kernel_size=1,
76
+ stride=1,
77
+ padding=0,
78
+ bias=bias,
79
+ )
80
+ # self.lorder is used to distinguish if it's a causal convolution,
81
+ # if self.lorder > 0: it's a causal convolution, the input will be
82
+ # padded with self.lorder frames on the left in forward.
83
+ # else: it's a symmetrical convolution
84
+ # kernel_size should be an odd number for none causal convolution
85
+ assert (kernel_size - 1) % 2 == 0
86
+ padding = (kernel_size - 1) // 2
87
+ self.lorder = 0
88
+
89
+ self.depthwise_conv = nn.Conv1d(
90
+ channels,
91
+ channels,
92
+ kernel_size,
93
+ stride=1,
94
+ padding=padding,
95
+ groups=channels,
96
+ bias=bias,
97
+ )
98
+
99
+ self.use_layer_norm = True
100
+ self.norm = nn.LayerNorm(channels)
101
+
102
+ self.pointwise_conv2 = nn.Conv1d(
103
+ channels,
104
+ channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0,
108
+ bias=bias,
109
+ )
110
+ self.activation = activation
111
+
112
+ def forward(
113
+ self,
114
+ x: torch.Tensor,
115
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
116
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Compute convolution module.
119
+ Args:
120
+ x (torch.Tensor): Input tensor (#batch, time, channels).
121
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
122
+ (0, 0, 0) means fake mask.
123
+ cache (torch.Tensor): left context cache, it is only
124
+ used in causal convolution (#batch, channels, cache_t),
125
+ (0, 0, 0) meas fake cache.
126
+ Returns:
127
+ torch.Tensor: Output tensor (#batch, time, channels).
128
+ """
129
+ # exchange the temporal dimension and the feature dimension
130
+ x = x.transpose(1, 2) # (#batch, channels, time)
131
+
132
+ # mask batch padding
133
+ if mask_pad.size(2) > 0: # time > 0
134
+ x.masked_fill_(~mask_pad, 0.0)
135
+
136
+ if self.lorder > 0:
137
+ if cache.size(2) == 0: # cache_t == 0
138
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
139
+ else:
140
+ assert cache.size(0) == x.size(0) # equal batch
141
+ assert cache.size(1) == x.size(1) # equal channel
142
+ x = torch.cat((cache, x), dim=2)
143
+ assert (x.size(2) > self.lorder)
144
+ new_cache = x[:, :, -self.lorder:]
145
+ else:
146
+ # It's better we just return None if no cache is required,
147
+ # However, for JIT export, here we just fake one tensor instead of
148
+ # None.
149
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
150
+
151
+ # GLU mechanism
152
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
153
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
154
+
155
+ # 1D Depthwise Conv
156
+ x = self.depthwise_conv(x)
157
+ if self.use_layer_norm:
158
+ x = x.transpose(1, 2)
159
+ x = self.activation(self.norm(x))
160
+ if self.use_layer_norm:
161
+ x = x.transpose(1, 2)
162
+ x = self.pointwise_conv2(x)
163
+ # mask batch padding
164
+ if mask_pad.size(2) > 0: # time > 0
165
+ x.masked_fill_(~mask_pad, 0.0)
166
+
167
+ return x.transpose(1, 2), new_cache
168
+
169
+
170
+ class ConformerEncoderLayer(nn.Module):
171
+ """Encoder layer module.
172
+ Args:
173
+ size (int): Input dimension.
174
+ self_attn (torch.nn.Module): Self-attention module instance.
175
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
176
+ instance can be used as the argument.
177
+ feed_forward (torch.nn.Module): Feed-forward module instance.
178
+ `PositionwiseFeedForward` instance can be used as the argument.
179
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
180
+ instance.
181
+ `PositionwiseFeedForward` instance can be used as the argument.
182
+ conv_module (torch.nn.Module): Convolution module instance.
183
+ `ConvlutionModule` instance can be used as the argument.
184
+ dropout_rate (float): Dropout rate.
185
+ normalize_before (bool):
186
+ True: use layer_norm before each sub-block.
187
+ False: use layer_norm after each sub-block.
188
+ concat_after (bool): Whether to concat attention layer's input and
189
+ output.
190
+ True: x -> x + linear(concat(x, att(x)))
191
+ False: x -> x + att(x)
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ size: int,
197
+ self_attn: torch.nn.Module,
198
+ feed_forward: Optional[nn.Module] = None,
199
+ feed_forward_macaron: Optional[nn.Module] = None,
200
+ conv_module: Optional[nn.Module] = None,
201
+ dropout_rate: float = 0.1,
202
+ normalize_before: bool = True,
203
+ concat_after: bool = False,
204
+ ):
205
+ """Construct an EncoderLayer object."""
206
+ super().__init__()
207
+ self.self_attn = self_attn
208
+ self.feed_forward = feed_forward
209
+ self.feed_forward_macaron = feed_forward_macaron
210
+ self.conv_module = conv_module
211
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
212
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
213
+ if feed_forward_macaron is not None:
214
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
215
+ self.ff_scale = 0.5
216
+ else:
217
+ self.ff_scale = 1.0
218
+ if self.conv_module is not None:
219
+ self.norm_conv = nn.LayerNorm(size,
220
+ eps=1e-5) # for the CNN module
221
+ self.norm_final = nn.LayerNorm(
222
+ size, eps=1e-5) # for the final output of the block
223
+ self.dropout = nn.Dropout(dropout_rate)
224
+ self.size = size
225
+ self.normalize_before = normalize_before
226
+ self.concat_after = concat_after
227
+ if self.concat_after:
228
+ self.concat_linear = nn.Linear(size + size, size)
229
+ else:
230
+ self.concat_linear = nn.Identity()
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ mask: torch.Tensor,
236
+ pos_emb: torch.Tensor,
237
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
238
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
239
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
241
+ """Compute encoded features.
242
+
243
+ Args:
244
+ x (torch.Tensor): (#batch, time, size)
245
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
246
+ (0, 0, 0) means fake mask.
247
+ pos_emb (torch.Tensor): positional encoding, must not be None
248
+ for ConformerEncoderLayer.
249
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
250
+ (#batch, 1,time), (0, 0, 0) means fake mask.
251
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
252
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
253
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
254
+ (#batch=1, size, cache_t2)
255
+ Returns:
256
+ torch.Tensor: Output tensor (#batch, time, size).
257
+ torch.Tensor: Mask tensor (#batch, time, time).
258
+ torch.Tensor: att_cache tensor,
259
+ (#batch=1, head, cache_t1 + time, d_k * 2).
260
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
261
+ """
262
+
263
+ # whether to use macaron style
264
+ if self.feed_forward_macaron is not None:
265
+ residual = x
266
+ if self.normalize_before:
267
+ x = self.norm_ff_macaron(x)
268
+ x = residual + self.ff_scale * self.dropout(
269
+ self.feed_forward_macaron(x))
270
+ if not self.normalize_before:
271
+ x = self.norm_ff_macaron(x)
272
+
273
+ # multi-headed self-attention module
274
+ residual = x
275
+ if self.normalize_before:
276
+ x = self.norm_mha(x)
277
+
278
+ x_att, new_att_cache = self.self_attn(
279
+ x, x, x, mask, pos_emb, att_cache)
280
+ if self.concat_after:
281
+ x_concat = torch.cat((x, x_att), dim=-1)
282
+ x = residual + self.concat_linear(x_concat)
283
+ else:
284
+ x = residual + self.dropout(x_att)
285
+ if not self.normalize_before:
286
+ x = self.norm_mha(x)
287
+
288
+ # convolution module
289
+ # Fake new cnn cache here, and then change it in conv_module
290
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
291
+ if self.conv_module is not None:
292
+ residual = x
293
+ if self.normalize_before:
294
+ x = self.norm_conv(x)
295
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
296
+ x = residual + self.dropout(x)
297
+
298
+ if not self.normalize_before:
299
+ x = self.norm_conv(x)
300
+
301
+ # feed forward module
302
+ residual = x
303
+ if self.normalize_before:
304
+ x = self.norm_ff(x)
305
+
306
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
307
+ if not self.normalize_before:
308
+ x = self.norm_ff(x)
309
+
310
+ if self.conv_module is not None:
311
+ x = self.norm_final(x)
312
+
313
+ return x, mask, new_att_cache, new_cnn_cache
314
+
315
+
316
+ class BaseEncoder(torch.nn.Module):
317
+ def __init__(
318
+ self,
319
+ input_size: int,
320
+ output_size: int = 256,
321
+ attention_heads: int = 4,
322
+ linear_units: int = 2048,
323
+ num_blocks: int = 6,
324
+ dropout_rate: float = 0.0,
325
+ input_layer: str = "conv2d",
326
+ pos_enc_layer_type: str = "abs_pos",
327
+ normalize_before: bool = True,
328
+ concat_after: bool = False,
329
+ ):
330
+ """
331
+ Args:
332
+ input_size (int): input dim
333
+ output_size (int): dimension of attention
334
+ attention_heads (int): the number of heads of multi head attention
335
+ linear_units (int): the hidden units number of position-wise feed
336
+ forward
337
+ num_blocks (int): the number of decoder blocks
338
+ dropout_rate (float): dropout rate
339
+ attention_dropout_rate (float): dropout rate in attention
340
+ positional_dropout_rate (float): dropout rate after adding
341
+ positional encoding
342
+ input_layer (str): input layer type.
343
+ optional [linear, conv2d, conv2d6, conv2d8]
344
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
345
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
346
+ normalize_before (bool):
347
+ True: use layer_norm before each sub-block of a layer.
348
+ False: use layer_norm after each sub-block of a layer.
349
+ concat_after (bool): whether to concat attention layer's input
350
+ and output.
351
+ True: x -> x + linear(concat(x, att(x)))
352
+ False: x -> x + att(x)
353
+ static_chunk_size (int): chunk size for static chunk training and
354
+ decoding
355
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
356
+ training or not, You can only use fixed chunk(chunk_size > 0)
357
+ or dyanmic chunk size(use_dynamic_chunk = True)
358
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
359
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
360
+ dynamic chunk training
361
+ """
362
+ super().__init__()
363
+ self._output_size = output_size
364
+
365
+ if pos_enc_layer_type == "abs_pos":
366
+ pos_enc_class = PositionalEncoding
367
+ elif pos_enc_layer_type == "rel_pos":
368
+ pos_enc_class = RelPositionalEncoding
369
+ elif pos_enc_layer_type == "no_pos":
370
+ pos_enc_class = NoPositionalEncoding
371
+ else:
372
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
373
+
374
+ if input_layer == "linear":
375
+ subsampling_class = LinearNoSubsampling
376
+ elif input_layer == "conv2d2":
377
+ subsampling_class = Conv2dSubsampling2
378
+ elif input_layer == "conv2d":
379
+ subsampling_class = Conv2dSubsampling4
380
+ elif input_layer == "conv2d6":
381
+ subsampling_class = Conv2dSubsampling6
382
+ elif input_layer == "conv2d8":
383
+ subsampling_class = Conv2dSubsampling8
384
+ else:
385
+ raise ValueError("unknown input_layer: " + input_layer)
386
+
387
+ self.embed = subsampling_class(
388
+ input_size,
389
+ output_size,
390
+ dropout_rate,
391
+ pos_enc_class(output_size, dropout_rate),
392
+ )
393
+
394
+ self.normalize_before = normalize_before
395
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
396
+
397
+ def output_size(self) -> int:
398
+ return self._output_size
399
+
400
+ def forward(
401
+ self,
402
+ xs: torch.Tensor,
403
+ xs_lens: torch.Tensor,
404
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
405
+ """Embed positions in tensor.
406
+
407
+ Args:
408
+ xs: padded input tensor (B, T, D)
409
+ xs_lens: input length (B)
410
+ decoding_chunk_size: decoding chunk size for dynamic chunk
411
+ 0: default for training, use random dynamic chunk.
412
+ <0: for decoding, use full chunk.
413
+ >0: for decoding, use fixed chunk size as set.
414
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
415
+ the chunk size is decoding_chunk_size.
416
+ >=0: use num_decoding_left_chunks
417
+ <0: use all left chunks
418
+ Returns:
419
+ encoder output tensor xs, and subsampled masks
420
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
421
+ masks: torch.Tensor batch padding mask after subsample
422
+ (B, 1, T' ~= T/subsample_rate)
423
+ """
424
+ T = xs.size(1)
425
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
426
+ xs, pos_emb, masks = self.embed(xs, masks)
427
+ chunk_masks = masks
428
+ mask_pad = masks # (B, 1, T/subsample_rate)
429
+ for layer in self.encoders:
430
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
431
+ if self.normalize_before:
432
+ xs = self.after_norm(xs)
433
+ # Here we assume the mask is not changed in encoder layers, so just
434
+ # return the masks before encoder layers, and the masks will be used
435
+ # for cross attention with decoder later
436
+ return xs, masks
437
+
438
+
439
+ class ConformerEncoder(BaseEncoder):
440
+ """Conformer encoder module."""
441
+
442
+ def __init__(
443
+ self,
444
+ input_size: int,
445
+ output_size: int = 256,
446
+ attention_heads: int = 4,
447
+ linear_units: int = 2048,
448
+ num_blocks: int = 6,
449
+ dropout_rate: float = 0.0,
450
+ input_layer: str = "conv2d",
451
+ pos_enc_layer_type: str = "rel_pos",
452
+ normalize_before: bool = True,
453
+ concat_after: bool = False,
454
+ macaron_style: bool = False,
455
+ use_cnn_module: bool = True,
456
+ cnn_module_kernel: int = 15,
457
+ ):
458
+ """Construct ConformerEncoder
459
+
460
+ Args:
461
+ input_size to use_dynamic_chunk, see in BaseEncoder
462
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
463
+ conv1d layer.
464
+ macaron_style (bool): Whether to use macaron style for
465
+ positionwise layer.
466
+ selfattention_layer_type (str): Encoder attention layer type,
467
+ the parameter has no effect now, it's just for configure
468
+ compatibility.
469
+ activation_type (str): Encoder activation function type.
470
+ use_cnn_module (bool): Whether to use convolution module.
471
+ cnn_module_kernel (int): Kernel size of convolution module.
472
+ causal (bool): whether to use causal convolution or not.
473
+ """
474
+
475
+ super().__init__(input_size, output_size, attention_heads,
476
+ linear_units, num_blocks, dropout_rate,
477
+ input_layer, pos_enc_layer_type, normalize_before,
478
+ concat_after)
479
+
480
+ activation = torch.nn.SiLU()
481
+
482
+ # self-attention module definition
483
+ if pos_enc_layer_type != "rel_pos":
484
+ encoder_selfattn_layer = MultiHeadedAttention
485
+ else:
486
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
487
+ encoder_selfattn_layer_args = (
488
+ attention_heads,
489
+ output_size,
490
+ dropout_rate,
491
+ )
492
+
493
+ # feed-forward module definition
494
+ positionwise_layer = PositionwiseFeedForward
495
+ positionwise_layer_args = (
496
+ output_size,
497
+ linear_units,
498
+ dropout_rate,
499
+ activation,
500
+ )
501
+ # convolution module definition
502
+ convolution_layer = ConvolutionModule
503
+ convolution_layer_args = (output_size,
504
+ cnn_module_kernel,
505
+ activation,)
506
+
507
+ self.encoders = torch.nn.ModuleList([
508
+ ConformerEncoderLayer(
509
+ output_size,
510
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
511
+ positionwise_layer(*positionwise_layer_args),
512
+ positionwise_layer(
513
+ *positionwise_layer_args) if macaron_style else None,
514
+ convolution_layer(
515
+ *convolution_layer_args) if use_cnn_module else None,
516
+ dropout_rate,
517
+ normalize_before,
518
+ concat_after,
519
+ ) for _ in range(num_blocks)
520
+ ])