xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.autograd import Function
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class GradientReversal(Function):
12
+ @staticmethod
13
+ def forward(ctx, x, alpha):
14
+ ctx.save_for_backward(x, alpha)
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, grad_output):
19
+ grad_input = None
20
+ _, alpha = ctx.saved_tensors
21
+ if ctx.needs_input_grad[0]:
22
+ grad_input = -alpha * grad_output
23
+ return grad_input, None
24
+
25
+
26
+ revgrad = GradientReversal.apply
27
+
28
+
29
+ class GradientReversal(nn.Module):
30
+ def __init__(self, alpha):
31
+ super().__init__()
32
+ self.alpha = torch.tensor(alpha, requires_grad=False)
33
+
34
+ def forward(self, x):
35
+ return revgrad(x, self.alpha)
@@ -0,0 +1,102 @@
1
+ import torch
2
+ import pyworld as pw
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import os
6
+ from torchaudio.functional import pitch_shift
7
+ import librosa
8
+ from librosa.filters import mel as librosa_mel_fn
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
14
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
15
+
16
+
17
+ def dynamic_range_decompression(x, C=1):
18
+ return np.exp(x) / C
19
+
20
+
21
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
22
+ return torch.log(torch.clamp(x, min=clip_val) * C)
23
+
24
+
25
+ def dynamic_range_decompression_torch(x, C=1):
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ output = dynamic_range_compression_torch(magnitudes)
31
+ return output
32
+
33
+
34
+ def spectral_de_normalize_torch(magnitudes):
35
+ output = dynamic_range_decompression_torch(magnitudes)
36
+ return output
37
+
38
+
39
+ class MelSpectrogram(nn.Module):
40
+ def __init__(
41
+ self,
42
+ n_fft,
43
+ num_mels,
44
+ sampling_rate,
45
+ hop_size,
46
+ win_size,
47
+ fmin,
48
+ fmax,
49
+ center=False,
50
+ ):
51
+ super(MelSpectrogram, self).__init__()
52
+ self.n_fft = n_fft
53
+ self.hop_size = hop_size
54
+ self.win_size = win_size
55
+ self.sampling_rate = sampling_rate
56
+ self.num_mels = num_mels
57
+ self.fmin = fmin
58
+ self.fmax = fmax
59
+ self.center = center
60
+
61
+ mel_basis = {}
62
+ hann_window = {}
63
+
64
+ mel = librosa_mel_fn(
65
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
66
+ )
67
+ mel_basis = torch.from_numpy(mel).float()
68
+ hann_window = torch.hann_window(win_size)
69
+
70
+ self.register_buffer("mel_basis", mel_basis)
71
+ self.register_buffer("hann_window", hann_window)
72
+
73
+ def forward(self, y):
74
+ y = torch.nn.functional.pad(
75
+ y.unsqueeze(1),
76
+ (
77
+ int((self.n_fft - self.hop_size) / 2),
78
+ int((self.n_fft - self.hop_size) / 2),
79
+ ),
80
+ mode="reflect",
81
+ )
82
+ y = y.squeeze(1)
83
+ spec = torch.stft(
84
+ y,
85
+ self.n_fft,
86
+ hop_length=self.hop_size,
87
+ win_length=self.win_size,
88
+ window=self.hann_window,
89
+ center=self.center,
90
+ pad_mode="reflect",
91
+ normalized=False,
92
+ onesided=True,
93
+ return_complex=True,
94
+ )
95
+ spec = torch.view_as_real(spec)
96
+
97
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
98
+
99
+ spec = torch.matmul(self.mel_basis, spec)
100
+ spec = spectral_normalize_torch(spec)
101
+
102
+ return spec
@@ -0,0 +1,7 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .fvq import *
7
+ from .rvq import *
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch.nn.utils import weight_norm
14
+
15
+
16
+ class FactorizedVectorQuantize(nn.Module):
17
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
18
+ super().__init__()
19
+ self.codebook_size = codebook_size
20
+ self.codebook_dim = codebook_dim
21
+ self.commitment = commitment
22
+
23
+ if dim != self.codebook_dim:
24
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
25
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
26
+ else:
27
+ self.in_proj = nn.Identity()
28
+ self.out_proj = nn.Identity()
29
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
30
+
31
+ @property
32
+ def codebook(self):
33
+ return self._codebook
34
+
35
+ def forward(self, z):
36
+ """Quantized the input tensor using a fixed codebook and returns
37
+ the corresponding codebook vectors
38
+
39
+ Parameters
40
+ ----------
41
+ z : Tensor[B x D x T]
42
+
43
+ Returns
44
+ -------
45
+ Tensor[B x D x T]
46
+ Quantized continuous representation of input
47
+ Tensor[1]
48
+ Commitment loss to train encoder to predict vectors closer to codebook
49
+ entries
50
+ Tensor[1]
51
+ Codebook loss to update the codebook
52
+ Tensor[B x T]
53
+ Codebook indices (quantized discrete representation of input)
54
+ Tensor[B x D x T]
55
+ Projected latents (continuous representation of input before quantization)
56
+ """
57
+ # transpose since we use linear
58
+
59
+ z = rearrange(z, "b d t -> b t d")
60
+
61
+ # Factorized codes project input into low-dimensional space
62
+ z_e = self.in_proj(z) # z_e : (B x T x D)
63
+ z_e = rearrange(z_e, "b t d -> b d t")
64
+ z_q, indices = self.decode_latents(z_e)
65
+
66
+ if self.training:
67
+ commitment_loss = (
68
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
69
+ * self.commitment
70
+ )
71
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
72
+ commit_loss = commitment_loss + codebook_loss
73
+ else:
74
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
75
+
76
+ z_q = (
77
+ z_e + (z_q - z_e).detach()
78
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
79
+
80
+ z_q = rearrange(z_q, "b d t -> b t d")
81
+ z_q = self.out_proj(z_q)
82
+ z_q = rearrange(z_q, "b t d -> b d t")
83
+
84
+ return z_q, indices, commit_loss
85
+
86
+ def vq2emb(self, vq, proj=True):
87
+ emb = self.embed_code(vq)
88
+ if proj:
89
+ emb = self.out_proj(emb)
90
+ return emb.transpose(1, 2)
91
+
92
+ def get_emb(self):
93
+ return self.codebook.weight
94
+
95
+ def embed_code(self, embed_id):
96
+ return F.embedding(embed_id, self.codebook.weight)
97
+
98
+ def decode_code(self, embed_id):
99
+ return self.embed_code(embed_id).transpose(1, 2)
100
+
101
+ def decode_latents(self, latents):
102
+ encodings = rearrange(latents, "b d t -> (b t) d")
103
+ codebook = self.codebook.weight # codebook: (N x D)
104
+ # L2 normalize encodings and codebook
105
+ encodings = F.normalize(encodings)
106
+ codebook = F.normalize(codebook)
107
+
108
+ # Compute euclidean distance with codebook
109
+ dist = (
110
+ encodings.pow(2).sum(1, keepdim=True)
111
+ - 2 * encodings @ codebook.t()
112
+ + codebook.pow(2).sum(1, keepdim=True).t()
113
+ )
114
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
115
+ z_q = self.decode_code(indices)
116
+ return z_q, indices
@@ -0,0 +1,87 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ from torch import nn
9
+ from .fvq import FactorizedVectorQuantize
10
+
11
+
12
+ class ResidualVQ(nn.Module):
13
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
14
+
15
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
16
+ super().__init__()
17
+ VQ = FactorizedVectorQuantize
18
+ if type(codebook_size) == int:
19
+ codebook_size = [codebook_size] * num_quantizers
20
+ self.layers = nn.ModuleList(
21
+ [VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
22
+ )
23
+ self.num_quantizers = num_quantizers
24
+ self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
25
+ self.dropout_type = kwargs.get("dropout_type", None)
26
+
27
+ def forward(self, x, n_quantizers=None):
28
+ quantized_out = 0.0
29
+ residual = x
30
+
31
+ all_losses = []
32
+ all_indices = []
33
+ all_quantized = []
34
+
35
+ if n_quantizers is None:
36
+ n_quantizers = self.num_quantizers
37
+ if self.training:
38
+ n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
39
+ if self.dropout_type == "linear":
40
+ dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
41
+ elif self.dropout_type == "exp":
42
+ dropout = torch.randint(
43
+ 1, int(math.log2(self.num_quantizers)), (x.shape[0],)
44
+ )
45
+ dropout = torch.pow(2, dropout)
46
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
47
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
48
+ n_quantizers = n_quantizers.to(x.device)
49
+
50
+ for idx, layer in enumerate(self.layers):
51
+ if not self.training and idx >= n_quantizers:
52
+ break
53
+ quantized, indices, loss = layer(residual)
54
+
55
+ mask = (
56
+ torch.full((x.shape[0],), fill_value=idx, device=x.device)
57
+ < n_quantizers
58
+ )
59
+
60
+ residual = residual - quantized
61
+
62
+ quantized_out = quantized_out + quantized * mask[:, None, None]
63
+
64
+ # loss
65
+ loss = (loss * mask).mean()
66
+
67
+ all_indices.append(indices)
68
+ all_losses.append(loss)
69
+ all_quantized.append(quantized)
70
+ all_losses, all_indices, all_quantized = map(
71
+ torch.stack, (all_losses, all_indices, all_quantized)
72
+ )
73
+ return quantized_out, all_indices, all_losses, all_quantized
74
+
75
+ def vq2emb(self, vq):
76
+ # vq: [n_quantizers, B, T]
77
+ quantized_out = 0.0
78
+ for idx, layer in enumerate(self.layers):
79
+ quantized = layer.vq2emb(vq[idx])
80
+ quantized_out += quantized
81
+ return quantized_out
82
+
83
+ def get_emb(self):
84
+ embs = []
85
+ for idx, layer in enumerate(self.layers):
86
+ embs.append(layer.get_emb())
87
+ return embs
@@ -0,0 +1,234 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import math
10
+ from torch.nn import functional as F
11
+
12
+
13
+ class StyleAdaptiveLayerNorm(nn.Module):
14
+ def __init__(self, normalized_shape, eps=1e-5):
15
+ super().__init__()
16
+ self.in_dim = normalized_shape
17
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
18
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
19
+ self.style.bias.data[: self.in_dim] = 1
20
+ self.style.bias.data[self.in_dim :] = 0
21
+
22
+ def forward(self, x, condition):
23
+ # x: (B, T, d); condition: (B, T, d)
24
+
25
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
26
+
27
+ gamma, beta = style.chunk(2, -1)
28
+
29
+ out = self.norm(x)
30
+
31
+ out = gamma * out + beta
32
+ return out
33
+
34
+
35
+ class PositionalEncoding(nn.Module):
36
+ def __init__(self, d_model, dropout, max_len=5000):
37
+ super().__init__()
38
+
39
+ self.dropout = dropout
40
+ position = torch.arange(max_len).unsqueeze(1)
41
+ div_term = torch.exp(
42
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
43
+ )
44
+ pe = torch.zeros(max_len, 1, d_model)
45
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
46
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
47
+ self.register_buffer("pe", pe)
48
+
49
+ def forward(self, x):
50
+ x = x + self.pe[: x.size(0)]
51
+ return F.dropout(x, self.dropout, training=self.training)
52
+
53
+
54
+ class TransformerFFNLayer(nn.Module):
55
+ def __init__(
56
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
57
+ ):
58
+ super().__init__()
59
+
60
+ self.encoder_hidden = encoder_hidden
61
+ self.conv_filter_size = conv_filter_size
62
+ self.conv_kernel_size = conv_kernel_size
63
+ self.encoder_dropout = encoder_dropout
64
+
65
+ self.ffn_1 = nn.Conv1d(
66
+ self.encoder_hidden,
67
+ self.conv_filter_size,
68
+ self.conv_kernel_size,
69
+ padding=self.conv_kernel_size // 2,
70
+ )
71
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
72
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
73
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
74
+
75
+ def forward(self, x):
76
+ # x: (B, T, d)
77
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
78
+ 0, 2, 1
79
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
80
+ x = F.relu(x)
81
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
82
+ x = self.ffn_2(x)
83
+ return x
84
+
85
+
86
+ class TransformerEncoderLayer(nn.Module):
87
+ def __init__(
88
+ self,
89
+ encoder_hidden,
90
+ encoder_head,
91
+ conv_filter_size,
92
+ conv_kernel_size,
93
+ encoder_dropout,
94
+ use_cln,
95
+ ):
96
+ super().__init__()
97
+ self.encoder_hidden = encoder_hidden
98
+ self.encoder_head = encoder_head
99
+ self.conv_filter_size = conv_filter_size
100
+ self.conv_kernel_size = conv_kernel_size
101
+ self.encoder_dropout = encoder_dropout
102
+ self.use_cln = use_cln
103
+
104
+ if not self.use_cln:
105
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
106
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
107
+ else:
108
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
109
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
110
+
111
+ self.self_attn = nn.MultiheadAttention(
112
+ self.encoder_hidden, self.encoder_head, batch_first=True
113
+ )
114
+
115
+ self.ffn = TransformerFFNLayer(
116
+ self.encoder_hidden,
117
+ self.conv_filter_size,
118
+ self.conv_kernel_size,
119
+ self.encoder_dropout,
120
+ )
121
+
122
+ def forward(self, x, key_padding_mask, conditon=None):
123
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
124
+
125
+ # self attention
126
+ residual = x
127
+ if self.use_cln:
128
+ x = self.ln_1(x, conditon)
129
+ else:
130
+ x = self.ln_1(x)
131
+
132
+ if key_padding_mask != None:
133
+ key_padding_mask_input = ~(key_padding_mask.bool())
134
+ else:
135
+ key_padding_mask_input = None
136
+ x, _ = self.self_attn(
137
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
138
+ )
139
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
140
+ x = residual + x
141
+
142
+ # ffn
143
+ residual = x
144
+ if self.use_cln:
145
+ x = self.ln_2(x, conditon)
146
+ else:
147
+ x = self.ln_2(x)
148
+ x = self.ffn(x)
149
+ x = residual + x
150
+
151
+ return x
152
+
153
+
154
+ class TransformerEncoder(nn.Module):
155
+ def __init__(
156
+ self,
157
+ enc_emb_tokens=None,
158
+ encoder_layer=4,
159
+ encoder_hidden=256,
160
+ encoder_head=4,
161
+ conv_filter_size=1024,
162
+ conv_kernel_size=5,
163
+ encoder_dropout=0.1,
164
+ use_cln=False,
165
+ cfg=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.encoder_layer = (
170
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
171
+ )
172
+ self.encoder_hidden = (
173
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
174
+ )
175
+ self.encoder_head = (
176
+ encoder_head if encoder_head is not None else cfg.encoder_head
177
+ )
178
+ self.conv_filter_size = (
179
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
180
+ )
181
+ self.conv_kernel_size = (
182
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
183
+ )
184
+ self.encoder_dropout = (
185
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
186
+ )
187
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
188
+
189
+ if enc_emb_tokens != None:
190
+ self.use_enc_emb = True
191
+ self.enc_emb_tokens = enc_emb_tokens
192
+ else:
193
+ self.use_enc_emb = False
194
+
195
+ self.position_emb = PositionalEncoding(
196
+ self.encoder_hidden, self.encoder_dropout
197
+ )
198
+
199
+ self.layers = nn.ModuleList([])
200
+ self.layers.extend(
201
+ [
202
+ TransformerEncoderLayer(
203
+ self.encoder_hidden,
204
+ self.encoder_head,
205
+ self.conv_filter_size,
206
+ self.conv_kernel_size,
207
+ self.encoder_dropout,
208
+ self.use_cln,
209
+ )
210
+ for i in range(self.encoder_layer)
211
+ ]
212
+ )
213
+
214
+ if self.use_cln:
215
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
216
+ else:
217
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
218
+
219
+ def forward(self, x, key_padding_mask, condition=None):
220
+ if len(x.shape) == 2 and self.use_enc_emb:
221
+ x = self.enc_emb_tokens(x)
222
+ x = self.position_emb(x)
223
+ else:
224
+ x = self.position_emb(x) # (B, T, d)
225
+
226
+ for layer in self.layers:
227
+ x = layer(x, key_padding_mask, condition)
228
+
229
+ if self.use_cln:
230
+ x = self.last_ln(x, condition)
231
+ else:
232
+ x = self.last_ln(x)
233
+
234
+ return x