xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,776 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import random
9
+ from pathlib import Path
10
+ import re
11
+ import glob
12
+
13
+ import accelerate
14
+ import json
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.utils import ProjectConfiguration
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchaudio
24
+
25
+ from accelerate.logging import get_logger
26
+
27
+ from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
28
+ from models.codec.codec_sampler import build_samplers
29
+ from models.codec.codec_trainer import CodecTrainer
30
+
31
+ from modules.dac.nn.loss import (
32
+ MultiScaleSTFTLoss,
33
+ MelSpectrogramLoss,
34
+ GANLoss,
35
+ L1Loss,
36
+ FocalLoss,
37
+ )
38
+ from audiotools import AudioSignal
39
+
40
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
41
+
42
+ try:
43
+ import nemo.collections.asr as nemo_asr
44
+ except ImportError:
45
+ print(
46
+ "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
47
+ )
48
+ nemo_asr = None
49
+
50
+ from models.codec.facodec.modules.commons import (
51
+ build_model,
52
+ load_checkpoint,
53
+ load_F0_models,
54
+ log_norm,
55
+ )
56
+ from models.codec.facodec.optimizer import build_optimizer
57
+
58
+
59
+ class FAcodecTrainer(CodecTrainer):
60
+ def __init__(self, args, cfg):
61
+ super().__init__()
62
+
63
+ self.args = args
64
+ self.cfg = cfg
65
+
66
+ cfg.exp_name = args.exp_name
67
+
68
+ # Init accelerator
69
+ self._init_accelerator()
70
+ self.accelerator.wait_for_everyone()
71
+
72
+ # Init logger
73
+ with self.accelerator.main_process_first():
74
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
75
+
76
+ self.logger.info("=" * 56)
77
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
78
+ self.logger.info("=" * 56)
79
+ self.logger.info("\n")
80
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
81
+ self.logger.info(f"Experiment name: {args.exp_name}")
82
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
83
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
84
+ if self.accelerator.is_main_process:
85
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
86
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
87
+
88
+ # Init training status
89
+ self.batch_count: int = 0
90
+ self.step: int = 0
91
+ self.epoch: int = 0
92
+
93
+ self.max_epoch = (
94
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
95
+ )
96
+ self.logger.info(
97
+ "Max epoch: {}".format(
98
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
99
+ )
100
+ )
101
+
102
+ # Check potential erorrs
103
+ if self.accelerator.is_main_process:
104
+ self._check_basic_configs()
105
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
106
+ self.checkpoints_path = [
107
+ [] for _ in range(len(self.save_checkpoint_stride))
108
+ ]
109
+ self.run_eval = self.cfg.train.run_eval
110
+
111
+ # Set random seed
112
+ with self.accelerator.main_process_first():
113
+ start = time.monotonic_ns()
114
+ self._set_random_seed(self.cfg.train.random_seed)
115
+ end = time.monotonic_ns()
116
+ self.logger.debug(
117
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
118
+ )
119
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
120
+
121
+ # Build dataloader
122
+ with self.accelerator.main_process_first():
123
+ self.logger.info("Building dataset...")
124
+ start = time.monotonic_ns()
125
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
126
+ end = time.monotonic_ns()
127
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
128
+
129
+ # Build model
130
+ with self.accelerator.main_process_first():
131
+ self.logger.info("Building model...")
132
+ start = time.monotonic_ns()
133
+ self.model = self._build_model()
134
+ end = time.monotonic_ns()
135
+ for _, model in self.model.items():
136
+ self.logger.debug(model)
137
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
138
+ self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
139
+
140
+ # Build optimizers and schedulers
141
+ with self.accelerator.main_process_first():
142
+ self.logger.info("Building optimizer and scheduler...")
143
+ start = time.monotonic_ns()
144
+ self.optimizer = self._build_optimizer()
145
+ end = time.monotonic_ns()
146
+ self.logger.info(
147
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
148
+ )
149
+
150
+ # Build helper models
151
+ with self.accelerator.main_process_first():
152
+ self.logger.info("Building helper models...")
153
+ start = time.monotonic_ns()
154
+ self._built_helper_model()
155
+ end = time.monotonic_ns()
156
+ self.logger.info(
157
+ f"Building helper models done in {(end - start) / 1e6:.2f}ms"
158
+ )
159
+
160
+ # Accelerator preparing
161
+ self.logger.info("Initializing accelerate...")
162
+ start = time.monotonic_ns()
163
+ for k in self.model:
164
+ self.model[k] = self.accelerator.prepare(self.model[k])
165
+ for k, v in self.optimizer.optimizers.items():
166
+ self.optimizer.optimizers[k] = self.accelerator.prepare(
167
+ self.optimizer.optimizers[k]
168
+ )
169
+ self.optimizer.schedulers[k] = self.accelerator.prepare(
170
+ self.optimizer.schedulers[k]
171
+ )
172
+ end = time.monotonic_ns()
173
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
174
+
175
+ # Build criterions
176
+ with self.accelerator.main_process_first():
177
+ self.logger.info("Building criterion...")
178
+ start = time.monotonic_ns()
179
+ self.criterions = self._build_criterion()
180
+ end = time.monotonic_ns()
181
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
182
+
183
+ # Resume checkpoints
184
+ with self.accelerator.main_process_first():
185
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
186
+ if args.resume_type:
187
+ self.logger.info("Resuming from checkpoint...")
188
+ start = time.monotonic_ns()
189
+ ckpt_path = Path(args.checkpoint)
190
+ if self._is_valid_pattern(ckpt_path.parts[-1]):
191
+ ckpt_path = self._load_model(args.checkpoint, args.resume_type)
192
+ else:
193
+ ckpt_path = self._load_model(
194
+ args.checkpoint, resume_type=args.resume_type
195
+ )
196
+ end = time.monotonic_ns()
197
+ self.logger.info(
198
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
199
+ )
200
+ self.checkpoints_path = json.load(
201
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
202
+ )
203
+
204
+ if self.accelerator.is_main_process:
205
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
206
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
207
+
208
+ # Save config
209
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
210
+
211
+ def _build_dataset(self):
212
+ return FAcodecDataset, FAcodecCollator
213
+
214
+ def _build_criterion(self):
215
+ criterions = dict()
216
+ stft_criterion = MultiScaleSTFTLoss()
217
+ mel_criterion = MelSpectrogramLoss(
218
+ n_mels=[5, 10, 20, 40, 80, 160, 320],
219
+ window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
220
+ mel_fmin=[0, 0, 0, 0, 0, 0, 0],
221
+ mel_fmax=[None, None, None, None, None, None, None],
222
+ pow=1.0,
223
+ mag_weight=0.0,
224
+ clamp_eps=1e-5,
225
+ )
226
+ content_criterion = FocalLoss(gamma=2)
227
+ l1_criterion = L1Loss()
228
+ criterions["stft"] = stft_criterion
229
+ criterions["mel"] = mel_criterion
230
+ criterions["l1"] = l1_criterion
231
+ criterions["content"] = content_criterion
232
+
233
+ return criterions
234
+
235
+ def _build_model(self):
236
+ model = build_model(self.cfg.model_params)
237
+ _ = [model[key].to(self.accelerator.device) for key in model]
238
+ return model
239
+
240
+ def _built_helper_model(self):
241
+ device = self.accelerator.device
242
+ self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
243
+
244
+ # load model and processor
245
+ self.w2v_processor = Wav2Vec2Processor.from_pretrained(
246
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
247
+ )
248
+ self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
249
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
250
+ ).to(device)
251
+ self.w2v_model.eval()
252
+
253
+ if nemo_asr is None:
254
+ self.speaker_model = None
255
+ else:
256
+ self.speaker_model = (
257
+ nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
258
+ "nvidia/speakerverification_en_titanet_large"
259
+ )
260
+ )
261
+ self.speaker_model = self.speaker_model.to(device)
262
+ self.speaker_model.eval()
263
+
264
+ def _build_optimizer(self):
265
+ scheduler_params = {
266
+ "warmup_steps": self.cfg.loss_params.warmup_steps,
267
+ "base_lr": self.cfg.loss_params.base_lr,
268
+ }
269
+ optimizer = build_optimizer(
270
+ {key: self.model[key] for key in self.model},
271
+ scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
272
+ lr=float(scheduler_params["base_lr"]),
273
+ )
274
+
275
+ return optimizer
276
+
277
+ def train_loop(self):
278
+ """Training process"""
279
+ self.accelerator.wait_for_everyone()
280
+
281
+ # Dump config
282
+ if self.accelerator.is_main_process:
283
+ self._dump_cfg(self.config_save_path)
284
+ _ = [self.model[key].train() for key in self.model]
285
+ self.optimizer.zero_grad()
286
+
287
+ # Sync and start training
288
+ self.accelerator.wait_for_everyone()
289
+ while self.epoch < self.max_epoch:
290
+ self.logger.info("\n")
291
+ self.logger.info("-" * 32)
292
+ self.logger.info("Epoch {}: ".format(self.epoch))
293
+
294
+ # Train and Validate
295
+ train_total_loss, train_losses = self._train_epoch()
296
+ for key, loss in train_losses.items():
297
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
298
+ self.accelerator.log(
299
+ {"Epoch/Train {} Loss".format(key): loss},
300
+ step=self.epoch,
301
+ )
302
+ self.accelerator.log(
303
+ {
304
+ "Epoch/Train Total Loss": train_total_loss,
305
+ },
306
+ step=self.epoch,
307
+ )
308
+
309
+ # Update scheduler
310
+ self.accelerator.wait_for_everyone()
311
+
312
+ # Check save checkpoint interval
313
+ run_eval = False
314
+ if self.accelerator.is_main_process:
315
+ save_checkpoint = False
316
+ for i, num in enumerate(self.save_checkpoint_stride):
317
+ if self.epoch % num == 0:
318
+ save_checkpoint = True
319
+ run_eval |= self.run_eval[i]
320
+
321
+ # Save checkpoints
322
+ self.accelerator.wait_for_everyone()
323
+ if self.accelerator.is_main_process and save_checkpoint:
324
+ print("Saving..")
325
+ state = {
326
+ "net": {key: self.model[key].state_dict() for key in self.model},
327
+ "optimizer": self.optimizer.state_dict(),
328
+ "scheduler": self.optimizer.scheduler_state_dict(),
329
+ "iters": self.step,
330
+ "epoch": self.epoch,
331
+ }
332
+ save_path = os.path.join(
333
+ self.checkpoint_dir,
334
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
335
+ )
336
+ torch.save(state, save_path)
337
+ json.dump(
338
+ self.checkpoints_path,
339
+ open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
340
+ ensure_ascii=False,
341
+ indent=4,
342
+ )
343
+
344
+ self.accelerator.wait_for_everyone()
345
+
346
+ self.epoch += 1
347
+
348
+ # Finish training
349
+ self.accelerator.wait_for_everyone()
350
+ if self.accelerator.is_main_process:
351
+ path = os.path.join(
352
+ self.checkpoint_dir,
353
+ "epoch-{:04d}_step-{:07d}".format(
354
+ self.epoch,
355
+ self.step,
356
+ ),
357
+ )
358
+ print("Saving..")
359
+ state = {
360
+ "net": {key: self.model[key].state_dict() for key in self.model},
361
+ "optimizer": self.optimizer.state_dict(),
362
+ "scheduler": self.optimizer.scheduler_state_dict(),
363
+ "iters": self.step,
364
+ "epoch": self.epoch,
365
+ }
366
+ save_path = os.path.join(
367
+ self.checkpoint_dir,
368
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
369
+ )
370
+ torch.save(state, save_path)
371
+
372
+ def _train_epoch(self):
373
+ """Training epoch. Should return average loss of a batch (sample) over
374
+ one epoch. See ``train_loop`` for usage.
375
+ """
376
+ _ = [self.model[key].train() for key in self.model]
377
+
378
+ epoch_losses: dict = {}
379
+ epoch_total_loss: int = 0
380
+
381
+ for batch in tqdm(
382
+ self.train_dataloader,
383
+ desc=f"Training Epoch {self.epoch}",
384
+ unit="batch",
385
+ colour="GREEN",
386
+ leave=False,
387
+ dynamic_ncols=True,
388
+ smoothing=0.04,
389
+ disable=not self.accelerator.is_main_process,
390
+ ):
391
+ # Get losses
392
+ total_loss, losses = self._train_step(batch)
393
+ self.batch_count += 1
394
+
395
+ # Log info
396
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
397
+ self.accelerator.log(
398
+ {
399
+ "Step/Learning Rate": (
400
+ self.optimizer.schedulers["encoder"].get_last_lr()[0]
401
+ if self.step != 0
402
+ else 0
403
+ )
404
+ },
405
+ step=self.step,
406
+ )
407
+ for key, _ in losses.items():
408
+ self.accelerator.log(
409
+ {
410
+ "Step/Train {} Loss".format(key): losses[key],
411
+ },
412
+ step=self.step,
413
+ )
414
+
415
+ if not epoch_losses:
416
+ epoch_losses = losses
417
+ else:
418
+ for key, value in losses.items():
419
+ epoch_losses[key] += value
420
+ epoch_total_loss += total_loss
421
+ self.step += 1
422
+
423
+ # Get and log total losses
424
+ self.accelerator.wait_for_everyone()
425
+ epoch_total_loss = (
426
+ epoch_total_loss
427
+ / len(self.train_dataloader)
428
+ * self.cfg.train.gradient_accumulation_step
429
+ )
430
+ for key in epoch_losses.keys():
431
+ epoch_losses[key] = (
432
+ epoch_losses[key]
433
+ / len(self.train_dataloader)
434
+ * self.cfg.train.gradient_accumulation_step
435
+ )
436
+ return epoch_total_loss, epoch_losses
437
+
438
+ def _train_step(self, data):
439
+ """Training forward step. Should return average loss of a sample over
440
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
441
+ See ``_train_epoch`` for usage.
442
+ """
443
+ # Init losses
444
+ train_losses = {}
445
+ total_loss = 0
446
+
447
+ # Use input feature to get predictions
448
+ data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
449
+ waves, mels, wave_lengths, mel_input_length = data
450
+
451
+ # extract semantic latent with w2v model
452
+ waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
453
+ w2v_input = self.w2v_processor(
454
+ waves_16k, sampling_rate=16000, return_tensors="pt"
455
+ ).input_values.to(self.accelerator.device)
456
+ with torch.no_grad():
457
+ w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
458
+ predicted_ids = torch.argmax(w2v_outputs, dim=-1)
459
+ phone_ids = (
460
+ F.interpolate(
461
+ predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
462
+ )
463
+ .long()
464
+ .squeeze(0)
465
+ )
466
+
467
+ # get clips
468
+ mel_seg_len = min(
469
+ [int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
470
+ )
471
+
472
+ gt_mel_seg = []
473
+ wav_seg = []
474
+ w2v_seg = []
475
+
476
+ for bib in range(len(mel_input_length)):
477
+ mel_length = int(mel_input_length[bib].item())
478
+
479
+ random_start = (
480
+ np.random.randint(0, mel_length - mel_seg_len)
481
+ if mel_length != mel_seg_len
482
+ else 0
483
+ )
484
+ gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
485
+
486
+ # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
487
+ w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
488
+
489
+ y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
490
+
491
+ wav_seg.append(y.to(self.accelerator.device))
492
+
493
+ gt_mel_seg = torch.stack(gt_mel_seg).detach()
494
+
495
+ wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
496
+ w2v_seg = torch.stack(w2v_seg).float().detach()
497
+
498
+ with torch.no_grad():
499
+ real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
500
+ F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
501
+
502
+ # normalize f0
503
+ # Remove unvoiced frames (replace with -1)
504
+ gt_glob_f0s = []
505
+ f0_targets = []
506
+ for bib in range(len(F0_real)):
507
+ voiced_indices = F0_real[bib] > 5.0
508
+ f0_voiced = F0_real[bib][voiced_indices]
509
+
510
+ if len(f0_voiced) != 0:
511
+ # Convert to log scale
512
+ log_f0 = f0_voiced.log2()
513
+
514
+ # Calculate mean and standard deviation
515
+ mean_f0 = log_f0.mean()
516
+ std_f0 = log_f0.std()
517
+
518
+ # Normalize the F0 sequence
519
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
520
+
521
+ # Create the normalized F0 sequence with unvoiced frames
522
+ normalized_sequence = torch.zeros_like(F0_real[bib])
523
+ normalized_sequence[voiced_indices] = normalized_f0
524
+ normalized_sequence[~voiced_indices] = (
525
+ -10
526
+ ) # Assign -10 to unvoiced frames
527
+
528
+ gt_glob_f0s.append(mean_f0)
529
+ else:
530
+ normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
531
+ gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
532
+
533
+ # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
534
+ f0_targets.append(normalized_sequence)
535
+ f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
536
+ # fill nan with -10
537
+ f0_targets[torch.isnan(f0_targets)] = -10.0
538
+ # fill inf with -10
539
+ f0_targets[torch.isinf(f0_targets)] = -10.0
540
+ # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
541
+ if self.cfg.preprocess_params.frame_rate != 80:
542
+ f0_targets = F.interpolate(
543
+ f0_targets.unsqueeze(1),
544
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
545
+ mode="nearest",
546
+ ).squeeze(1)
547
+ w2v_seg = F.interpolate(
548
+ w2v_seg,
549
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
550
+ mode="nearest",
551
+ )
552
+
553
+ wav_seg_input = wav_seg
554
+ wav_seg_target = wav_seg
555
+
556
+ z = self.model.encoder(wav_seg_input)
557
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
558
+ z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
559
+ )
560
+ preds, rev_preds = self.model.fa_predictors(quantized, timbre)
561
+
562
+ pred_wave = self.model.decoder(z)
563
+
564
+ len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
565
+ if len_diff > 0:
566
+ wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
567
+
568
+ # discriminator loss
569
+ d_fake = self.model.discriminator(pred_wave.detach())
570
+ d_real = self.model.discriminator(wav_seg_target)
571
+ loss_d = 0
572
+ for x_fake, x_real in zip(d_fake, d_real):
573
+ loss_d += torch.mean(x_fake[-1] ** 2)
574
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
575
+
576
+ self.optimizer.zero_grad()
577
+ self.accelerator.backward(loss_d)
578
+ grad_norm_d = torch.nn.utils.clip_grad_norm_(
579
+ self.model.discriminator.parameters(), 10.0
580
+ )
581
+ self.optimizer.step("discriminator")
582
+ self.optimizer.scheduler(key="discriminator")
583
+
584
+ # generator loss
585
+ signal = AudioSignal(wav_seg_target, sample_rate=24000)
586
+ recons = AudioSignal(pred_wave, sample_rate=24000)
587
+ stft_loss = self.criterions["stft"](recons, signal)
588
+ mel_loss = self.criterions["mel"](recons, signal)
589
+ waveform_loss = self.criterions["l1"](recons, signal)
590
+
591
+ d_fake = self.model.discriminator(pred_wave)
592
+ d_real = self.model.discriminator(wav_seg_target)
593
+
594
+ loss_g = 0
595
+ for x_fake in d_fake:
596
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
597
+
598
+ loss_feature = 0
599
+
600
+ for i in range(len(d_fake)):
601
+ for j in range(len(d_fake[i]) - 1):
602
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
603
+
604
+ pred_f0, pred_uv = preds["f0"], preds["uv"]
605
+ rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
606
+
607
+ common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
608
+ f0_targets = f0_targets[..., :common_min_size]
609
+ real_norm = real_norm[..., :common_min_size]
610
+
611
+ f0_loss = F.smooth_l1_loss(
612
+ f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
613
+ )
614
+ uv_loss = F.smooth_l1_loss(
615
+ real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
616
+ )
617
+ rev_f0_loss = (
618
+ F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
619
+ if rev_pred_f0 is not None
620
+ else torch.FloatTensor([0]).to(self.accelerator.device)
621
+ )
622
+ rev_uv_loss = (
623
+ F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
624
+ if rev_pred_uv is not None
625
+ else torch.FloatTensor([0]).to(self.accelerator.device)
626
+ )
627
+
628
+ tot_f0_loss = f0_loss + rev_f0_loss
629
+ tot_uv_loss = uv_loss + rev_uv_loss
630
+
631
+ pred_content = preds["content"]
632
+ rev_pred_content = rev_preds["rev_content"]
633
+
634
+ target_content_latents = w2v_seg[..., :common_min_size]
635
+
636
+ content_loss = self.criterions["content"](
637
+ pred_content.transpose(1, 2)[..., :common_min_size],
638
+ target_content_latents.long(),
639
+ )
640
+ rev_content_loss = (
641
+ self.criterions["content"](
642
+ rev_pred_content.transpose(1, 2)[..., :common_min_size],
643
+ target_content_latents.long(),
644
+ )
645
+ if rev_pred_content is not None
646
+ else torch.FloatTensor([0]).to(self.accelerator.device)
647
+ )
648
+
649
+ tot_content_loss = content_loss + rev_content_loss
650
+
651
+ if self.speaker_model is not None:
652
+ spk_logits = torch.cat(
653
+ [
654
+ self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
655
+ for w16, wl in zip(waves_16k, wave_lengths)
656
+ ],
657
+ dim=0,
658
+ )
659
+ spk_labels = spk_logits.argmax(dim=-1)
660
+ else:
661
+ spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
662
+ self.accelerator.device
663
+ )
664
+
665
+ spk_pred_logits = preds["timbre"]
666
+ spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
667
+ x_spk_pred_logits = rev_preds["x_timbre"]
668
+
669
+ x_spk_loss = (
670
+ F.cross_entropy(x_spk_pred_logits, spk_labels)
671
+ if x_spk_pred_logits is not None
672
+ else torch.FloatTensor([0]).to(self.accelerator.device)
673
+ )
674
+
675
+ tot_spk_loss = spk_loss + x_spk_loss
676
+
677
+ loss_gen_all = (
678
+ mel_loss * 15.0
679
+ + loss_feature * 1.0
680
+ + loss_g * 1.0
681
+ + commitment_loss * 0.25
682
+ + codebook_loss * 1.0
683
+ + tot_f0_loss * 1.0
684
+ + tot_uv_loss * 1.0
685
+ + tot_content_loss * 5.0
686
+ + tot_spk_loss * 5.0
687
+ )
688
+
689
+ self.optimizer.zero_grad()
690
+ self.accelerator.backward(loss_gen_all)
691
+
692
+ with torch.no_grad():
693
+ total_loss = loss_gen_all.item()
694
+ train_losses["stft"] = stft_loss.item()
695
+ train_losses["mel"] = mel_loss.item()
696
+ train_losses["l1"] = waveform_loss.item()
697
+ train_losses["f0"] = f0_loss.item()
698
+ train_losses["uv"] = uv_loss.item()
699
+ train_losses["content"] = content_loss.item()
700
+ train_losses["speaker"] = spk_loss.item()
701
+ train_losses["rev_f0"] = rev_f0_loss.item()
702
+ train_losses["rev_uv"] = rev_uv_loss.item()
703
+ train_losses["rev_content"] = rev_content_loss.item()
704
+ train_losses["rev_speaker"] = x_spk_loss.item()
705
+
706
+ train_losses["feature"] = loss_feature.item()
707
+ train_losses["generator"] = loss_g.item()
708
+ train_losses["commitment"] = commitment_loss.item()
709
+ train_losses["codebook"] = codebook_loss.item()
710
+
711
+ # discriminators
712
+ train_losses["discriminator"] = loss_d.item()
713
+
714
+ return total_loss, train_losses
715
+
716
+ def _inference(self, eval_wave):
717
+ """Inference during training for test audios."""
718
+ z = self.model.encoder(
719
+ eval_wave[None, None, ...].to(self.accelerator.device).float()
720
+ )
721
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
722
+ z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
723
+ )
724
+ full_pred_wave = self.model.decoder(z)
725
+ return full_pred_wave[0]
726
+
727
+ def _load_model(self, checkpoint_path=None, resume_type="resume"):
728
+ """Load model from checkpoint. If checkpoint_path is None, it will
729
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
730
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
731
+ method after** ``accelerator.prepare()``.
732
+ """
733
+ if resume_type == "resume":
734
+ if checkpoint_path is None:
735
+ available_checkpoints = glob.glob(
736
+ os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
737
+ )
738
+ # find the checkpoint that has the highest step number
739
+ latest_checkpoint = max(
740
+ available_checkpoints,
741
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
742
+ )
743
+ earliest_checkpoint = min(
744
+ available_checkpoints,
745
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
746
+ )
747
+ # delete the earliest checkpoint
748
+ if (
749
+ earliest_checkpoint != latest_checkpoint
750
+ and self.accelerator.is_main_process
751
+ and len(available_checkpoints) > 4
752
+ ):
753
+ os.remove(earliest_checkpoint)
754
+ print(f"Removed {earliest_checkpoint}")
755
+ else:
756
+ latest_checkpoint = checkpoint_path
757
+
758
+ self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
759
+ self.model,
760
+ self.optimizer,
761
+ latest_checkpoint,
762
+ load_only_params=False,
763
+ ignore_modules=[],
764
+ is_distributed=self.accelerator.num_processes > 1,
765
+ )
766
+
767
+ else:
768
+ raise ValueError("Invalid resume type")
769
+ return checkpoint_path
770
+
771
+ def _count_parameters(self):
772
+ total_num = sum(
773
+ sum(p.numel() for p in self.model[key].parameters()) for key in self.model
774
+ )
775
+ # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
776
+ return total_num