xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,515 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ import json
9
+ import json5
10
+ import time
11
+ import accelerate
12
+ import random
13
+ import numpy as np
14
+ import shutil
15
+
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ from glob import glob
19
+ from accelerate.logging import get_logger
20
+ from torch.utils.data import DataLoader
21
+
22
+ from models.vocoders.vocoder_dataset import (
23
+ VocoderDataset,
24
+ VocoderCollator,
25
+ VocoderConcatDataset,
26
+ )
27
+
28
+ from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
29
+ from models.vocoders.flow.waveglow import waveglow
30
+ from models.vocoders.diffusion.diffwave import diffwave
31
+ from models.vocoders.autoregressive.wavenet import wavenet
32
+ from models.vocoders.autoregressive.wavernn import wavernn
33
+
34
+ from models.vocoders.gan import gan_vocoder_inference
35
+ from models.vocoders.diffusion import diffusion_vocoder_inference
36
+
37
+ from utils.io import save_audio
38
+
39
+ _vocoders = {
40
+ "diffwave": diffwave.DiffWave,
41
+ "wavernn": wavernn.WaveRNN,
42
+ "wavenet": wavenet.WaveNet,
43
+ "waveglow": waveglow.WaveGlow,
44
+ "nsfhifigan": nsfhifigan.NSFHiFiGAN,
45
+ "bigvgan": bigvgan.BigVGAN,
46
+ "hifigan": hifigan.HiFiGAN,
47
+ "melgan": melgan.MelGAN,
48
+ "apnet": apnet.APNet,
49
+ }
50
+
51
+ # Forward call for generalized Inferencor
52
+ _vocoder_forward_funcs = {
53
+ # "world": world_inference.synthesis_audios,
54
+ # "wavernn": wavernn_inference.synthesis_audios,
55
+ # "wavenet": wavenet_inference.synthesis_audios,
56
+ "diffwave": diffusion_vocoder_inference.vocoder_inference,
57
+ "nsfhifigan": gan_vocoder_inference.vocoder_inference,
58
+ "bigvgan": gan_vocoder_inference.vocoder_inference,
59
+ "melgan": gan_vocoder_inference.vocoder_inference,
60
+ "hifigan": gan_vocoder_inference.vocoder_inference,
61
+ "apnet": gan_vocoder_inference.vocoder_inference,
62
+ }
63
+
64
+ # APIs for other tasks. e.g. SVC, TTS, TTA...
65
+ _vocoder_infer_funcs = {
66
+ # "world": world_inference.synthesis_audios,
67
+ # "wavernn": wavernn_inference.synthesis_audios,
68
+ # "wavenet": wavenet_inference.synthesis_audios,
69
+ "diffwave": diffusion_vocoder_inference.synthesis_audios,
70
+ "nsfhifigan": gan_vocoder_inference.synthesis_audios,
71
+ "bigvgan": gan_vocoder_inference.synthesis_audios,
72
+ "melgan": gan_vocoder_inference.synthesis_audios,
73
+ "hifigan": gan_vocoder_inference.synthesis_audios,
74
+ "apnet": gan_vocoder_inference.synthesis_audios,
75
+ }
76
+
77
+
78
+ class VocoderInference(object):
79
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
80
+ super().__init__()
81
+
82
+ start = time.monotonic_ns()
83
+ self.args = args
84
+ self.cfg = cfg
85
+ self.infer_type = infer_type
86
+
87
+ # Init accelerator
88
+ self.accelerator = accelerate.Accelerator()
89
+ self.accelerator.wait_for_everyone()
90
+
91
+ # Get logger
92
+ with self.accelerator.main_process_first():
93
+ self.logger = get_logger("inference", log_level=args.log_level)
94
+
95
+ # Log some info
96
+ self.logger.info("=" * 56)
97
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
98
+ self.logger.info("=" * 56)
99
+ self.logger.info("\n")
100
+
101
+ self.vocoder_dir = args.vocoder_dir
102
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
103
+
104
+ os.makedirs(args.output_dir, exist_ok=True)
105
+ if os.path.exists(os.path.join(args.output_dir, "pred")):
106
+ shutil.rmtree(os.path.join(args.output_dir, "pred"))
107
+ if os.path.exists(os.path.join(args.output_dir, "gt")):
108
+ shutil.rmtree(os.path.join(args.output_dir, "gt"))
109
+ os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
110
+ os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
111
+
112
+ # Set random seed
113
+ with self.accelerator.main_process_first():
114
+ start = time.monotonic_ns()
115
+ self._set_random_seed(self.cfg.train.random_seed)
116
+ end = time.monotonic_ns()
117
+ self.logger.debug(
118
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
119
+ )
120
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
121
+
122
+ # Setup inference mode
123
+ if self.infer_type == "infer_from_dataset":
124
+ self.cfg.dataset = self.args.infer_datasets
125
+ elif self.infer_type == "infer_from_feature":
126
+ self._build_tmp_dataset_from_feature()
127
+ self.cfg.dataset = ["tmp"]
128
+ elif self.infer_type == "infer_from_audio":
129
+ self._build_tmp_dataset_from_audio()
130
+ self.cfg.dataset = ["tmp"]
131
+
132
+ # Setup data loader
133
+ with self.accelerator.main_process_first():
134
+ self.logger.info("Building dataset...")
135
+ start = time.monotonic_ns()
136
+ self.test_dataloader = self._build_dataloader()
137
+ end = time.monotonic_ns()
138
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
139
+
140
+ # Build model
141
+ with self.accelerator.main_process_first():
142
+ self.logger.info("Building model...")
143
+ start = time.monotonic_ns()
144
+ self.model = self._build_model()
145
+ end = time.monotonic_ns()
146
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
147
+
148
+ # Init with accelerate
149
+ self.logger.info("Initializing accelerate...")
150
+ start = time.monotonic_ns()
151
+ self.accelerator = accelerate.Accelerator()
152
+ (self.model, self.test_dataloader) = self.accelerator.prepare(
153
+ self.model, self.test_dataloader
154
+ )
155
+ end = time.monotonic_ns()
156
+ self.accelerator.wait_for_everyone()
157
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
158
+
159
+ with self.accelerator.main_process_first():
160
+ self.logger.info("Loading checkpoint...")
161
+ start = time.monotonic_ns()
162
+ if os.path.isdir(args.vocoder_dir):
163
+ if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
164
+ self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
165
+ else:
166
+ self._load_model(os.path.join(args.vocoder_dir))
167
+ else:
168
+ self._load_model(os.path.join(args.vocoder_dir))
169
+ end = time.monotonic_ns()
170
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
171
+
172
+ self.model.eval()
173
+ self.accelerator.wait_for_everyone()
174
+
175
+ def _build_tmp_dataset_from_feature(self):
176
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
177
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
178
+
179
+ utts = []
180
+ mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
181
+ for i, mel in enumerate(mels):
182
+ uid = mel.split("/")[-1].split(".")[0]
183
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i}
184
+ utts.append(utt)
185
+
186
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
187
+ with open(
188
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
189
+ ) as f:
190
+ json.dump(utts, f)
191
+
192
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
193
+
194
+ with open(
195
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
196
+ "w",
197
+ ) as f:
198
+ json.dump(meta_info, f)
199
+
200
+ features = glob(os.path.join(self.args.feature_folder, "*"))
201
+ for feature in features:
202
+ feature_name = feature.split("/")[-1]
203
+ if os.path.isfile(feature):
204
+ continue
205
+ shutil.copytree(
206
+ os.path.join(self.args.feature_folder, feature_name),
207
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
208
+ )
209
+
210
+ def _build_tmp_dataset_from_audio(self):
211
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
212
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
213
+
214
+ utts = []
215
+ audios = glob(os.path.join(self.args.audio_folder, "*"))
216
+ for i, audio in enumerate(audios):
217
+ uid = audio.split("/")[-1].split(".")[0]
218
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
219
+ utts.append(utt)
220
+
221
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
222
+ with open(
223
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
224
+ ) as f:
225
+ json.dump(utts, f)
226
+
227
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
228
+
229
+ with open(
230
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
231
+ "w",
232
+ ) as f:
233
+ json.dump(meta_info, f)
234
+
235
+ from processors import acoustic_extractor
236
+
237
+ acoustic_extractor.extract_utt_acoustic_features_serial(
238
+ utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
239
+ )
240
+
241
+ def _build_test_dataset(self):
242
+ return VocoderDataset, VocoderCollator
243
+
244
+ def _build_model(self):
245
+ model = _vocoders[self.cfg.model.generator](self.cfg)
246
+ return model
247
+
248
+ def _build_dataloader(self):
249
+ """Build dataloader which merges a series of datasets."""
250
+ Dataset, Collator = self._build_test_dataset()
251
+
252
+ datasets_list = []
253
+ for dataset in self.cfg.dataset:
254
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
255
+ datasets_list.append(subdataset)
256
+ test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
257
+ test_collate = Collator(self.cfg)
258
+ test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
259
+ test_dataloader = DataLoader(
260
+ test_dataset,
261
+ collate_fn=test_collate,
262
+ num_workers=1,
263
+ batch_size=test_batch_size,
264
+ shuffle=False,
265
+ )
266
+ self.test_batch_size = test_batch_size
267
+ self.test_dataset = test_dataset
268
+ return test_dataloader
269
+
270
+ def _load_model(self, checkpoint_dir, from_multi_gpu=False):
271
+ """Load model from checkpoint. If a folder is given, it will
272
+ load the latest checkpoint in checkpoint_dir. If a path is given
273
+ it will load the checkpoint specified by checkpoint_path.
274
+ **Only use this method after** ``accelerator.prepare()``.
275
+ """
276
+ if os.path.isdir(checkpoint_dir):
277
+ if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
278
+ checkpoint_path = checkpoint_dir
279
+ else:
280
+ # Load the latest accelerator state dicts
281
+ ls = [
282
+ str(i)
283
+ for i in Path(checkpoint_dir).glob("*")
284
+ if not "audio" in str(i)
285
+ ]
286
+ ls.sort(
287
+ key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
288
+ reverse=True,
289
+ )
290
+ checkpoint_path = ls[0]
291
+ accelerate.load_checkpoint_and_dispatch(
292
+ self.accelerator.unwrap_model(self.model),
293
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
294
+ )
295
+ return str(checkpoint_path)
296
+ else:
297
+ # Load old .pt checkpoints
298
+ if self.cfg.model.generator in [
299
+ "bigvgan",
300
+ "hifigan",
301
+ "melgan",
302
+ "nsfhifigan",
303
+ ]:
304
+ ckpt = torch.load(
305
+ checkpoint_dir,
306
+ map_location=(
307
+ torch.device("cuda")
308
+ if torch.cuda.is_available()
309
+ else torch.device("cpu")
310
+ ),
311
+ )
312
+ if from_multi_gpu:
313
+ pretrained_generator_dict = ckpt["generator_state_dict"]
314
+ generator_dict = self.model.state_dict()
315
+
316
+ new_generator_dict = {
317
+ k.split("module.")[-1]: v
318
+ for k, v in pretrained_generator_dict.items()
319
+ if (
320
+ k.split("module.")[-1] in generator_dict
321
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
322
+ )
323
+ }
324
+
325
+ generator_dict.update(new_generator_dict)
326
+
327
+ self.model.load_state_dict(generator_dict)
328
+ else:
329
+ self.model.load_state_dict(ckpt["generator_state_dict"])
330
+ else:
331
+ self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
332
+ return str(checkpoint_dir)
333
+
334
+ def inference(self):
335
+ """Inference via batches"""
336
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
337
+ if self.cfg.preprocess.use_frame_pitch:
338
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
339
+ self.cfg,
340
+ self.model,
341
+ batch["mel"].transpose(-1, -2),
342
+ f0s=batch["frame_pitch"].float(),
343
+ device=next(self.model.parameters()).device,
344
+ )
345
+ else:
346
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
347
+ self.cfg,
348
+ self.model,
349
+ batch["mel"].transpose(-1, -2),
350
+ device=next(self.model.parameters()).device,
351
+ )
352
+ audio_ls = audio_pred.chunk(self.test_batch_size)
353
+ audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
354
+ length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
355
+ j = 0
356
+ for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
357
+ l = l.item()
358
+ it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
359
+ it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
360
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
361
+ save_audio(
362
+ os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
363
+ it,
364
+ self.cfg.preprocess.sample_rate,
365
+ )
366
+ save_audio(
367
+ os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
368
+ it_gt,
369
+ self.cfg.preprocess.sample_rate,
370
+ )
371
+ j += 1
372
+
373
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
374
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
375
+
376
+ def _set_random_seed(self, seed):
377
+ """Set random seed for all possible random modules."""
378
+ random.seed(seed)
379
+ np.random.seed(seed)
380
+ torch.random.manual_seed(seed)
381
+
382
+ def _count_parameters(self, model):
383
+ return sum(p.numel() for p in model.parameters())
384
+
385
+ def _dump_cfg(self, path):
386
+ os.makedirs(os.path.dirname(path), exist_ok=True)
387
+ json5.dump(
388
+ self.cfg,
389
+ open(path, "w"),
390
+ indent=4,
391
+ sort_keys=True,
392
+ ensure_ascii=False,
393
+ quote_keys=True,
394
+ )
395
+
396
+
397
+ def load_nnvocoder(
398
+ cfg,
399
+ vocoder_name,
400
+ weights_file,
401
+ from_multi_gpu=False,
402
+ ):
403
+ """Load the specified vocoder.
404
+ cfg: the vocoder config filer.
405
+ weights_file: a folder or a .pt path.
406
+ from_multi_gpu: automatically remove the "module" string in state dicts if "True".
407
+ """
408
+ print("Loading Vocoder from Weights file: {}".format(weights_file))
409
+
410
+ # Build model
411
+ model = _vocoders[vocoder_name](cfg)
412
+ if not os.path.isdir(weights_file):
413
+ # Load from .pt file
414
+ if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
415
+ ckpt = torch.load(
416
+ weights_file,
417
+ map_location=(
418
+ torch.device("cuda")
419
+ if torch.cuda.is_available()
420
+ else torch.device("cpu")
421
+ ),
422
+ )
423
+ if from_multi_gpu:
424
+ pretrained_generator_dict = ckpt["generator_state_dict"]
425
+ generator_dict = model.state_dict()
426
+
427
+ new_generator_dict = {
428
+ k.split("module.")[-1]: v
429
+ for k, v in pretrained_generator_dict.items()
430
+ if (
431
+ k.split("module.")[-1] in generator_dict
432
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
433
+ )
434
+ }
435
+
436
+ generator_dict.update(new_generator_dict)
437
+
438
+ model.load_state_dict(generator_dict)
439
+ else:
440
+ model.load_state_dict(ckpt["generator_state_dict"])
441
+ else:
442
+ model.load_state_dict(torch.load(weights_file)["state_dict"])
443
+ else:
444
+ # Load from accelerator state dict
445
+ weights_file = os.path.join(weights_file, "checkpoint")
446
+ ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
447
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
448
+ checkpoint_path = ls[0]
449
+ accelerator = accelerate.Accelerator()
450
+ model = accelerator.prepare(model)
451
+ accelerator.load_state(checkpoint_path)
452
+
453
+ if torch.cuda.is_available():
454
+ model = model.cuda()
455
+
456
+ model = model.eval()
457
+ return model
458
+
459
+
460
+ def tensorize(data, device, n_samples):
461
+ """
462
+ data: a list of numpy array
463
+ """
464
+ assert type(data) == list
465
+ if n_samples:
466
+ data = data[:n_samples]
467
+ data = [torch.as_tensor(x, device=device) for x in data]
468
+ return data
469
+
470
+
471
+ def synthesis(
472
+ cfg,
473
+ vocoder_weight_file,
474
+ n_samples,
475
+ pred,
476
+ f0s=None,
477
+ batch_size=64,
478
+ fast_inference=False,
479
+ ):
480
+ """Synthesis audios from a given vocoder and series of given features.
481
+ cfg: vocoder config.
482
+ vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
483
+ pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
484
+ """
485
+
486
+ vocoder_name = cfg.model.generator
487
+
488
+ print("Synthesis audios using {} vocoder...".format(vocoder_name))
489
+
490
+ ###### TODO: World Vocoder Refactor ######
491
+ # if vocoder_name == "world":
492
+ # world_inference.synthesis_audios(
493
+ # cfg, dataset_name, split, n_samples, pred, save_dir, tag
494
+ # )
495
+ # return
496
+
497
+ # ====== Loading neural vocoder model ======
498
+ vocoder = load_nnvocoder(
499
+ cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
500
+ )
501
+ device = next(vocoder.parameters()).device
502
+
503
+ # ====== Inference for predicted acoustic features ======
504
+ # pred: (frame_len, n_mels) -> (n_mels, frame_len)
505
+ mels_pred = tensorize([p.T for p in pred], device, n_samples)
506
+ print("For predicted mels, #sample = {}...".format(len(mels_pred)))
507
+ audios_pred = _vocoder_infer_funcs[vocoder_name](
508
+ cfg,
509
+ vocoder,
510
+ mels_pred,
511
+ f0s=f0s,
512
+ batch_size=batch_size,
513
+ fast_inference=fast_inference,
514
+ )
515
+ return audios_pred
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
35
+ ):
36
+ if not isinstance(concat_dataset, ConcatDataset):
37
+ raise ValueError(
38
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
39
+ type(concat_dataset)
40
+ )
41
+ )
42
+ if not isinstance(batch_size, int):
43
+ raise ValueError(
44
+ "batch_size must be an integer, but got {}".format(type(batch_size))
45
+ )
46
+ if not isinstance(holistic_shuffle, bool):
47
+ raise ValueError(
48
+ "holistic_shuffle must be a boolean, but got {}".format(
49
+ type(holistic_shuffle)
50
+ )
51
+ )
52
+
53
+ self.concat_dataset = concat_dataset
54
+ self.batch_size = batch_size
55
+ self.holistic_shuffle = holistic_shuffle
56
+
57
+ affected_dataset_name = []
58
+ affected_dataset_len = []
59
+ for dataset in concat_dataset.datasets:
60
+ dataset_len = len(dataset)
61
+ dataset_name = dataset.get_dataset_name()
62
+ if dataset_len < batch_size:
63
+ affected_dataset_name.append(dataset_name)
64
+ affected_dataset_len.append(dataset_len)
65
+
66
+ self.type = type
67
+ for dataset_name, dataset_len in zip(
68
+ affected_dataset_name, affected_dataset_len
69
+ ):
70
+ if not type == "valid":
71
+ logger.warning(
72
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
73
+ type, dataset_name, dataset_len, batch_size
74
+ )
75
+ )
76
+
77
+ def __len__(self):
78
+ # the number of batches with drop last
79
+ num_of_batches = sum(
80
+ [
81
+ math.floor(len(dataset) / self.batch_size)
82
+ for dataset in self.concat_dataset.datasets
83
+ ]
84
+ )
85
+ return num_of_batches * self.batch_size
86
+
87
+ def __iter__(self):
88
+ iters = []
89
+ for dataset in self.concat_dataset.datasets:
90
+ iters.append(
91
+ SequentialSampler(dataset).__iter__()
92
+ if self.holistic_shuffle
93
+ else RandomSampler(dataset).__iter__()
94
+ )
95
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
96
+ output_batches = []
97
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
98
+ cur_batch = []
99
+ for idx in iters[dataset_idx]:
100
+ cur_batch.append(idx + init_indices[dataset_idx])
101
+ if len(cur_batch) == self.batch_size:
102
+ output_batches.append(cur_batch)
103
+ cur_batch = []
104
+ if self.type == "valid" and len(cur_batch) > 0:
105
+ output_batches.append(cur_batch)
106
+ cur_batch = []
107
+ # force drop last in training
108
+ random.shuffle(output_batches)
109
+ output_indices = [item for sublist in output_batches for item in sublist]
110
+ return iter(output_indices)
111
+
112
+
113
+ def build_samplers(concat_dataset: Dataset, cfg, logger, type):
114
+ sampler = ScheduledSampler(
115
+ concat_dataset,
116
+ cfg.train.batch_size,
117
+ cfg.train.sampler.holistic_shuffle,
118
+ logger,
119
+ type,
120
+ )
121
+ batch_sampler = BatchSampler(
122
+ sampler,
123
+ cfg.train.batch_size,
124
+ cfg.train.sampler.drop_last if not type == "valid" else False,
125
+ )
126
+ return sampler, batch_sampler