xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,713 @@
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import transformers
8
+ from transformers import GPT2Config, LogitsProcessorList
9
+ from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
10
+
11
+ # from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
12
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
+ from transformers.utils.model_parallel_utils import (assert_device_map,
14
+ get_device_map)
15
+
16
+ from indextts.gpt.conformer_encoder import ConformerEncoder
17
+ from indextts.gpt.perceiver import PerceiverResampler
18
+ from indextts.utils.arch_util import AttentionBlock
19
+ from indextts.utils.typical_sampling import TypicalLogitsWarper
20
+
21
+
22
+ def null_position_embeddings(range, dim):
23
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
24
+
25
+
26
+ class ResBlock(nn.Module):
27
+ """
28
+ Basic residual convolutional block that uses GroupNorm.
29
+ """
30
+
31
+ def __init__(self, chan):
32
+ super().__init__()
33
+ self.net = nn.Sequential(
34
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
35
+ nn.GroupNorm(chan // 8, chan),
36
+ nn.ReLU(),
37
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
38
+ nn.GroupNorm(chan // 8, chan)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return F.relu(self.net(x) + x)
43
+
44
+
45
+ class GPT2InferenceModel(GPT2PreTrainedModel):
46
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
47
+ super().__init__(config)
48
+ # Note: the argument named `text_pos_emb` here actually represents the mel position embedding
49
+ self.transformer = gpt
50
+ self.text_pos_embedding = text_pos_emb
51
+ self.embeddings = embeddings
52
+ self.final_norm = norm
53
+ self.lm_head = nn.Sequential(norm, linear)
54
+ self.kv_cache = kv_cache
55
+
56
+ # Model parallel
57
+ self.model_parallel = False
58
+ self.device_map = None
59
+ self.cached_mel_emb = None
60
+
61
+ def parallelize(self, device_map=None):
62
+ self.device_map = (
63
+ get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
64
+ if device_map is None
65
+ else device_map
66
+ )
67
+ assert_device_map(self.device_map, len(self.transformer.h))
68
+ self.transformer.parallelize(self.device_map)
69
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
70
+ self.model_parallel = True
71
+
72
+ def deparallelize(self):
73
+ self.transformer.deparallelize()
74
+ self.transformer = self.transformer.to("cpu")
75
+ self.lm_head = self.lm_head.to("cpu")
76
+ self.model_parallel = False
77
+ torch.cuda.empty_cache()
78
+ if torch.backends.mps.is_available():
79
+ torch.mps.empty_cache()
80
+
81
+ def get_output_embeddings(self):
82
+ return self.lm_head
83
+
84
+ def set_output_embeddings(self, new_embeddings):
85
+ self.lm_head = new_embeddings
86
+
87
+ def store_mel_emb(self, mel_emb):
88
+ self.cached_mel_emb = mel_emb
89
+
90
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
91
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
92
+ if not self.kv_cache:
93
+ past_key_values = None
94
+ # only last token for inputs_ids if past is defined in kwargs
95
+ if past_key_values:
96
+ input_ids = input_ids[:, -1].unsqueeze(-1)
97
+ if token_type_ids is not None:
98
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
99
+
100
+ attention_mask = kwargs.get("attention_mask", None)
101
+ position_ids = kwargs.get("position_ids", None)
102
+
103
+ if attention_mask is not None and position_ids is None:
104
+ # create position_ids on the fly for batch generation
105
+ position_ids = attention_mask.long().cumsum(-1) - 1
106
+ position_ids.masked_fill_(attention_mask == 0, 0)
107
+ if past_key_values:
108
+ position_ids = position_ids[:, -1].unsqueeze(-1)
109
+ else:
110
+ position_ids = None
111
+ return {
112
+ "input_ids": input_ids,
113
+ "past_key_values": past_key_values,
114
+ "use_cache": kwargs.get("use_cache"),
115
+ "position_ids": position_ids,
116
+ "attention_mask": attention_mask,
117
+ "token_type_ids": token_type_ids,
118
+ }
119
+
120
+ def forward(
121
+ self,
122
+ input_ids=None,
123
+ past_key_values=None,
124
+ attention_mask=None,
125
+ token_type_ids=None,
126
+ position_ids=None,
127
+ head_mask=None,
128
+ inputs_embeds=None,
129
+ encoder_hidden_states=None,
130
+ encoder_attention_mask=None,
131
+ labels=None,
132
+ use_cache=None,
133
+ output_attentions=None,
134
+ output_hidden_states=None,
135
+ return_dict=None,
136
+ ):
137
+ assert self.cached_mel_emb is not None
138
+ assert inputs_embeds is None # Not supported by this inference model.
139
+ assert labels is None # Training not supported by this inference model.
140
+ return_dict = (
141
+ return_dict if return_dict is not None else self.config.use_return_dict
142
+ )
143
+ # Create embedding
144
+ mel_len = self.cached_mel_emb.shape[1]
145
+ if input_ids.shape[1] != 1:
146
+ text_inputs = input_ids[:, mel_len:]
147
+ text_emb = self.embeddings(text_inputs)
148
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
149
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
150
+ mel_emb = self.cached_mel_emb.repeat_interleave(
151
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
152
+ )
153
+ else: # this outcome only occurs once per loop in most cases
154
+ mel_emb = self.cached_mel_emb
155
+ emb = torch.cat([mel_emb, text_emb], dim=1)
156
+ else:
157
+ emb = self.embeddings(input_ids)
158
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
159
+ attention_mask.shape[1] - mel_len, attention_mask.device
160
+ )
161
+ transformer_outputs = self.transformer(
162
+ inputs_embeds=emb,
163
+ past_key_values=past_key_values,
164
+ attention_mask=attention_mask,
165
+ token_type_ids=token_type_ids,
166
+ position_ids=position_ids,
167
+ head_mask=head_mask,
168
+ encoder_hidden_states=encoder_hidden_states,
169
+ encoder_attention_mask=encoder_attention_mask,
170
+ use_cache=use_cache,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ return_dict=return_dict,
174
+ )
175
+ hidden_states = transformer_outputs[0]
176
+
177
+ # Set device for model parallelism
178
+ if self.model_parallel:
179
+ if torch.backends.mps.is_available():
180
+ self.to(self.transformer.first_device)
181
+ else:
182
+ torch.cuda.set_device(self.transformer.first_device)
183
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
184
+
185
+ lm_logits = self.lm_head(hidden_states)
186
+
187
+ if not return_dict:
188
+ return (lm_logits,) + transformer_outputs[1:]
189
+
190
+ return CausalLMOutputWithCrossAttentions(
191
+ loss=None,
192
+ logits=lm_logits,
193
+ past_key_values=transformer_outputs.past_key_values,
194
+ hidden_states=transformer_outputs.hidden_states,
195
+ attentions=transformer_outputs.attentions,
196
+ cross_attentions=transformer_outputs.cross_attentions,
197
+ )
198
+
199
+ @staticmethod
200
+ def _reorder_cache(past, beam_idx):
201
+ """
202
+ This function is used to re-order the :obj:`past_key_values` cache if
203
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
204
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
205
+ """
206
+ return tuple(
207
+ tuple(
208
+ past_state.index_select(0, beam_idx.to(past_state.device))
209
+ for past_state in layer_past
210
+ )
211
+ for layer_past in past
212
+ )
213
+
214
+
215
+ class ConditioningEncoder(nn.Module):
216
+ def __init__(self,
217
+ spec_dim,
218
+ embedding_dim,
219
+ attn_blocks=6,
220
+ num_attn_heads=4,
221
+ do_checkpointing=False,
222
+ mean=False):
223
+ super().__init__()
224
+ attn = []
225
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
226
+ for a in range(attn_blocks):
227
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
228
+ self.attn = nn.Sequential(*attn)
229
+ self.dim = embedding_dim
230
+ self.do_checkpointing = do_checkpointing
231
+ self.mean = mean
232
+
233
+ def forward(self, x):
234
+ h = self.init(x)
235
+ h = self.attn(h)
236
+ if self.mean:
237
+ return h.mean(dim=2)
238
+ else:
239
+ return h
240
+ # return h[:, :, 0]
241
+
242
+
243
+ class LearnedPositionEmbeddings(nn.Module):
244
+ def __init__(self, seq_len, model_dim, init=.02):
245
+ super().__init__()
246
+ self.emb = nn.Embedding(seq_len, model_dim)
247
+ # Initializing this way is standard for GPT-2
248
+ self.emb.weight.data.normal_(mean=0.0, std=init)
249
+
250
+ def forward(self, x):
251
+ sl = x.shape[1]
252
+ return self.emb(torch.arange(0, sl, device=x.device))
253
+
254
+ def get_fixed_embedding(self, ind, dev):
255
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
256
+
257
+
258
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
259
+ """
260
+ GPT-2 implemented by the HuggingFace library.
261
+ """
262
+ from transformers import GPT2Config, GPT2Model
263
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
264
+ n_positions=max_mel_seq_len + max_text_seq_len,
265
+ n_ctx=max_mel_seq_len + max_text_seq_len,
266
+ n_embd=model_dim,
267
+ n_layer=layers,
268
+ n_head=heads,
269
+ activation_function=activation_function or "gelu_new",
270
+ gradient_checkpointing=checkpointing,
271
+ use_cache=not checkpointing)
272
+ gpt = GPT2Model(gpt_config)
273
+ # Override the built in positional embeddings
274
+ del gpt.wpe
275
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
276
+ # Built-in token embeddings are unused.
277
+ del gpt.wte
278
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
279
+ None, None
280
+
281
+
282
+ class MelEncoder(nn.Module):
283
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
284
+ super().__init__()
285
+ self.channels = channels
286
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
287
+ nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
288
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
289
+ nn.GroupNorm(channels // 16, channels // 2),
290
+ nn.ReLU(),
291
+ nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
292
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
293
+ nn.GroupNorm(channels // 8, channels),
294
+ nn.ReLU(),
295
+ nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
296
+ )
297
+ self.reduction = 4
298
+
299
+ def forward(self, x):
300
+ for e in self.encoder:
301
+ x = e(x)
302
+ return x.permute(0, 2, 1)
303
+
304
+
305
+ class UnifiedVoice(nn.Module):
306
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
307
+ mel_length_compression=1024, number_text_tokens=256,
308
+ start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
309
+ train_solo_embeddings=False, use_mel_codes_as_input=True,
310
+ checkpointing=True, types=1, activation_function=None,
311
+ condition_num_latent=32, condition_type="perceiver", condition_module=None):
312
+ """
313
+ Args:
314
+ layers: Number of layers in transformer stack.
315
+ model_dim: Operating dimensions of the transformer
316
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
317
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
318
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
319
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
320
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
321
+ number_text_tokens:
322
+ start_text_token:
323
+ stop_text_token:
324
+ number_mel_codes:
325
+ start_mel_token:
326
+ stop_mel_token:
327
+ train_solo_embeddings:
328
+ use_mel_codes_as_input:
329
+ checkpointing:
330
+ condition_type: perceiver, gst or default encoder
331
+ """
332
+ super().__init__()
333
+ self.number_text_tokens = number_text_tokens
334
+ self.start_text_token = start_text_token
335
+ self.stop_text_token = stop_text_token
336
+ self.number_mel_codes = number_mel_codes
337
+ self.start_mel_token = start_mel_token
338
+ self.stop_mel_token = stop_mel_token
339
+ self.layers = layers
340
+ self.heads = heads
341
+ self.max_mel_tokens = max_mel_tokens
342
+ self.max_text_tokens = max_text_tokens
343
+ self.model_dim = model_dim
344
+ self.max_conditioning_inputs = max_conditioning_inputs
345
+ self.mel_length_compression = mel_length_compression
346
+ self.condition_type = condition_type
347
+ self.cond_num = condition_num_latent
348
+ self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
349
+ if condition_type == "perceiver":
350
+ self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads)
351
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
352
+ elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
353
+ self.conditioning_encoder = ConformerEncoder(input_size=100,
354
+ output_size=condition_module['output_size'],
355
+ linear_units=condition_module['linear_units'],
356
+ attention_heads=condition_module['attention_heads'],
357
+ num_blocks=condition_module['num_blocks'],
358
+ input_layer=condition_module['input_layer'])
359
+ if condition_type == "conformer_perceiver":
360
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
361
+ ff_mult=condition_module['perceiver_mult'],
362
+ heads=condition_module['attention_heads'],
363
+ num_latents=self.cond_num)
364
+ else:
365
+ self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads, mean=True)
366
+
367
+ self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
368
+ if use_mel_codes_as_input:
369
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
370
+ else:
371
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
372
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
373
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
374
+ self.max_text_tokens + 2, checkpointing, activation_function)
375
+ if train_solo_embeddings:
376
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
377
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
378
+ else:
379
+ self.mel_solo_embedding = 0
380
+ self.text_solo_embedding = 0
381
+
382
+ self.final_norm = nn.LayerNorm(model_dim)
383
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
384
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
385
+
386
+ # Initialize the embeddings per the GPT-2 scheme
387
+ embeddings = [self.text_embedding]
388
+ if use_mel_codes_as_input:
389
+ embeddings.append(self.mel_embedding)
390
+ for module in embeddings:
391
+ module.weight.data.normal_(mean=0.0, std=.02)
392
+
393
+ def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
394
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
395
+ gpt_config = GPT2Config(
396
+ vocab_size=self.number_mel_codes,
397
+ n_positions=seq_length,
398
+ n_ctx=seq_length,
399
+ n_embd=self.model_dim,
400
+ n_layer=self.layers,
401
+ n_head=self.heads,
402
+ gradient_checkpointing=False,
403
+ use_cache=True,
404
+ )
405
+ self.inference_model = GPT2InferenceModel(
406
+ gpt_config,
407
+ self.gpt,
408
+ self.mel_pos_embedding,
409
+ self.mel_embedding,
410
+ self.final_norm,
411
+ self.mel_head,
412
+ kv_cache=kv_cache,
413
+ )
414
+ if use_deepspeed and half and torch.cuda.is_available():
415
+ import deepspeed
416
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
417
+ mp_size=1,
418
+ replace_with_kernel_inject=False,
419
+ dtype=torch.float16)
420
+ self.inference_model = self.ds_engine.module.eval()
421
+ elif use_deepspeed and torch.cuda.is_available():
422
+ import deepspeed
423
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
424
+ mp_size=1,
425
+ replace_with_kernel_inject=False,
426
+ dtype=torch.float32)
427
+ self.inference_model = self.ds_engine.module.eval()
428
+ else:
429
+ self.inference_model = self.inference_model.eval()
430
+
431
+ # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
432
+ self.gpt.wte = self.mel_embedding
433
+
434
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
435
+ inp = F.pad(input, (1, 0), value=start_token)
436
+ tar = F.pad(input, (0, 1), value=stop_token)
437
+ return inp, tar
438
+
439
+ def set_mel_padding(self, mel_input_tokens, mel_lengths):
440
+ """
441
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
442
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
443
+ preformatting to create a working TTS model.
444
+ """
445
+ for b in range(len(mel_lengths)):
446
+ # Due to the convolutional nature of how these tokens are generated,
447
+ # it would be best if the model predicts a token past the actual last token.
448
+ actual_end = mel_lengths[b]
449
+ if actual_end < mel_input_tokens.shape[-1]:
450
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
451
+ return mel_input_tokens
452
+
453
+ def set_text_padding(self, text_input_tokens, text_lengths):
454
+ """
455
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
456
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
457
+ preformatting to create a working TTS model.
458
+ """
459
+ for b in range(len(text_lengths)):
460
+ # Due to the convolutional nature of how these tokens are generated,
461
+ # it would be best if the model predicts a token past the actual last token.
462
+ actual_end = text_lengths[b]
463
+ if actual_end < text_input_tokens.shape[-1]:
464
+ text_input_tokens[b, actual_end:] = self.stop_text_token
465
+ return text_input_tokens
466
+
467
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
468
+ if second_inputs is not None:
469
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
470
+ else:
471
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
472
+
473
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
474
+ if get_attns:
475
+ return gpt_out.attentions
476
+
477
+ offset = speech_conditioning_inputs.shape[1]
478
+ enc = gpt_out.last_hidden_state[:, offset:]
479
+ enc = self.final_norm(enc)
480
+
481
+ if return_latent:
482
+ return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
483
+
484
+ first_logits = enc[:, :first_inputs.shape[1]]
485
+ first_logits = first_head(first_logits)
486
+ first_logits = first_logits.permute(0, 2, 1)
487
+ if second_inputs is not None:
488
+ second_logits = enc[:, -second_inputs.shape[1]:]
489
+ second_logits = second_head(second_logits)
490
+ second_logits = second_logits.permute(0, 2, 1)
491
+ return first_logits, second_logits
492
+ else:
493
+ return first_logits
494
+
495
+ def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
496
+ if self.condition_type == "perceiver":
497
+ if speech_conditioning_input.ndim == 4:
498
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
499
+ speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
500
+ conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
501
+ elif self.condition_type == "conformer_perceiver":
502
+ speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
503
+ cond_mel_lengths) # (b, s, d), (b, 1, s)
504
+ if self.condition_type == "conformer_perceiver":
505
+ # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
506
+ conds_mask = self.cond_mask_pad(mask.squeeze(1))
507
+ conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
508
+ elif self.condition_type == "gst":
509
+ if speech_conditioning_input.ndim == 4:
510
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
511
+ conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
512
+ else:
513
+ speech_conditioning_input = (
514
+ speech_conditioning_input.unsqueeze(1)
515
+ if len(speech_conditioning_input.shape) == 3
516
+ else speech_conditioning_input
517
+ )
518
+ conds = []
519
+ for j in range(speech_conditioning_input.shape[1]):
520
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
521
+ conds = torch.stack(conds, dim=1)
522
+ conds = conds.mean(dim=1)
523
+ conds = conds.unsqueeze(1)
524
+ return conds
525
+
526
+ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths,
527
+ cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False,
528
+ return_latent=False, clip_inputs=False):
529
+ """
530
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
531
+ (actuated by `text_first`).
532
+
533
+ speech_conditioning_input: MEL float tensor, (b,1024)
534
+ text_inputs: long tensor, (b,t)
535
+ text_lengths: long tensor, (b,)
536
+ mel_inputs: long tensor, (b,m)
537
+ wav_lengths: long tensor, (b,)
538
+ raw_mels: MEL float tensor (b,80,s)
539
+
540
+ If return_attentions is specified, only logits are returned.
541
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
542
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
543
+ """
544
+
545
+ speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
546
+ # Types are expressed by expanding the text embedding space.
547
+ if types is not None:
548
+ text_inputs = text_inputs * (1 + types).unsqueeze(-1)
549
+
550
+ if clip_inputs:
551
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
552
+ # chopping the inputs by the maximum actual length.
553
+ max_text_len = text_lengths.max()
554
+ text_inputs = text_inputs[:, :max_text_len]
555
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
556
+ mel_codes = mel_codes[:, :max_mel_len]
557
+ if raw_mels is not None:
558
+ raw_mels = raw_mels[:, :, :max_mel_len * 4]
559
+
560
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
561
+ # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
562
+ mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
563
+ mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
564
+ text_inputs = self.set_text_padding(text_inputs, text_lengths)
565
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
566
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
567
+
568
+ conds = speech_conditioning_latent
569
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
570
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
571
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
572
+ if raw_mels is not None:
573
+ mel_inp = F.pad(raw_mels, (0, 8))
574
+ else:
575
+ mel_inp = mel_codes
576
+ mel_emb = self.mel_embedding(mel_inp)
577
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
578
+
579
+ if text_first:
580
+ # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
581
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
582
+ if return_latent:
583
+ return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
584
+ else:
585
+ mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
586
+ if return_latent:
587
+ return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
588
+
589
+ if return_attentions:
590
+ return mel_logits
591
+
592
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
593
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
594
+ return loss_text.mean(), loss_mel.mean(), mel_logits
595
+
596
+ def prepare_gpt_inputs(
597
+ self,
598
+ conditional_latents: torch.Tensor,
599
+ text_inputs: torch.Tensor,
600
+ ):
601
+
602
+ """
603
+ Prepare the inputs for the GPT2InferenceModel to generate.
604
+ Args:
605
+ conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
606
+ text_inputs: (b, L)
607
+ Returns:
608
+ input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
609
+ inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
610
+ attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
611
+ """
612
+ b, L = text_inputs.shape[:2]
613
+ device = text_inputs.device
614
+ single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
615
+ if not single_cond:
616
+ assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
617
+ batched_mel_emb = []
618
+ attention_masks = []
619
+ target_len = conditional_latents.shape[1] + L + 2
620
+ for i in range(b):
621
+ valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
622
+ text_input = text_inputs[i][valid_mask]
623
+ text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
624
+ text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
625
+ text_input_pos = torch.arange(0, text_input.size(-1), device=device)
626
+ text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
627
+ # concatenate [conditional latents][text embeddings]
628
+ conds_text_emb = [
629
+ conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
630
+ text_emb,
631
+ ]
632
+ # +1 for the start_mel_token
633
+ attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
634
+ # check this text input is padded
635
+ padding: int = L + 2 - text_input.size(-1)
636
+ # pad left of [cond][text] -> [pad][cond][text]
637
+ if padding > 0:
638
+ pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
639
+ conds_text_emb.insert(0, pad)
640
+ attention_mask[:padding] = 0
641
+ mel_emb = torch.cat(conds_text_emb) #[s, dim]
642
+ assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
643
+ batched_mel_emb.append(mel_emb)
644
+ attention_masks.append(attention_mask)
645
+ # [b, s, dim]
646
+ batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
647
+ # [b, s+1]
648
+ attention_mask = torch.stack(attention_masks, dim=0)
649
+ # [b, s+1]
650
+ fake_inputs = torch.ones(
651
+ (
652
+ batched_mel_emb.shape[0],
653
+ batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
654
+ ),
655
+ dtype=torch.long,
656
+ device=device,
657
+ )
658
+ fake_inputs[:, -1] = self.start_mel_token
659
+ return fake_inputs, batched_mel_emb, attention_mask
660
+ def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
661
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
662
+ """
663
+ Args:
664
+ speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
665
+ text_inputs: (b, L)
666
+ cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
667
+ input_tokens: additional tokens for generation in shape (b, s) or (s,)
668
+ max_generate_length: limit the number of generated tokens
669
+ hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
670
+ """
671
+ if speech_conditioning_mel.ndim == 2:
672
+ speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
673
+ if cond_mel_lengths is None:
674
+ cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
675
+ conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
676
+ input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
677
+ self.inference_model.store_mel_emb(inputs_embeds)
678
+ if input_tokens is None:
679
+ inputs = input_ids
680
+ else:
681
+ if input_tokens.ndim == 1:
682
+ input_tokens = input_tokens.unsqueeze(0)
683
+ assert num_return_sequences % input_tokens.shape[0] == 0, \
684
+ "The num_return_sequences must be divisible by the batch number of input_tokens"
685
+ assert num_return_sequences % text_inputs.shape[0] == 0, \
686
+ "The num_return_sequences must be divisible by the batch number of text_inputs"
687
+ b = num_return_sequences // input_ids.shape[0]
688
+ if b > 1:
689
+ input_ids = input_ids.repeat(b, 1)
690
+ attention_mask = attention_mask.repeat(b, 1)
691
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
692
+ inputs = torch.cat([input_ids, input_tokens], dim=1)
693
+ attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
694
+ trunc_index = inputs.shape[1]
695
+ logits_processor = LogitsProcessorList()
696
+ if typical_sampling:
697
+ # employ custom typical sampling
698
+ if not (typical_mass > 0.0 and typical_mass < 1.0):
699
+ raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
700
+ min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
701
+ logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
702
+ max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
703
+ output = self.inference_model.generate(inputs,
704
+ bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
705
+ eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
706
+ max_length=max_length, logits_processor=logits_processor,
707
+ num_return_sequences=num_return_sequences,
708
+ **hf_generate_kwargs)
709
+ if isinstance(output, torch.Tensor):
710
+ return output[:, trunc_index:]
711
+ # GenerateOutput
712
+ output.sequences = output.sequences[:, trunc_index:]
713
+ return output