xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,632 @@
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+ import argparse
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+
11
+ def str2bool(v):
12
+ if isinstance(v, bool):
13
+ return v
14
+ if v.lower() in ("yes", "true", "t", "y", "1"):
15
+ return True
16
+ elif v.lower() in ("no", "false", "f", "n", "0"):
17
+ return False
18
+ else:
19
+ raise argparse.ArgumentTypeError("Boolean value expected.")
20
+
21
+ class AttrDict(dict):
22
+ def __init__(self, *args, **kwargs):
23
+ super(AttrDict, self).__init__(*args, **kwargs)
24
+ self.__dict__ = self
25
+
26
+
27
+ def init_weights(m, mean=0.0, std=0.01):
28
+ classname = m.__class__.__name__
29
+ if classname.find("Conv") != -1:
30
+ m.weight.data.normal_(mean, std)
31
+
32
+
33
+ def get_padding(kernel_size, dilation=1):
34
+ return int((kernel_size * dilation - dilation) / 2)
35
+
36
+
37
+ def convert_pad_shape(pad_shape):
38
+ l = pad_shape[::-1]
39
+ pad_shape = [item for sublist in l for item in sublist]
40
+ return pad_shape
41
+
42
+
43
+ def intersperse(lst, item):
44
+ result = [item] * (len(lst) * 2 + 1)
45
+ result[1::2] = lst
46
+ return result
47
+
48
+
49
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
50
+ """KL(P||Q)"""
51
+ kl = (logs_q - logs_p) - 0.5
52
+ kl += (
53
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
54
+ )
55
+ return kl
56
+
57
+
58
+ def rand_gumbel(shape):
59
+ """Sample from the Gumbel distribution, protect from overflows."""
60
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
61
+ return -torch.log(-torch.log(uniform_samples))
62
+
63
+
64
+ def rand_gumbel_like(x):
65
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
66
+ return g
67
+
68
+
69
+ def slice_segments(x, ids_str, segment_size=4):
70
+ ret = torch.zeros_like(x[:, :, :segment_size])
71
+ for i in range(x.size(0)):
72
+ idx_str = ids_str[i]
73
+ idx_end = idx_str + segment_size
74
+ ret[i] = x[i, :, idx_str:idx_end]
75
+ return ret
76
+
77
+
78
+ def slice_segments_audio(x, ids_str, segment_size=4):
79
+ ret = torch.zeros_like(x[:, :segment_size])
80
+ for i in range(x.size(0)):
81
+ idx_str = ids_str[i]
82
+ idx_end = idx_str + segment_size
83
+ ret[i] = x[i, idx_str:idx_end]
84
+ return ret
85
+
86
+
87
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
88
+ b, d, t = x.size()
89
+ if x_lengths is None:
90
+ x_lengths = t
91
+ ids_str_max = x_lengths - segment_size + 1
92
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
93
+ dtype=torch.long
94
+ )
95
+ ret = slice_segments(x, ids_str, segment_size)
96
+ return ret, ids_str
97
+
98
+
99
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
100
+ position = torch.arange(length, dtype=torch.float)
101
+ num_timescales = channels // 2
102
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
103
+ num_timescales - 1
104
+ )
105
+ inv_timescales = min_timescale * torch.exp(
106
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
107
+ )
108
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
109
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
110
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
111
+ signal = signal.view(1, channels, length)
112
+ return signal
113
+
114
+
115
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
116
+ b, channels, length = x.size()
117
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
118
+ return x + signal.to(dtype=x.dtype, device=x.device)
119
+
120
+
121
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
122
+ b, channels, length = x.size()
123
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
124
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
125
+
126
+
127
+ def subsequent_mask(length):
128
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
129
+ return mask
130
+
131
+
132
+ @torch.jit.script
133
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
134
+ n_channels_int = n_channels[0]
135
+ in_act = input_a + input_b
136
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
137
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
138
+ acts = t_act * s_act
139
+ return acts
140
+
141
+
142
+ def convert_pad_shape(pad_shape):
143
+ l = pad_shape[::-1]
144
+ pad_shape = [item for sublist in l for item in sublist]
145
+ return pad_shape
146
+
147
+
148
+ def shift_1d(x):
149
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
150
+ return x
151
+
152
+
153
+ def sequence_mask(length, max_length=None):
154
+ if max_length is None:
155
+ max_length = length.max()
156
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
157
+ return x.unsqueeze(0) < length.unsqueeze(1)
158
+
159
+
160
+ def avg_with_mask(x, mask):
161
+ assert mask.dtype == torch.float, "Mask should be float"
162
+
163
+ if mask.ndim == 2:
164
+ mask = mask.unsqueeze(1)
165
+
166
+ if mask.shape[1] == 1:
167
+ mask = mask.expand_as(x)
168
+
169
+ return (x * mask).sum() / mask.sum()
170
+
171
+
172
+ def generate_path(duration, mask):
173
+ """
174
+ duration: [b, 1, t_x]
175
+ mask: [b, 1, t_y, t_x]
176
+ """
177
+ device = duration.device
178
+
179
+ b, _, t_y, t_x = mask.shape
180
+ cum_duration = torch.cumsum(duration, -1)
181
+
182
+ cum_duration_flat = cum_duration.view(b * t_x)
183
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
184
+ path = path.view(b, t_x, t_y)
185
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
186
+ path = path.unsqueeze(1).transpose(2, 3) * mask
187
+ return path
188
+
189
+
190
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
191
+ if isinstance(parameters, torch.Tensor):
192
+ parameters = [parameters]
193
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
194
+ norm_type = float(norm_type)
195
+ if clip_value is not None:
196
+ clip_value = float(clip_value)
197
+
198
+ total_norm = 0
199
+ for p in parameters:
200
+ param_norm = p.grad.data.norm(norm_type)
201
+ total_norm += param_norm.item() ** norm_type
202
+ if clip_value is not None:
203
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
204
+ total_norm = total_norm ** (1.0 / norm_type)
205
+ return total_norm
206
+
207
+
208
+ def log_norm(x, mean=-4, std=4, dim=2):
209
+ """
210
+ normalized log mel -> mel -> norm -> log(norm)
211
+ """
212
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
213
+ return x
214
+
215
+
216
+ def load_F0_models(path):
217
+ # load F0 model
218
+ from .JDC.model import JDCNet
219
+
220
+ F0_model = JDCNet(num_class=1, seq_len=192)
221
+ params = torch.load(path, map_location="cpu")["net"]
222
+ F0_model.load_state_dict(params)
223
+ _ = F0_model.train()
224
+
225
+ return F0_model
226
+
227
+
228
+ def modify_w2v_forward(self, output_layer=15):
229
+ """
230
+ change forward method of w2v encoder to get its intermediate layer output
231
+ :param self:
232
+ :param layer:
233
+ :return:
234
+ """
235
+ from transformers.modeling_outputs import BaseModelOutput
236
+
237
+ def forward(
238
+ hidden_states,
239
+ attention_mask=None,
240
+ output_attentions=False,
241
+ output_hidden_states=False,
242
+ return_dict=True,
243
+ ):
244
+ all_hidden_states = () if output_hidden_states else None
245
+ all_self_attentions = () if output_attentions else None
246
+
247
+ conv_attention_mask = attention_mask
248
+ if attention_mask is not None:
249
+ # make sure padded tokens output 0
250
+ hidden_states = hidden_states.masked_fill(
251
+ ~attention_mask.bool().unsqueeze(-1), 0.0
252
+ )
253
+
254
+ # extend attention_mask
255
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
256
+ dtype=hidden_states.dtype
257
+ )
258
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
259
+ attention_mask = attention_mask.expand(
260
+ attention_mask.shape[0],
261
+ 1,
262
+ attention_mask.shape[-1],
263
+ attention_mask.shape[-1],
264
+ )
265
+
266
+ hidden_states = self.dropout(hidden_states)
267
+
268
+ if self.embed_positions is not None:
269
+ relative_position_embeddings = self.embed_positions(hidden_states)
270
+ else:
271
+ relative_position_embeddings = None
272
+
273
+ deepspeed_zero3_is_enabled = False
274
+
275
+ for i, layer in enumerate(self.layers):
276
+ if output_hidden_states:
277
+ all_hidden_states = all_hidden_states + (hidden_states,)
278
+
279
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
280
+ dropout_probability = torch.rand([])
281
+
282
+ skip_the_layer = (
283
+ True
284
+ if self.training and (dropout_probability < self.config.layerdrop)
285
+ else False
286
+ )
287
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
288
+ # under deepspeed zero3 all gpus must run in sync
289
+ if self.gradient_checkpointing and self.training:
290
+ layer_outputs = self._gradient_checkpointing_func(
291
+ layer.__call__,
292
+ hidden_states,
293
+ attention_mask,
294
+ relative_position_embeddings,
295
+ output_attentions,
296
+ conv_attention_mask,
297
+ )
298
+ else:
299
+ layer_outputs = layer(
300
+ hidden_states,
301
+ attention_mask=attention_mask,
302
+ relative_position_embeddings=relative_position_embeddings,
303
+ output_attentions=output_attentions,
304
+ conv_attention_mask=conv_attention_mask,
305
+ )
306
+ hidden_states = layer_outputs[0]
307
+
308
+ if skip_the_layer:
309
+ layer_outputs = (None, None)
310
+
311
+ if output_attentions:
312
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
313
+
314
+ if i == output_layer - 1:
315
+ break
316
+
317
+ if output_hidden_states:
318
+ all_hidden_states = all_hidden_states + (hidden_states,)
319
+
320
+ if not return_dict:
321
+ return tuple(
322
+ v
323
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
324
+ if v is not None
325
+ )
326
+ return BaseModelOutput(
327
+ last_hidden_state=hidden_states,
328
+ hidden_states=all_hidden_states,
329
+ attentions=all_self_attentions,
330
+ )
331
+
332
+ return forward
333
+
334
+
335
+ MATPLOTLIB_FLAG = False
336
+
337
+
338
+ def plot_spectrogram_to_numpy(spectrogram):
339
+ global MATPLOTLIB_FLAG
340
+ if not MATPLOTLIB_FLAG:
341
+ import matplotlib
342
+ import logging
343
+
344
+ matplotlib.use("Agg")
345
+ MATPLOTLIB_FLAG = True
346
+ mpl_logger = logging.getLogger("matplotlib")
347
+ mpl_logger.setLevel(logging.WARNING)
348
+ import matplotlib.pylab as plt
349
+ import numpy as np
350
+
351
+ fig, ax = plt.subplots(figsize=(10, 2))
352
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
353
+ plt.colorbar(im, ax=ax)
354
+ plt.xlabel("Frames")
355
+ plt.ylabel("Channels")
356
+ plt.tight_layout()
357
+
358
+ fig.canvas.draw()
359
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
360
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
361
+ plt.close()
362
+ return data
363
+
364
+
365
+ def normalize_f0(f0_sequence):
366
+ # Remove unvoiced frames (replace with -1)
367
+ voiced_indices = np.where(f0_sequence > 0)[0]
368
+ f0_voiced = f0_sequence[voiced_indices]
369
+
370
+ # Convert to log scale
371
+ log_f0 = np.log2(f0_voiced)
372
+
373
+ # Calculate mean and standard deviation
374
+ mean_f0 = np.mean(log_f0)
375
+ std_f0 = np.std(log_f0)
376
+
377
+ # Normalize the F0 sequence
378
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
379
+
380
+ # Create the normalized F0 sequence with unvoiced frames
381
+ normalized_sequence = np.zeros_like(f0_sequence)
382
+ normalized_sequence[voiced_indices] = normalized_f0
383
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
384
+
385
+ return normalized_sequence
386
+
387
+
388
+ class MyModel(nn.Module):
389
+ def __init__(self,args, use_emovec=False, use_gpt_latent=False):
390
+ super(MyModel, self).__init__()
391
+ from indextts.s2mel.modules.flow_matching import CFM
392
+ from indextts.s2mel.modules.length_regulator import InterpolateRegulator
393
+
394
+ length_regulator = InterpolateRegulator(
395
+ channels=args.length_regulator.channels,
396
+ sampling_ratios=args.length_regulator.sampling_ratios,
397
+ is_discrete=args.length_regulator.is_discrete,
398
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
399
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
400
+ codebook_size=args.length_regulator.content_codebook_size,
401
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
402
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
403
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
404
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
405
+ )
406
+
407
+ if use_gpt_latent:
408
+ self.models = nn.ModuleDict({
409
+ 'cfm': CFM(args),
410
+ 'length_regulator': length_regulator,
411
+ 'gpt_layer': torch.nn.Sequential(torch.nn.Linear(1280, 256), torch.nn.Linear(256, 128), torch.nn.Linear(128, 1024))
412
+ })
413
+
414
+ else:
415
+ self.models = nn.ModuleDict({
416
+ 'cfm': CFM(args),
417
+ 'length_regulator': length_regulator
418
+ })
419
+
420
+ def forward(self, x, target_lengths, prompt_len, cond, y):
421
+ x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
422
+ return x
423
+
424
+ def forward2(self, S_ori,target_lengths,F0_ori):
425
+ x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
426
+ return x
427
+
428
+ def forward_emovec(self, x):
429
+ x = self.models['emo_layer'](x)
430
+ return x
431
+
432
+ def forward_emo_encoder(self, x):
433
+ x = self.models['emo_encoder'](x)
434
+ return x
435
+
436
+ def forward_gpt(self,x):
437
+ x = self.models['gpt_layer'](x)
438
+ return x
439
+
440
+
441
+
442
+ def build_model(args, stage="DiT"):
443
+ if stage == "DiT":
444
+ from modules.flow_matching import CFM
445
+ from modules.length_regulator import InterpolateRegulator
446
+
447
+ length_regulator = InterpolateRegulator(
448
+ channels=args.length_regulator.channels,
449
+ sampling_ratios=args.length_regulator.sampling_ratios,
450
+ is_discrete=args.length_regulator.is_discrete,
451
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
452
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
453
+ codebook_size=args.length_regulator.content_codebook_size,
454
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
455
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
456
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
457
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
458
+ )
459
+ cfm = CFM(args)
460
+ nets = Munch(
461
+ cfm=cfm,
462
+ length_regulator=length_regulator,
463
+ )
464
+
465
+ elif stage == 'codec':
466
+ from dac.model.dac import Encoder
467
+ from modules.quantize import (
468
+ FAquantizer,
469
+ )
470
+
471
+ encoder = Encoder(
472
+ d_model=args.DAC.encoder_dim,
473
+ strides=args.DAC.encoder_rates,
474
+ d_latent=1024,
475
+ causal=args.causal,
476
+ lstm=args.lstm,
477
+ )
478
+
479
+ quantizer = FAquantizer(
480
+ in_dim=1024,
481
+ n_p_codebooks=1,
482
+ n_c_codebooks=args.n_c_codebooks,
483
+ n_t_codebooks=2,
484
+ n_r_codebooks=3,
485
+ codebook_size=1024,
486
+ codebook_dim=8,
487
+ quantizer_dropout=0.5,
488
+ causal=args.causal,
489
+ separate_prosody_encoder=args.separate_prosody_encoder,
490
+ timbre_norm=args.timbre_norm,
491
+ )
492
+
493
+ nets = Munch(
494
+ encoder=encoder,
495
+ quantizer=quantizer,
496
+ )
497
+
498
+ elif stage == "mel_vocos":
499
+ from modules.vocos import Vocos
500
+ decoder = Vocos(args)
501
+ nets = Munch(
502
+ decoder=decoder,
503
+ )
504
+
505
+ else:
506
+ raise ValueError(f"Unknown stage: {stage}")
507
+
508
+ return nets
509
+
510
+
511
+ def load_checkpoint(
512
+ model,
513
+ optimizer,
514
+ path,
515
+ load_only_params=True,
516
+ ignore_modules=[],
517
+ is_distributed=False,
518
+ load_ema=False,
519
+ ):
520
+ state = torch.load(path, map_location="cpu")
521
+ params = state["net"]
522
+ if load_ema and "ema" in state:
523
+ print("Loading EMA")
524
+ for key in model:
525
+ i = 0
526
+ for param_name in params[key]:
527
+ if "input_pos" in param_name:
528
+ continue
529
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
530
+ params[key][param_name] = state["ema"][key][0][i].clone()
531
+ i += 1
532
+ for key in model:
533
+ if key in params and key not in ignore_modules:
534
+ if not is_distributed:
535
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
536
+ for k in list(params[key].keys()):
537
+ if k.startswith("module."):
538
+ params[key][k[len("module.") :]] = params[key][k]
539
+ del params[key][k]
540
+ model_state_dict = model[key].state_dict()
541
+ # 过滤出形状匹配的键值对
542
+ filtered_state_dict = {
543
+ k: v
544
+ for k, v in params[key].items()
545
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
546
+ }
547
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
548
+ if skipped_keys:
549
+ print(
550
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
551
+ )
552
+ print("%s loaded" % key)
553
+ model[key].load_state_dict(filtered_state_dict, strict=False)
554
+ _ = [model[key].eval() for key in model]
555
+
556
+ if not load_only_params:
557
+ epoch = state["epoch"] + 1
558
+ iters = state["iters"]
559
+ optimizer.load_state_dict(state["optimizer"])
560
+ optimizer.load_scheduler_state_dict(state["scheduler"])
561
+
562
+ else:
563
+ epoch = 0
564
+ iters = 0
565
+
566
+ return model, optimizer, epoch, iters
567
+
568
+ def load_checkpoint2(
569
+ model,
570
+ optimizer,
571
+ path,
572
+ load_only_params=True,
573
+ ignore_modules=[],
574
+ is_distributed=False,
575
+ load_ema=False,
576
+ ):
577
+ state = torch.load(path, map_location="cpu")
578
+ params = state["net"]
579
+ if load_ema and "ema" in state:
580
+ print("Loading EMA")
581
+ for key in model.models:
582
+ i = 0
583
+ for param_name in params[key]:
584
+ if "input_pos" in param_name:
585
+ continue
586
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
587
+ params[key][param_name] = state["ema"][key][0][i].clone()
588
+ i += 1
589
+ for key in model.models:
590
+ if key in params and key not in ignore_modules:
591
+ if not is_distributed:
592
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
593
+ for k in list(params[key].keys()):
594
+ if k.startswith("module."):
595
+ params[key][k[len("module.") :]] = params[key][k]
596
+ del params[key][k]
597
+ model_state_dict = model.models[key].state_dict()
598
+ # 过滤出形状匹配的键值对
599
+ filtered_state_dict = {
600
+ k: v
601
+ for k, v in params[key].items()
602
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
603
+ }
604
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
605
+ if skipped_keys:
606
+ print(
607
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
608
+ )
609
+ print("%s loaded" % key)
610
+ model.models[key].load_state_dict(filtered_state_dict, strict=False)
611
+ model.eval()
612
+ # _ = [model[key].eval() for key in model]
613
+
614
+ if not load_only_params:
615
+ epoch = state["epoch"] + 1
616
+ iters = state["iters"]
617
+ optimizer.load_state_dict(state["optimizer"])
618
+ optimizer.load_scheduler_state_dict(state["scheduler"])
619
+
620
+ else:
621
+ epoch = 0
622
+ iters = 0
623
+
624
+ return model, optimizer, epoch, iters
625
+
626
+ def recursive_munch(d):
627
+ if isinstance(d, dict):
628
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
629
+ elif isinstance(d, list):
630
+ return [recursive_munch(v) for v in d]
631
+ else:
632
+ return d