xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,747 @@
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import transformers
8
+ from transformers import GPT2Config, LogitsProcessorList
9
+ from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
10
+
11
+ # from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
12
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
+ from transformers.utils.model_parallel_utils import (assert_device_map,
14
+ get_device_map)
15
+
16
+ from indextts.gpt.conformer_encoder import ConformerEncoder
17
+ from indextts.gpt.perceiver import PerceiverResampler
18
+ from indextts.utils.arch_util import AttentionBlock
19
+ from indextts.utils.typical_sampling import TypicalLogitsWarper
20
+
21
+
22
+ def null_position_embeddings(range, dim):
23
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
24
+
25
+
26
+ class ResBlock(nn.Module):
27
+ """
28
+ Basic residual convolutional block that uses GroupNorm.
29
+ """
30
+
31
+ def __init__(self, chan):
32
+ super().__init__()
33
+ self.net = nn.Sequential(
34
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
35
+ nn.GroupNorm(chan // 8, chan),
36
+ nn.ReLU(),
37
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
38
+ nn.GroupNorm(chan // 8, chan)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return F.relu(self.net(x) + x)
43
+
44
+
45
+ class GPT2InferenceModel(GPT2PreTrainedModel):
46
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
47
+ super().__init__(config)
48
+ # Note: the argument named `text_pos_emb` here actually represents the mel position embedding
49
+ self.transformer = gpt
50
+ self.text_pos_embedding = text_pos_emb
51
+ self.embeddings = embeddings
52
+ self.final_norm = norm
53
+ self.lm_head = nn.Sequential(norm, linear)
54
+ self.kv_cache = kv_cache
55
+
56
+ # Model parallel
57
+ self.model_parallel = False
58
+ self.device_map = None
59
+ self.cached_mel_emb = None
60
+
61
+ def parallelize(self, device_map=None):
62
+ self.device_map = (
63
+ get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
64
+ if device_map is None
65
+ else device_map
66
+ )
67
+ assert_device_map(self.device_map, len(self.transformer.h))
68
+ self.transformer.parallelize(self.device_map)
69
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
70
+ self.model_parallel = True
71
+
72
+ def deparallelize(self):
73
+ self.transformer.deparallelize()
74
+ self.transformer = self.transformer.to("cpu")
75
+ self.lm_head = self.lm_head.to("cpu")
76
+ self.model_parallel = False
77
+ torch.cuda.empty_cache()
78
+ if torch.backends.mps.is_available():
79
+ torch.mps.empty_cache()
80
+
81
+ def get_output_embeddings(self):
82
+ return self.lm_head
83
+
84
+ def set_output_embeddings(self, new_embeddings):
85
+ self.lm_head = new_embeddings
86
+
87
+ def store_mel_emb(self, mel_emb):
88
+ self.cached_mel_emb = mel_emb
89
+
90
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
91
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
92
+ if not self.kv_cache:
93
+ past_key_values = None
94
+ # only last token for inputs_ids if past is defined in kwargs
95
+ if past_key_values:
96
+ input_ids = input_ids[:, -1].unsqueeze(-1)
97
+ if token_type_ids is not None:
98
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
99
+
100
+ attention_mask = kwargs.get("attention_mask", None)
101
+ position_ids = kwargs.get("position_ids", None)
102
+
103
+ if attention_mask is not None and position_ids is None:
104
+ # create position_ids on the fly for batch generation
105
+ position_ids = attention_mask.long().cumsum(-1) - 1
106
+ position_ids.masked_fill_(attention_mask == 0, 0)
107
+ if past_key_values:
108
+ position_ids = position_ids[:, -1].unsqueeze(-1)
109
+ else:
110
+ position_ids = None
111
+ return {
112
+ "input_ids": input_ids,
113
+ "past_key_values": past_key_values,
114
+ "use_cache": kwargs.get("use_cache"),
115
+ "position_ids": position_ids,
116
+ "attention_mask": attention_mask,
117
+ "token_type_ids": token_type_ids,
118
+ }
119
+
120
+ def forward(
121
+ self,
122
+ input_ids=None,
123
+ past_key_values=None,
124
+ attention_mask=None,
125
+ token_type_ids=None,
126
+ position_ids=None,
127
+ head_mask=None,
128
+ inputs_embeds=None,
129
+ encoder_hidden_states=None,
130
+ encoder_attention_mask=None,
131
+ labels=None,
132
+ use_cache=None,
133
+ output_attentions=None,
134
+ output_hidden_states=None,
135
+ return_dict=None,
136
+ ):
137
+ assert self.cached_mel_emb is not None
138
+ assert inputs_embeds is None # Not supported by this inference model.
139
+ assert labels is None # Training not supported by this inference model.
140
+ return_dict = (
141
+ return_dict if return_dict is not None else self.config.use_return_dict
142
+ )
143
+ # Create embedding
144
+ mel_len = self.cached_mel_emb.shape[1]
145
+ if input_ids.shape[1] != 1:
146
+ text_inputs = input_ids[:, mel_len:]
147
+ text_emb = self.embeddings(text_inputs)
148
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
149
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
150
+ mel_emb = self.cached_mel_emb.repeat_interleave(
151
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
152
+ )
153
+ else: # this outcome only occurs once per loop in most cases
154
+ mel_emb = self.cached_mel_emb
155
+ emb = torch.cat([mel_emb, text_emb], dim=1)
156
+ else:
157
+ emb = self.embeddings(input_ids)
158
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
159
+ attention_mask.shape[1] - mel_len, attention_mask.device
160
+ )
161
+ transformer_outputs = self.transformer(
162
+ inputs_embeds=emb,
163
+ past_key_values=past_key_values,
164
+ attention_mask=attention_mask,
165
+ token_type_ids=token_type_ids,
166
+ position_ids=position_ids,
167
+ head_mask=head_mask,
168
+ encoder_hidden_states=encoder_hidden_states,
169
+ encoder_attention_mask=encoder_attention_mask,
170
+ use_cache=use_cache,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ return_dict=return_dict,
174
+ )
175
+ hidden_states = transformer_outputs[0]
176
+
177
+ # Set device for model parallelism
178
+ if self.model_parallel:
179
+ if torch.backends.mps.is_available():
180
+ self.to(self.transformer.first_device)
181
+ else:
182
+ torch.cuda.set_device(self.transformer.first_device)
183
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
184
+
185
+ lm_logits = self.lm_head(hidden_states)
186
+
187
+ if not return_dict:
188
+ return (lm_logits,) + transformer_outputs[1:]
189
+
190
+ return CausalLMOutputWithCrossAttentions(
191
+ loss=None,
192
+ logits=lm_logits,
193
+ past_key_values=transformer_outputs.past_key_values,
194
+ hidden_states=transformer_outputs.hidden_states,
195
+ attentions=transformer_outputs.attentions,
196
+ cross_attentions=transformer_outputs.cross_attentions,
197
+ )
198
+
199
+ @staticmethod
200
+ def _reorder_cache(past, beam_idx):
201
+ """
202
+ This function is used to re-order the :obj:`past_key_values` cache if
203
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
204
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
205
+ """
206
+ return tuple(
207
+ tuple(
208
+ past_state.index_select(0, beam_idx.to(past_state.device))
209
+ for past_state in layer_past
210
+ )
211
+ for layer_past in past
212
+ )
213
+
214
+
215
+ class ConditioningEncoder(nn.Module):
216
+ def __init__(self,
217
+ spec_dim,
218
+ embedding_dim,
219
+ attn_blocks=6,
220
+ num_attn_heads=4,
221
+ do_checkpointing=False,
222
+ mean=False):
223
+ super().__init__()
224
+ attn = []
225
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
226
+ for a in range(attn_blocks):
227
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
228
+ self.attn = nn.Sequential(*attn)
229
+ self.dim = embedding_dim
230
+ self.do_checkpointing = do_checkpointing
231
+ self.mean = mean
232
+
233
+ def forward(self, x):
234
+ h = self.init(x)
235
+ h = self.attn(h)
236
+ if self.mean:
237
+ return h.mean(dim=2)
238
+ else:
239
+ return h
240
+ # return h[:, :, 0]
241
+
242
+
243
+ class LearnedPositionEmbeddings(nn.Module):
244
+ def __init__(self, seq_len, model_dim, init=.02):
245
+ super().__init__()
246
+ self.emb = nn.Embedding(seq_len, model_dim)
247
+ # Initializing this way is standard for GPT-2
248
+ self.emb.weight.data.normal_(mean=0.0, std=init)
249
+
250
+ def forward(self, x):
251
+ sl = x.shape[1]
252
+ return self.emb(torch.arange(0, sl, device=x.device))
253
+
254
+ def get_fixed_embedding(self, ind, dev):
255
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
256
+
257
+
258
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
259
+ """
260
+ GPT-2 implemented by the HuggingFace library.
261
+ """
262
+ from transformers import GPT2Config, GPT2Model
263
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
264
+ n_positions=max_mel_seq_len + max_text_seq_len,
265
+ n_ctx=max_mel_seq_len + max_text_seq_len,
266
+ n_embd=model_dim,
267
+ n_layer=layers,
268
+ n_head=heads,
269
+ gradient_checkpointing=checkpointing,
270
+ use_cache=not checkpointing)
271
+ gpt = GPT2Model(gpt_config)
272
+ # Override the built in positional embeddings
273
+ del gpt.wpe
274
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
275
+ # Built-in token embeddings are unused.
276
+ del gpt.wte
277
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
278
+ None, None
279
+
280
+
281
+ class MelEncoder(nn.Module):
282
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
283
+ super().__init__()
284
+ self.channels = channels
285
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
286
+ nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
287
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
288
+ nn.GroupNorm(channels // 16, channels // 2),
289
+ nn.ReLU(),
290
+ nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
291
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
292
+ nn.GroupNorm(channels // 8, channels),
293
+ nn.ReLU(),
294
+ nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
295
+ )
296
+ self.reduction = 4
297
+
298
+ def forward(self, x):
299
+ for e in self.encoder:
300
+ x = e(x)
301
+ return x.permute(0, 2, 1)
302
+
303
+
304
+ class UnifiedVoice(nn.Module):
305
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
306
+ mel_length_compression=1024, number_text_tokens=256,
307
+ start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
308
+ train_solo_embeddings=False, use_mel_codes_as_input=True,
309
+ checkpointing=True, types=1,
310
+ condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
311
+ """
312
+ Args:
313
+ layers: Number of layers in transformer stack.
314
+ model_dim: Operating dimensions of the transformer
315
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
316
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
317
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
318
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
319
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
320
+ number_text_tokens:
321
+ start_text_token:
322
+ stop_text_token:
323
+ number_mel_codes:
324
+ start_mel_token:
325
+ stop_mel_token:
326
+ train_solo_embeddings:
327
+ use_mel_codes_as_input:
328
+ checkpointing:
329
+ condition_type: perceiver, gst or default encoder
330
+ """
331
+ super().__init__()
332
+ self.number_text_tokens = number_text_tokens
333
+ self.start_text_token = start_text_token
334
+ self.stop_text_token = stop_text_token
335
+ self.number_mel_codes = number_mel_codes
336
+ self.start_mel_token = start_mel_token
337
+ self.stop_mel_token = stop_mel_token
338
+ self.layers = layers
339
+ self.heads = heads
340
+ self.max_mel_tokens = max_mel_tokens
341
+ self.max_text_tokens = max_text_tokens
342
+ self.model_dim = model_dim
343
+ self.max_conditioning_inputs = max_conditioning_inputs
344
+ self.mel_length_compression = mel_length_compression
345
+ self.condition_type = condition_type
346
+ self.cond_num = condition_num_latent
347
+ self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
348
+ self.emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True)
349
+ if condition_type == "perceiver":
350
+ self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads)
351
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
352
+ elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
353
+ self.conditioning_encoder = ConformerEncoder(input_size=1024,
354
+ output_size=condition_module['output_size'],
355
+ linear_units=condition_module['linear_units'],
356
+ attention_heads=condition_module['attention_heads'],
357
+ num_blocks=condition_module['num_blocks'],
358
+ input_layer=condition_module['input_layer'])
359
+ if condition_type == "conformer_perceiver":
360
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
361
+ ff_mult=condition_module['perceiver_mult'],
362
+ heads=condition_module['attention_heads'],
363
+ num_latents=self.cond_num)
364
+ else:
365
+ self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads, mean=True)
366
+
367
+ self.emo_conditioning_encoder = ConformerEncoder(input_size=1024,
368
+ output_size=emo_condition_module['output_size'],
369
+ linear_units=emo_condition_module['linear_units'],
370
+ attention_heads=emo_condition_module['attention_heads'],
371
+ num_blocks=emo_condition_module['num_blocks'],
372
+ input_layer=emo_condition_module['input_layer'])
373
+ self.emo_perceiver_encoder = PerceiverResampler(1024, dim_context=emo_condition_module['output_size'],
374
+ ff_mult=emo_condition_module['perceiver_mult'],
375
+ heads=emo_condition_module['attention_heads'],
376
+ num_latents=1)
377
+
378
+
379
+
380
+ self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
381
+ self.emo_layer = nn.Linear(model_dim, model_dim)
382
+ self.emovec_layer = nn.Linear(1024, model_dim)
383
+
384
+ if use_mel_codes_as_input:
385
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
386
+ else:
387
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
388
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
389
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
390
+ self.max_text_tokens + 2, checkpointing)
391
+ if train_solo_embeddings:
392
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
393
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
394
+ else:
395
+ self.mel_solo_embedding = 0
396
+ self.text_solo_embedding = 0
397
+
398
+ self.final_norm = nn.LayerNorm(model_dim)
399
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
400
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
401
+
402
+ self.speed_emb = nn.Embedding(2, model_dim)
403
+ self.speed_emb.weight.data.normal_(mean=0.0, std=0.0)
404
+
405
+ # Initialize the embeddings per the GPT-2 scheme
406
+ embeddings = [self.text_embedding]
407
+ if use_mel_codes_as_input:
408
+ embeddings.append(self.mel_embedding)
409
+ for module in embeddings:
410
+ module.weight.data.normal_(mean=0.0, std=.02)
411
+
412
+ def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
413
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
414
+ gpt_config = GPT2Config(
415
+ vocab_size=self.number_mel_codes,
416
+ n_positions=seq_length,
417
+ n_ctx=seq_length,
418
+ n_embd=self.model_dim,
419
+ n_layer=self.layers,
420
+ n_head=self.heads,
421
+ gradient_checkpointing=False,
422
+ use_cache=True,
423
+ )
424
+ self.inference_model = GPT2InferenceModel(
425
+ gpt_config,
426
+ self.gpt,
427
+ self.mel_pos_embedding,
428
+ self.mel_embedding,
429
+ self.final_norm,
430
+ self.mel_head,
431
+ kv_cache=kv_cache,
432
+ )
433
+ if use_deepspeed and half and torch.cuda.is_available():
434
+ import deepspeed
435
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
436
+ mp_size=1,
437
+ replace_with_kernel_inject=True,
438
+ dtype=torch.float16)
439
+ self.inference_model = self.ds_engine.module.eval()
440
+ elif use_deepspeed and torch.cuda.is_available():
441
+ import deepspeed
442
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
443
+ mp_size=1,
444
+ replace_with_kernel_inject=True,
445
+ dtype=torch.float32)
446
+ self.inference_model = self.ds_engine.module.eval()
447
+ else:
448
+ self.inference_model = self.inference_model.eval()
449
+
450
+ # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
451
+ self.gpt.wte = self.mel_embedding
452
+
453
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
454
+ inp = F.pad(input, (1, 0), value=start_token)
455
+ tar = F.pad(input, (0, 1), value=stop_token)
456
+ return inp, tar
457
+
458
+ def set_mel_padding(self, mel_input_tokens, mel_lengths):
459
+ """
460
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
461
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
462
+ preformatting to create a working TTS model.
463
+ """
464
+ for b in range(len(mel_lengths)):
465
+ # Due to the convolutional nature of how these tokens are generated,
466
+ # it would be best if the model predicts a token past the actual last token.
467
+ actual_end = mel_lengths[b]
468
+ if actual_end < mel_input_tokens.shape[-1]:
469
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
470
+ return mel_input_tokens
471
+
472
+ def set_text_padding(self, text_input_tokens, text_lengths):
473
+ """
474
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
475
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
476
+ preformatting to create a working TTS model.
477
+ """
478
+ for b in range(len(text_lengths)):
479
+ # Due to the convolutional nature of how these tokens are generated,
480
+ # it would be best if the model predicts a token past the actual last token.
481
+ actual_end = text_lengths[b]
482
+ if actual_end < text_input_tokens.shape[-1]:
483
+ text_input_tokens[b, actual_end:] = self.stop_text_token
484
+ return text_input_tokens
485
+
486
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
487
+ if second_inputs is not None:
488
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
489
+ else:
490
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
491
+
492
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
493
+ if get_attns:
494
+ return gpt_out.attentions
495
+
496
+ offset = speech_conditioning_inputs.shape[1]
497
+ enc = gpt_out.last_hidden_state[:, offset:]
498
+ enc = self.final_norm(enc)
499
+
500
+ if return_latent:
501
+ return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
502
+
503
+ first_logits = enc[:, :first_inputs.shape[1]]
504
+ first_logits = first_head(first_logits)
505
+ first_logits = first_logits.permute(0, 2, 1)
506
+ if second_inputs is not None:
507
+ second_logits = enc[:, -second_inputs.shape[1]:]
508
+ second_logits = second_head(second_logits)
509
+ second_logits = second_logits.permute(0, 2, 1)
510
+ return first_logits, second_logits
511
+ else:
512
+ return first_logits
513
+
514
+ def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
515
+ if self.condition_type == "perceiver":
516
+ if speech_conditioning_input.ndim == 4:
517
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
518
+ speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
519
+ conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
520
+ elif self.condition_type == "conformer_perceiver":
521
+ speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
522
+ cond_mel_lengths) # (b, s, d), (b, 1, s)
523
+ if self.condition_type == "conformer_perceiver":
524
+ # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
525
+ conds_mask = self.cond_mask_pad(mask.squeeze(1))
526
+ conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
527
+ elif self.condition_type == "gst":
528
+ if speech_conditioning_input.ndim == 4:
529
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
530
+ conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
531
+ else:
532
+ speech_conditioning_input = (
533
+ speech_conditioning_input.unsqueeze(1)
534
+ if len(speech_conditioning_input.shape) == 3
535
+ else speech_conditioning_input
536
+ )
537
+ conds = []
538
+ for j in range(speech_conditioning_input.shape[1]):
539
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
540
+ conds = torch.stack(conds, dim=1)
541
+ conds = conds.mean(dim=1)
542
+ conds = conds.unsqueeze(1)
543
+ return conds
544
+
545
+
546
+ def get_emo_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
547
+ speech_conditioning_input, mask = self.emo_conditioning_encoder(speech_conditioning_input.transpose(1, 2),
548
+ cond_mel_lengths) # (b, s, d), (b, 1, s)
549
+ conds_mask = self.emo_cond_mask_pad(mask.squeeze(1))
550
+ conds = self.emo_perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 1, d)
551
+ return conds.squeeze(1)
552
+
553
+
554
+ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, mel_codes_lengths, emo_speech_conditioning_latent,
555
+ cond_mel_lengths=None, emo_cond_mel_lengths=None, emo_vec=None, use_speed=None, do_spk_cond=False):
556
+ """
557
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
558
+
559
+ speech_conditioning_input: MEL float tensor, (b,1024)
560
+ text_inputs: long tensor, (b,t)
561
+ text_lengths: long tensor, (b,)
562
+ mel_inputs: long tensor, (b,m)
563
+ wav_lengths: long tensor, (b,)
564
+
565
+ If return_attentions is specified, only logits are returned.
566
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
567
+ """
568
+
569
+ if do_spk_cond:
570
+ speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent.transpose(1,2), cond_mel_lengths)
571
+ else:
572
+ speech_conditioning_latent = speech_conditioning_latent
573
+
574
+ if emo_vec is None:
575
+ emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_mel_lengths)
576
+ emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
577
+ emo_vec = self.emo_layer(emo_vec_syn)
578
+
579
+ text_inputs = self.set_text_padding(text_inputs, text_lengths)
580
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
581
+
582
+ mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
583
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
584
+
585
+ duration_emb = self.speed_emb(torch.zeros_like(use_speed))
586
+ duration_emb_half = self.speed_emb(torch.ones_like(use_speed))
587
+ conds = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
588
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
589
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
590
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
591
+
592
+ mel_emb = self.mel_embedding(mel_codes)
593
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
594
+
595
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=False, return_latent=True)
596
+ return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
597
+
598
+ def prepare_gpt_inputs(
599
+ self,
600
+ conditional_latents: torch.Tensor,
601
+ text_inputs: torch.Tensor,
602
+ ):
603
+
604
+ """
605
+ Prepare the inputs for the GPT2InferenceModel to generate.
606
+ Args:
607
+ conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
608
+ text_inputs: (b, L)
609
+ Returns:
610
+ input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
611
+ inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
612
+ attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
613
+ """
614
+ b, L = text_inputs.shape[:2]
615
+ device = text_inputs.device
616
+ single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
617
+ if not single_cond:
618
+ assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
619
+ batched_mel_emb = []
620
+ attention_masks = []
621
+ target_len = conditional_latents.shape[1] + L + 2
622
+ for i in range(b):
623
+ valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
624
+ text_input = text_inputs[i][valid_mask]
625
+ text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
626
+ text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
627
+ text_input_pos = torch.arange(0, text_input.size(-1), device=device)
628
+ text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
629
+ # concatenate [conditional latents][text embeddings]
630
+ conds_text_emb = [
631
+ conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
632
+ text_emb,
633
+ ]
634
+ # +1 for the start_mel_token
635
+ attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
636
+ # check this text input is padded
637
+ padding: int = L + 2 - text_input.size(-1)
638
+ # pad left of [cond][text] -> [pad][cond][text]
639
+ if padding > 0:
640
+ pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
641
+ conds_text_emb.insert(0, pad)
642
+ attention_mask[:padding] = 0
643
+ mel_emb = torch.cat(conds_text_emb) #[s, dim]
644
+ assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
645
+ batched_mel_emb.append(mel_emb)
646
+ attention_masks.append(attention_mask)
647
+ # [b, s, dim]
648
+ batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
649
+ # [b, s+1]
650
+ attention_mask = torch.stack(attention_masks, dim=0)
651
+ # [b, s+1]
652
+ fake_inputs = torch.ones(
653
+ (
654
+ batched_mel_emb.shape[0],
655
+ batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
656
+ ),
657
+ dtype=torch.long,
658
+ device=device,
659
+ )
660
+ fake_inputs[:, -1] = self.start_mel_token
661
+ return fake_inputs, batched_mel_emb, attention_mask
662
+
663
+ def inference_speech(self, speech_condition, text_inputs, emo_speech_condition=None, cond_lengths=None, emo_cond_lengths=None, emo_vec=None, use_speed=False, input_tokens=None, num_return_sequences=1,
664
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
665
+ """
666
+ Args:
667
+ speech_condition: (b, d, frames) or (d, frames)
668
+ text_inputs: (b, L)
669
+ cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
670
+ input_tokens: additional tokens for generation in shape (b, s) or (s,)
671
+ max_generate_length: limit the number of generated tokens
672
+ hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
673
+ """
674
+
675
+ if speech_condition.ndim == 2:
676
+ speech_condition = speech_condition.unsqueeze(0)
677
+ if emo_speech_condition is None:
678
+ emo_speech_condition = speech_condition
679
+ if cond_lengths is None:
680
+ cond_lengths = torch.tensor([speech_condition.shape[-1]], device=speech_condition.device)
681
+ if emo_cond_lengths is None:
682
+ emo_cond_lengths = torch.tensor([emo_speech_condition.shape[-1]], device=speech_condition.device)
683
+
684
+ speech_conditioning_latent = self.get_conditioning(speech_condition.transpose(1,2), cond_lengths)
685
+ if emo_vec is None:
686
+ print('compute emo vec')
687
+ emo_vec = self.get_emo_conditioning(emo_speech_condition.transpose(1,2), emo_cond_lengths)
688
+ emo_vec = self.emovec_layer(emo_vec)
689
+ emo_vec = self.emo_layer(emo_vec)
690
+ else:
691
+ print('Use the specified emotion vector')
692
+
693
+ tmp = torch.zeros(text_inputs.size(0)).to(text_inputs.device)
694
+ duration_emb = self.speed_emb(torch.zeros_like(tmp).long())
695
+ duration_emb_half = self.speed_emb(torch.ones_like(tmp).long())
696
+ conds_latent = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
697
+ input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
698
+ self.inference_model.store_mel_emb(inputs_embeds)
699
+ if input_tokens is None:
700
+ inputs = input_ids
701
+ else:
702
+ if input_tokens.ndim == 1:
703
+ input_tokens = input_tokens.unsqueeze(0)
704
+ assert num_return_sequences % input_tokens.shape[0] == 0, \
705
+ "The num_return_sequences must be divisible by the batch number of input_tokens"
706
+ assert num_return_sequences % text_inputs.shape[0] == 0, \
707
+ "The num_return_sequences must be divisible by the batch number of text_inputs"
708
+ b = num_return_sequences // input_ids.shape[0]
709
+ if b > 1:
710
+ input_ids = input_ids.repeat(b, 1)
711
+ attention_mask = attention_mask.repeat(b, 1)
712
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
713
+ inputs = torch.cat([input_ids, input_tokens], dim=1)
714
+ attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
715
+ trunc_index = inputs.shape[1]
716
+ logits_processor = LogitsProcessorList()
717
+ if typical_sampling:
718
+ # employ custom typical sampling
719
+ if not (typical_mass > 0.0 and typical_mass < 1.0):
720
+ raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
721
+ min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
722
+ logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
723
+ max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
724
+ output = self.inference_model.generate(inputs,
725
+ bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
726
+ eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
727
+ max_length=max_length, logits_processor=logits_processor,
728
+ num_return_sequences=num_return_sequences,
729
+ **hf_generate_kwargs)
730
+ if isinstance(output, torch.Tensor):
731
+ return output[:, trunc_index:], speech_conditioning_latent
732
+ # GenerateOutput
733
+ output.sequences = output.sequences[:, trunc_index:]
734
+ return output, speech_conditioning_latent
735
+
736
+ def get_emovec(self, emo_speech_conditioning_latent, emo_cond_lengths):
737
+ emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_lengths)
738
+ emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
739
+ emo_vec = self.emo_layer(emo_vec_syn)
740
+ return emo_vec
741
+
742
+ def merge_emovec(self, speech_conditioning_latent, emo_speech_conditioning_latent, cond_lengths, emo_cond_lengths, alpha = 1.0):
743
+ emo_vec = self.get_emovec(emo_speech_conditioning_latent, emo_cond_lengths)
744
+ base_vec = self.get_emovec(speech_conditioning_latent, cond_lengths)
745
+
746
+ out = base_vec + alpha * (emo_vec - base_vec)
747
+ return out