xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1222 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn, sin, pow
9
+ from torch.nn import Parameter
10
+ import torch.nn.functional as F
11
+ from torch.nn.utils import weight_norm
12
+ from .alias_free_torch import *
13
+ from .quantize import *
14
+ from einops import rearrange
15
+ from einops.layers.torch import Rearrange
16
+ from .transformer import TransformerEncoder
17
+ from .gradient_reversal import GradientReversal
18
+ from .melspec import MelSpectrogram
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, nn.Conv1d):
23
+ nn.init.trunc_normal_(m.weight, std=0.02)
24
+ nn.init.constant_(m.bias, 0)
25
+
26
+
27
+ def WNConv1d(*args, **kwargs):
28
+ return weight_norm(nn.Conv1d(*args, **kwargs))
29
+
30
+
31
+ def WNConvTranspose1d(*args, **kwargs):
32
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
33
+
34
+
35
+ class CNNLSTM(nn.Module):
36
+ def __init__(self, indim, outdim, head, global_pred=False):
37
+ super().__init__()
38
+ self.global_pred = global_pred
39
+ self.model = nn.Sequential(
40
+ ResidualUnit(indim, dilation=1),
41
+ ResidualUnit(indim, dilation=2),
42
+ ResidualUnit(indim, dilation=3),
43
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
44
+ Rearrange("b c t -> b t c"),
45
+ )
46
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
47
+
48
+ def forward(self, x):
49
+ # x: [B, C, T]
50
+ x = self.model(x)
51
+ if self.global_pred:
52
+ x = torch.mean(x, dim=1, keepdim=False)
53
+ outs = [head(x) for head in self.heads]
54
+ return outs
55
+
56
+
57
+ class SnakeBeta(nn.Module):
58
+ """
59
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
60
+ Shape:
61
+ - Input: (B, C, T)
62
+ - Output: (B, C, T), same shape as the input
63
+ Parameters:
64
+ - alpha - trainable parameter that controls frequency
65
+ - beta - trainable parameter that controls magnitude
66
+ References:
67
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
68
+ https://arxiv.org/abs/2006.08195
69
+ Examples:
70
+ >>> a1 = snakebeta(256)
71
+ >>> x = torch.randn(256)
72
+ >>> x = a1(x)
73
+ """
74
+
75
+ def __init__(
76
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
77
+ ):
78
+ """
79
+ Initialization.
80
+ INPUT:
81
+ - in_features: shape of the input
82
+ - alpha - trainable parameter that controls frequency
83
+ - beta - trainable parameter that controls magnitude
84
+ alpha is initialized to 1 by default, higher values = higher-frequency.
85
+ beta is initialized to 1 by default, higher values = higher-magnitude.
86
+ alpha will be trained along with the rest of your model.
87
+ """
88
+ super(SnakeBeta, self).__init__()
89
+ self.in_features = in_features
90
+
91
+ # initialize alpha
92
+ self.alpha_logscale = alpha_logscale
93
+ if self.alpha_logscale: # log scale alphas initialized to zeros
94
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
95
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
96
+ else: # linear scale alphas initialized to ones
97
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
98
+ self.beta = Parameter(torch.ones(in_features) * alpha)
99
+
100
+ self.alpha.requires_grad = alpha_trainable
101
+ self.beta.requires_grad = alpha_trainable
102
+
103
+ self.no_div_by_zero = 0.000000001
104
+
105
+ def forward(self, x):
106
+ """
107
+ Forward pass of the function.
108
+ Applies the function to the input elementwise.
109
+ SnakeBeta := x + 1/b * sin^2 (xa)
110
+ """
111
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
112
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
113
+ if self.alpha_logscale:
114
+ alpha = torch.exp(alpha)
115
+ beta = torch.exp(beta)
116
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
117
+
118
+ return x
119
+
120
+
121
+ class ResidualUnit(nn.Module):
122
+ def __init__(self, dim: int = 16, dilation: int = 1):
123
+ super().__init__()
124
+ pad = ((7 - 1) * dilation) // 2
125
+ self.block = nn.Sequential(
126
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
127
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
128
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
129
+ WNConv1d(dim, dim, kernel_size=1),
130
+ )
131
+
132
+ def forward(self, x):
133
+ return x + self.block(x)
134
+
135
+
136
+ class EncoderBlock(nn.Module):
137
+ def __init__(self, dim: int = 16, stride: int = 1):
138
+ super().__init__()
139
+ self.block = nn.Sequential(
140
+ ResidualUnit(dim // 2, dilation=1),
141
+ ResidualUnit(dim // 2, dilation=3),
142
+ ResidualUnit(dim // 2, dilation=9),
143
+ Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
144
+ WNConv1d(
145
+ dim // 2,
146
+ dim,
147
+ kernel_size=2 * stride,
148
+ stride=stride,
149
+ padding=stride // 2 + stride % 2,
150
+ ),
151
+ )
152
+
153
+ def forward(self, x):
154
+ return self.block(x)
155
+
156
+
157
+ class FACodecEncoder(nn.Module):
158
+ def __init__(
159
+ self,
160
+ ngf=32,
161
+ up_ratios=(2, 4, 5, 5),
162
+ out_channels=1024,
163
+ ):
164
+ super().__init__()
165
+ self.hop_length = np.prod(up_ratios)
166
+ self.up_ratios = up_ratios
167
+
168
+ # Create first convolution
169
+ d_model = ngf
170
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
171
+
172
+ # Create EncoderBlocks that double channels as they downsample by `stride`
173
+ for stride in up_ratios:
174
+ d_model *= 2
175
+ self.block += [EncoderBlock(d_model, stride=stride)]
176
+
177
+ # Create last convolution
178
+ self.block += [
179
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
180
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
181
+ ]
182
+
183
+ # Wrap black into nn.Sequential
184
+ self.block = nn.Sequential(*self.block)
185
+ self.enc_dim = d_model
186
+
187
+ self.reset_parameters()
188
+
189
+ def forward(self, x):
190
+ out = self.block(x)
191
+ return out
192
+
193
+ def inference(self, x):
194
+ return self.block(x)
195
+
196
+ def remove_weight_norm(self):
197
+ """Remove weight normalization module from all of the layers."""
198
+
199
+ def _remove_weight_norm(m):
200
+ try:
201
+ torch.nn.utils.remove_weight_norm(m)
202
+ except ValueError: # this module didn't have weight norm
203
+ return
204
+
205
+ self.apply(_remove_weight_norm)
206
+
207
+ def apply_weight_norm(self):
208
+ """Apply weight normalization module from all of the layers."""
209
+
210
+ def _apply_weight_norm(m):
211
+ if isinstance(m, nn.Conv1d):
212
+ torch.nn.utils.weight_norm(m)
213
+
214
+ self.apply(_apply_weight_norm)
215
+
216
+ def reset_parameters(self):
217
+ self.apply(init_weights)
218
+
219
+
220
+ class DecoderBlock(nn.Module):
221
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
222
+ super().__init__()
223
+ self.block = nn.Sequential(
224
+ Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
225
+ WNConvTranspose1d(
226
+ input_dim,
227
+ output_dim,
228
+ kernel_size=2 * stride,
229
+ stride=stride,
230
+ padding=stride // 2 + stride % 2,
231
+ output_padding=stride % 2,
232
+ ),
233
+ ResidualUnit(output_dim, dilation=1),
234
+ ResidualUnit(output_dim, dilation=3),
235
+ ResidualUnit(output_dim, dilation=9),
236
+ )
237
+
238
+ def forward(self, x):
239
+ return self.block(x)
240
+
241
+
242
+ class FACodecDecoder(nn.Module):
243
+ def __init__(
244
+ self,
245
+ in_channels=256,
246
+ upsample_initial_channel=1536,
247
+ ngf=32,
248
+ up_ratios=(5, 5, 4, 2),
249
+ vq_num_q_c=2,
250
+ vq_num_q_p=1,
251
+ vq_num_q_r=3,
252
+ vq_dim=1024,
253
+ vq_commit_weight=0.005,
254
+ vq_weight_init=False,
255
+ vq_full_commit_loss=False,
256
+ codebook_dim=8,
257
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
258
+ codebook_size_content=10,
259
+ codebook_size_residual=10,
260
+ quantizer_dropout=0.0,
261
+ dropout_type="linear",
262
+ use_gr_content_f0=False,
263
+ use_gr_prosody_phone=False,
264
+ use_gr_residual_f0=False,
265
+ use_gr_residual_phone=False,
266
+ use_gr_x_timbre=False,
267
+ use_random_mask_residual=True,
268
+ prob_random_mask_residual=0.75,
269
+ ):
270
+ super().__init__()
271
+ self.hop_length = np.prod(up_ratios)
272
+ self.ngf = ngf
273
+ self.up_ratios = up_ratios
274
+
275
+ self.use_random_mask_residual = use_random_mask_residual
276
+ self.prob_random_mask_residual = prob_random_mask_residual
277
+
278
+ self.vq_num_q_p = vq_num_q_p
279
+ self.vq_num_q_c = vq_num_q_c
280
+ self.vq_num_q_r = vq_num_q_r
281
+
282
+ self.codebook_size_prosody = codebook_size_prosody
283
+ self.codebook_size_content = codebook_size_content
284
+ self.codebook_size_residual = codebook_size_residual
285
+
286
+ quantizer_class = ResidualVQ
287
+
288
+ self.quantizer = nn.ModuleList()
289
+
290
+ # prosody
291
+ quantizer = quantizer_class(
292
+ num_quantizers=vq_num_q_p,
293
+ dim=vq_dim,
294
+ codebook_size=codebook_size_prosody,
295
+ codebook_dim=codebook_dim,
296
+ threshold_ema_dead_code=2,
297
+ commitment=vq_commit_weight,
298
+ weight_init=vq_weight_init,
299
+ full_commit_loss=vq_full_commit_loss,
300
+ quantizer_dropout=quantizer_dropout,
301
+ dropout_type=dropout_type,
302
+ )
303
+ self.quantizer.append(quantizer)
304
+
305
+ # phone
306
+ quantizer = quantizer_class(
307
+ num_quantizers=vq_num_q_c,
308
+ dim=vq_dim,
309
+ codebook_size=codebook_size_content,
310
+ codebook_dim=codebook_dim,
311
+ threshold_ema_dead_code=2,
312
+ commitment=vq_commit_weight,
313
+ weight_init=vq_weight_init,
314
+ full_commit_loss=vq_full_commit_loss,
315
+ quantizer_dropout=quantizer_dropout,
316
+ dropout_type=dropout_type,
317
+ )
318
+ self.quantizer.append(quantizer)
319
+
320
+ # residual
321
+ if self.vq_num_q_r > 0:
322
+ quantizer = quantizer_class(
323
+ num_quantizers=vq_num_q_r,
324
+ dim=vq_dim,
325
+ codebook_size=codebook_size_residual,
326
+ codebook_dim=codebook_dim,
327
+ threshold_ema_dead_code=2,
328
+ commitment=vq_commit_weight,
329
+ weight_init=vq_weight_init,
330
+ full_commit_loss=vq_full_commit_loss,
331
+ quantizer_dropout=quantizer_dropout,
332
+ dropout_type=dropout_type,
333
+ )
334
+ self.quantizer.append(quantizer)
335
+
336
+ # Add first conv layer
337
+ channels = upsample_initial_channel
338
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
339
+
340
+ # Add upsampling + MRF blocks
341
+ for i, stride in enumerate(up_ratios):
342
+ input_dim = channels // 2**i
343
+ output_dim = channels // 2 ** (i + 1)
344
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
345
+
346
+ # Add final conv layer
347
+ layers += [
348
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
349
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
350
+ nn.Tanh(),
351
+ ]
352
+
353
+ self.model = nn.Sequential(*layers)
354
+
355
+ self.timbre_encoder = TransformerEncoder(
356
+ enc_emb_tokens=None,
357
+ encoder_layer=4,
358
+ encoder_hidden=256,
359
+ encoder_head=4,
360
+ conv_filter_size=1024,
361
+ conv_kernel_size=5,
362
+ encoder_dropout=0.1,
363
+ use_cln=False,
364
+ )
365
+
366
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
367
+ self.timbre_linear.bias.data[:in_channels] = 1
368
+ self.timbre_linear.bias.data[in_channels:] = 0
369
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
370
+
371
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
372
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
373
+
374
+ self.use_gr_content_f0 = use_gr_content_f0
375
+ self.use_gr_prosody_phone = use_gr_prosody_phone
376
+ self.use_gr_residual_f0 = use_gr_residual_f0
377
+ self.use_gr_residual_phone = use_gr_residual_phone
378
+ self.use_gr_x_timbre = use_gr_x_timbre
379
+
380
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
381
+ self.res_f0_predictor = nn.Sequential(
382
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
383
+ )
384
+
385
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
386
+ self.res_phone_predictor = nn.Sequential(
387
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
388
+ )
389
+
390
+ if self.use_gr_content_f0:
391
+ self.content_f0_predictor = nn.Sequential(
392
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
393
+ )
394
+
395
+ if self.use_gr_prosody_phone:
396
+ self.prosody_phone_predictor = nn.Sequential(
397
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
398
+ )
399
+
400
+ if self.use_gr_x_timbre:
401
+ self.x_timbre_predictor = nn.Sequential(
402
+ GradientReversal(alpha=1),
403
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
404
+ )
405
+
406
+ self.reset_parameters()
407
+
408
+ def quantize(self, x, n_quantizers=None):
409
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
410
+
411
+ # prosody
412
+ f0_input = x # (B, d, T)
413
+ f0_quantizer = self.quantizer[0]
414
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
415
+ outs += out
416
+ qs.append(q)
417
+ quantized_buf.append(quantized.sum(0))
418
+ commit_loss.append(commit)
419
+
420
+ # phone
421
+ phone_input = x
422
+ phone_quantizer = self.quantizer[1]
423
+ out, q, commit, quantized = phone_quantizer(
424
+ phone_input, n_quantizers=n_quantizers
425
+ )
426
+ outs += out
427
+ qs.append(q)
428
+ quantized_buf.append(quantized.sum(0))
429
+ commit_loss.append(commit)
430
+
431
+ # residual
432
+ if self.vq_num_q_r > 0:
433
+ residual_quantizer = self.quantizer[2]
434
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
435
+ out, q, commit, quantized = residual_quantizer(
436
+ residual_input, n_quantizers=n_quantizers
437
+ )
438
+ outs += out
439
+ qs.append(q)
440
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
441
+ commit_loss.append(commit)
442
+
443
+ qs = torch.cat(qs, dim=0)
444
+ commit_loss = torch.cat(commit_loss, dim=0)
445
+ return outs, qs, commit_loss, quantized_buf
446
+
447
+ def forward(
448
+ self,
449
+ x,
450
+ vq=True,
451
+ get_vq=False,
452
+ eval_vq=True,
453
+ speaker_embedding=None,
454
+ n_quantizers=None,
455
+ quantized=None,
456
+ ):
457
+ if get_vq:
458
+ return self.quantizer.get_emb()
459
+ if vq is True:
460
+ if eval_vq:
461
+ self.quantizer.eval()
462
+ x_timbre = x
463
+ outs, qs, commit_loss, quantized_buf = self.quantize(
464
+ x, n_quantizers=n_quantizers
465
+ )
466
+
467
+ x_timbre = x_timbre.transpose(1, 2)
468
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
469
+ x_timbre = x_timbre.transpose(1, 2)
470
+ spk_embs = torch.mean(x_timbre, dim=2)
471
+ return outs, qs, commit_loss, quantized_buf, spk_embs
472
+
473
+ out = {}
474
+
475
+ layer_0 = quantized[0]
476
+ f0, uv = self.f0_predictor(layer_0)
477
+ f0 = rearrange(f0, "... 1 -> ...")
478
+ uv = rearrange(uv, "... 1 -> ...")
479
+
480
+ layer_1 = quantized[1]
481
+ (phone,) = self.phone_predictor(layer_1)
482
+
483
+ out = {"f0": f0, "uv": uv, "phone": phone}
484
+
485
+ if self.use_gr_prosody_phone:
486
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
487
+ out["prosody_phone"] = prosody_phone
488
+
489
+ if self.use_gr_content_f0:
490
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
491
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
492
+ content_uv = rearrange(content_uv, "... 1 -> ...")
493
+ out["content_f0"] = content_f0
494
+ out["content_uv"] = content_uv
495
+
496
+ if self.vq_num_q_r > 0:
497
+ layer_2 = quantized[2]
498
+
499
+ if self.use_gr_residual_f0:
500
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
501
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
502
+ res_uv = rearrange(res_uv, "... 1 -> ...")
503
+ out["res_f0"] = res_f0
504
+ out["res_uv"] = res_uv
505
+
506
+ if self.use_gr_residual_phone:
507
+ (res_phone,) = self.res_phone_predictor(layer_2)
508
+ out["res_phone"] = res_phone
509
+
510
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
511
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
512
+ if self.vq_num_q_r > 0:
513
+ if self.use_random_mask_residual:
514
+ bsz = quantized[2].shape[0]
515
+ res_mask = np.random.choice(
516
+ [0, 1],
517
+ size=bsz,
518
+ p=[
519
+ self.prob_random_mask_residual,
520
+ 1 - self.prob_random_mask_residual,
521
+ ],
522
+ )
523
+ res_mask = (
524
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
525
+ ) # (B, 1, 1)
526
+ res_mask = res_mask.to(
527
+ device=quantized[2].device, dtype=quantized[2].dtype
528
+ )
529
+ x = (
530
+ quantized[0].detach()
531
+ + quantized[1].detach()
532
+ + quantized[2] * res_mask
533
+ )
534
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
535
+ else:
536
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
537
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
538
+ else:
539
+ x = quantized[0].detach() + quantized[1].detach()
540
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
541
+
542
+ if self.use_gr_x_timbre:
543
+ (x_timbre,) = self.x_timbre_predictor(x)
544
+ out["x_timbre"] = x_timbre
545
+
546
+ x = x.transpose(1, 2)
547
+ x = self.timbre_norm(x)
548
+ x = x.transpose(1, 2)
549
+ x = x * gamma + beta
550
+
551
+ x = self.model(x)
552
+ out["audio"] = x
553
+
554
+ return out
555
+
556
+ def vq2emb(self, vq, use_residual_code=True):
557
+ # vq: [num_quantizer, B, T]
558
+ self.quantizer = self.quantizer.eval()
559
+ out = 0
560
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
561
+ out += self.quantizer[1].vq2emb(
562
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
563
+ )
564
+ if self.vq_num_q_r > 0 and use_residual_code:
565
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
566
+ return out
567
+
568
+ def inference(self, x, speaker_embedding):
569
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
570
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
571
+ x = x.transpose(1, 2)
572
+ x = self.timbre_norm(x)
573
+ x = x.transpose(1, 2)
574
+ x = x * gamma + beta
575
+ x = self.model(x)
576
+ return x
577
+
578
+ def remove_weight_norm(self):
579
+ """Remove weight normalization module from all of the layers."""
580
+
581
+ def _remove_weight_norm(m):
582
+ try:
583
+ torch.nn.utils.remove_weight_norm(m)
584
+ except ValueError: # this module didn't have weight norm
585
+ return
586
+
587
+ self.apply(_remove_weight_norm)
588
+
589
+ def apply_weight_norm(self):
590
+ """Apply weight normalization module from all of the layers."""
591
+
592
+ def _apply_weight_norm(m):
593
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
594
+ torch.nn.utils.weight_norm(m)
595
+
596
+ self.apply(_apply_weight_norm)
597
+
598
+ def reset_parameters(self):
599
+ self.apply(init_weights)
600
+
601
+
602
+ class FACodecRedecoder(nn.Module):
603
+ def __init__(
604
+ self,
605
+ in_channels=256,
606
+ upsample_initial_channel=1280,
607
+ up_ratios=(5, 5, 4, 2),
608
+ vq_num_q_c=2,
609
+ vq_num_q_p=1,
610
+ vq_num_q_r=3,
611
+ vq_dim=256,
612
+ codebook_size_prosody=10,
613
+ codebook_size_content=10,
614
+ codebook_size_residual=10,
615
+ ):
616
+ super().__init__()
617
+ self.hop_length = np.prod(up_ratios)
618
+ self.up_ratios = up_ratios
619
+
620
+ self.vq_num_q_p = vq_num_q_p
621
+ self.vq_num_q_c = vq_num_q_c
622
+ self.vq_num_q_r = vq_num_q_r
623
+
624
+ self.vq_dim = vq_dim
625
+
626
+ self.codebook_size_prosody = codebook_size_prosody
627
+ self.codebook_size_content = codebook_size_content
628
+ self.codebook_size_residual = codebook_size_residual
629
+
630
+ self.prosody_embs = nn.ModuleList()
631
+ for i in range(self.vq_num_q_p):
632
+ emb_tokens = nn.Embedding(
633
+ num_embeddings=2**self.codebook_size_prosody,
634
+ embedding_dim=self.vq_dim,
635
+ )
636
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
637
+ self.prosody_embs.append(emb_tokens)
638
+ self.content_embs = nn.ModuleList()
639
+ for i in range(self.vq_num_q_c):
640
+ emb_tokens = nn.Embedding(
641
+ num_embeddings=2**self.codebook_size_content,
642
+ embedding_dim=self.vq_dim,
643
+ )
644
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
645
+ self.content_embs.append(emb_tokens)
646
+ self.residual_embs = nn.ModuleList()
647
+ for i in range(self.vq_num_q_r):
648
+ emb_tokens = nn.Embedding(
649
+ num_embeddings=2**self.codebook_size_residual,
650
+ embedding_dim=self.vq_dim,
651
+ )
652
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
653
+ self.residual_embs.append(emb_tokens)
654
+
655
+ # Add first conv layer
656
+ channels = upsample_initial_channel
657
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
658
+
659
+ # Add upsampling + MRF blocks
660
+ for i, stride in enumerate(up_ratios):
661
+ input_dim = channels // 2**i
662
+ output_dim = channels // 2 ** (i + 1)
663
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
664
+
665
+ # Add final conv layer
666
+ layers += [
667
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
668
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
669
+ nn.Tanh(),
670
+ ]
671
+
672
+ self.model = nn.Sequential(*layers)
673
+
674
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
675
+ self.timbre_linear.bias.data[:in_channels] = 1
676
+ self.timbre_linear.bias.data[in_channels:] = 0
677
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
678
+
679
+ self.timbre_cond_prosody_enc = TransformerEncoder(
680
+ enc_emb_tokens=None,
681
+ encoder_layer=4,
682
+ encoder_hidden=256,
683
+ encoder_head=4,
684
+ conv_filter_size=1024,
685
+ conv_kernel_size=5,
686
+ encoder_dropout=0.1,
687
+ use_cln=True,
688
+ cfg=None,
689
+ )
690
+
691
+ def forward(
692
+ self,
693
+ vq,
694
+ speaker_embedding,
695
+ use_residual_code=False,
696
+ ):
697
+
698
+ x = 0
699
+
700
+ x_p = 0
701
+ for i in range(self.vq_num_q_p):
702
+ x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d)
703
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1)
704
+ x_p = self.timbre_cond_prosody_enc(
705
+ x_p, key_padding_mask=None, condition=spk_cond
706
+ )
707
+ x = x + x_p
708
+
709
+ x_c = 0
710
+ for i in range(self.vq_num_q_c):
711
+ x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i])
712
+
713
+ x = x + x_c
714
+
715
+ if use_residual_code:
716
+
717
+ x_r = 0
718
+ for i in range(self.vq_num_q_r):
719
+ x_r = x_r + self.residual_embs[i](
720
+ vq[self.vq_num_q_p + self.vq_num_q_c + i]
721
+ )
722
+ x = x + x_r
723
+
724
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
725
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
726
+ x = x.transpose(1, 2)
727
+ x = self.timbre_norm(x)
728
+ x = x.transpose(1, 2)
729
+ x = x * gamma + beta
730
+ x = self.model(x)
731
+
732
+ return x
733
+
734
+ def vq2emb(self, vq, speaker_embedding, use_residual=True):
735
+
736
+ out = 0
737
+
738
+ x_t = 0
739
+ for i in range(self.vq_num_q_p):
740
+ x_t += self.prosody_embs[i](vq[i]) # (B, T, d)
741
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1)
742
+ x_t = self.timbre_cond_prosody_enc(
743
+ x_t, key_padding_mask=None, condition=spk_cond
744
+ )
745
+
746
+ # prosody
747
+ out += x_t
748
+
749
+ # content
750
+ for i in range(self.vq_num_q_c):
751
+ out += self.content_embs[i](vq[self.vq_num_q_p + i])
752
+
753
+ # residual
754
+ if use_residual:
755
+ for i in range(self.vq_num_q_r):
756
+ out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
757
+
758
+ out = out.transpose(1, 2) # (B, T, d) -> (B, d, T)
759
+ return out
760
+
761
+ def inference(self, x, speaker_embedding):
762
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
763
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
764
+ x = x.transpose(1, 2)
765
+ x = self.timbre_norm(x)
766
+ x = x.transpose(1, 2)
767
+ x = x * gamma + beta
768
+ x = self.model(x)
769
+ return x
770
+
771
+
772
+ class FACodecEncoderV2(nn.Module):
773
+ def __init__(
774
+ self,
775
+ ngf=32,
776
+ up_ratios=(2, 4, 5, 5),
777
+ out_channels=1024,
778
+ ):
779
+ super().__init__()
780
+ self.hop_length = np.prod(up_ratios)
781
+ self.up_ratios = up_ratios
782
+
783
+ # Create first convolution
784
+ d_model = ngf
785
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
786
+
787
+ # Create EncoderBlocks that double channels as they downsample by `stride`
788
+ for stride in up_ratios:
789
+ d_model *= 2
790
+ self.block += [EncoderBlock(d_model, stride=stride)]
791
+
792
+ # Create last convolution
793
+ self.block += [
794
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
795
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
796
+ ]
797
+
798
+ # Wrap black into nn.Sequential
799
+ self.block = nn.Sequential(*self.block)
800
+ self.enc_dim = d_model
801
+
802
+ self.mel_transform = MelSpectrogram(
803
+ n_fft=1024,
804
+ num_mels=80,
805
+ sampling_rate=16000,
806
+ hop_size=200,
807
+ win_size=800,
808
+ fmin=0,
809
+ fmax=8000,
810
+ )
811
+
812
+ self.reset_parameters()
813
+
814
+ def forward(self, x):
815
+ out = self.block(x)
816
+ return out
817
+
818
+ def inference(self, x):
819
+ return self.block(x)
820
+
821
+ def get_prosody_feature(self, x):
822
+ return self.mel_transform(x.squeeze(1))[:, :20, :]
823
+
824
+ def remove_weight_norm(self):
825
+ """Remove weight normalization module from all of the layers."""
826
+
827
+ def _remove_weight_norm(m):
828
+ try:
829
+ torch.nn.utils.remove_weight_norm(m)
830
+ except ValueError: # this module didn't have weight norm
831
+ return
832
+
833
+ self.apply(_remove_weight_norm)
834
+
835
+ def apply_weight_norm(self):
836
+ """Apply weight normalization module from all of the layers."""
837
+
838
+ def _apply_weight_norm(m):
839
+ if isinstance(m, nn.Conv1d):
840
+ torch.nn.utils.weight_norm(m)
841
+
842
+ self.apply(_apply_weight_norm)
843
+
844
+ def reset_parameters(self):
845
+ self.apply(init_weights)
846
+
847
+
848
+ class FACodecDecoderV2(nn.Module):
849
+ def __init__(
850
+ self,
851
+ in_channels=256,
852
+ upsample_initial_channel=1536,
853
+ ngf=32,
854
+ up_ratios=(5, 5, 4, 2),
855
+ vq_num_q_c=2,
856
+ vq_num_q_p=1,
857
+ vq_num_q_r=3,
858
+ vq_dim=1024,
859
+ vq_commit_weight=0.005,
860
+ vq_weight_init=False,
861
+ vq_full_commit_loss=False,
862
+ codebook_dim=8,
863
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
864
+ codebook_size_content=10,
865
+ codebook_size_residual=10,
866
+ quantizer_dropout=0.0,
867
+ dropout_type="linear",
868
+ use_gr_content_f0=False,
869
+ use_gr_prosody_phone=False,
870
+ use_gr_residual_f0=False,
871
+ use_gr_residual_phone=False,
872
+ use_gr_x_timbre=False,
873
+ use_random_mask_residual=True,
874
+ prob_random_mask_residual=0.75,
875
+ ):
876
+ super().__init__()
877
+ self.hop_length = np.prod(up_ratios)
878
+ self.ngf = ngf
879
+ self.up_ratios = up_ratios
880
+
881
+ self.use_random_mask_residual = use_random_mask_residual
882
+ self.prob_random_mask_residual = prob_random_mask_residual
883
+
884
+ self.vq_num_q_p = vq_num_q_p
885
+ self.vq_num_q_c = vq_num_q_c
886
+ self.vq_num_q_r = vq_num_q_r
887
+
888
+ self.codebook_size_prosody = codebook_size_prosody
889
+ self.codebook_size_content = codebook_size_content
890
+ self.codebook_size_residual = codebook_size_residual
891
+
892
+ quantizer_class = ResidualVQ
893
+
894
+ self.quantizer = nn.ModuleList()
895
+
896
+ # prosody
897
+ quantizer = quantizer_class(
898
+ num_quantizers=vq_num_q_p,
899
+ dim=vq_dim,
900
+ codebook_size=codebook_size_prosody,
901
+ codebook_dim=codebook_dim,
902
+ threshold_ema_dead_code=2,
903
+ commitment=vq_commit_weight,
904
+ weight_init=vq_weight_init,
905
+ full_commit_loss=vq_full_commit_loss,
906
+ quantizer_dropout=quantizer_dropout,
907
+ dropout_type=dropout_type,
908
+ )
909
+ self.quantizer.append(quantizer)
910
+
911
+ # phone
912
+ quantizer = quantizer_class(
913
+ num_quantizers=vq_num_q_c,
914
+ dim=vq_dim,
915
+ codebook_size=codebook_size_content,
916
+ codebook_dim=codebook_dim,
917
+ threshold_ema_dead_code=2,
918
+ commitment=vq_commit_weight,
919
+ weight_init=vq_weight_init,
920
+ full_commit_loss=vq_full_commit_loss,
921
+ quantizer_dropout=quantizer_dropout,
922
+ dropout_type=dropout_type,
923
+ )
924
+ self.quantizer.append(quantizer)
925
+
926
+ # residual
927
+ if self.vq_num_q_r > 0:
928
+ quantizer = quantizer_class(
929
+ num_quantizers=vq_num_q_r,
930
+ dim=vq_dim,
931
+ codebook_size=codebook_size_residual,
932
+ codebook_dim=codebook_dim,
933
+ threshold_ema_dead_code=2,
934
+ commitment=vq_commit_weight,
935
+ weight_init=vq_weight_init,
936
+ full_commit_loss=vq_full_commit_loss,
937
+ quantizer_dropout=quantizer_dropout,
938
+ dropout_type=dropout_type,
939
+ )
940
+ self.quantizer.append(quantizer)
941
+
942
+ # Add first conv layer
943
+ channels = upsample_initial_channel
944
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
945
+
946
+ # Add upsampling + MRF blocks
947
+ for i, stride in enumerate(up_ratios):
948
+ input_dim = channels // 2**i
949
+ output_dim = channels // 2 ** (i + 1)
950
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
951
+
952
+ # Add final conv layer
953
+ layers += [
954
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
955
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
956
+ nn.Tanh(),
957
+ ]
958
+
959
+ self.model = nn.Sequential(*layers)
960
+
961
+ self.timbre_encoder = TransformerEncoder(
962
+ enc_emb_tokens=None,
963
+ encoder_layer=4,
964
+ encoder_hidden=256,
965
+ encoder_head=4,
966
+ conv_filter_size=1024,
967
+ conv_kernel_size=5,
968
+ encoder_dropout=0.1,
969
+ use_cln=False,
970
+ )
971
+
972
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
973
+ self.timbre_linear.bias.data[:in_channels] = 1
974
+ self.timbre_linear.bias.data[in_channels:] = 0
975
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
976
+
977
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
978
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
979
+
980
+ self.use_gr_content_f0 = use_gr_content_f0
981
+ self.use_gr_prosody_phone = use_gr_prosody_phone
982
+ self.use_gr_residual_f0 = use_gr_residual_f0
983
+ self.use_gr_residual_phone = use_gr_residual_phone
984
+ self.use_gr_x_timbre = use_gr_x_timbre
985
+
986
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
987
+ self.res_f0_predictor = nn.Sequential(
988
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
989
+ )
990
+
991
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
992
+ self.res_phone_predictor = nn.Sequential(
993
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
994
+ )
995
+
996
+ if self.use_gr_content_f0:
997
+ self.content_f0_predictor = nn.Sequential(
998
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
999
+ )
1000
+
1001
+ if self.use_gr_prosody_phone:
1002
+ self.prosody_phone_predictor = nn.Sequential(
1003
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
1004
+ )
1005
+
1006
+ if self.use_gr_x_timbre:
1007
+ self.x_timbre_predictor = nn.Sequential(
1008
+ GradientReversal(alpha=1),
1009
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
1010
+ )
1011
+
1012
+ self.melspec_linear = nn.Linear(20, 256)
1013
+ self.melspec_encoder = TransformerEncoder(
1014
+ enc_emb_tokens=None,
1015
+ encoder_layer=4,
1016
+ encoder_hidden=256,
1017
+ encoder_head=4,
1018
+ conv_filter_size=1024,
1019
+ conv_kernel_size=5,
1020
+ encoder_dropout=0.1,
1021
+ use_cln=False,
1022
+ cfg=None,
1023
+ )
1024
+
1025
+ self.reset_parameters()
1026
+
1027
+ def quantize(self, x, prosody_feature, n_quantizers=None):
1028
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
1029
+
1030
+ # prosody
1031
+ f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
1032
+ f0_input = self.melspec_linear(f0_input)
1033
+ f0_input = self.melspec_encoder(f0_input, None, None)
1034
+ f0_input = f0_input.transpose(1, 2)
1035
+ f0_quantizer = self.quantizer[0]
1036
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
1037
+ outs += out
1038
+ qs.append(q)
1039
+ quantized_buf.append(quantized.sum(0))
1040
+ commit_loss.append(commit)
1041
+
1042
+ # phone
1043
+ phone_input = x
1044
+ phone_quantizer = self.quantizer[1]
1045
+ out, q, commit, quantized = phone_quantizer(
1046
+ phone_input, n_quantizers=n_quantizers
1047
+ )
1048
+ outs += out
1049
+ qs.append(q)
1050
+ quantized_buf.append(quantized.sum(0))
1051
+ commit_loss.append(commit)
1052
+
1053
+ # residual
1054
+ if self.vq_num_q_r > 0:
1055
+ residual_quantizer = self.quantizer[2]
1056
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
1057
+ out, q, commit, quantized = residual_quantizer(
1058
+ residual_input, n_quantizers=n_quantizers
1059
+ )
1060
+ outs += out
1061
+ qs.append(q)
1062
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
1063
+ commit_loss.append(commit)
1064
+
1065
+ qs = torch.cat(qs, dim=0)
1066
+ commit_loss = torch.cat(commit_loss, dim=0)
1067
+ return outs, qs, commit_loss, quantized_buf
1068
+
1069
+ def forward(
1070
+ self,
1071
+ x,
1072
+ prosody_feature,
1073
+ vq=True,
1074
+ get_vq=False,
1075
+ eval_vq=True,
1076
+ speaker_embedding=None,
1077
+ n_quantizers=None,
1078
+ quantized=None,
1079
+ ):
1080
+ if get_vq:
1081
+ return self.quantizer.get_emb()
1082
+ if vq is True:
1083
+ if eval_vq:
1084
+ self.quantizer.eval()
1085
+ x_timbre = x
1086
+ outs, qs, commit_loss, quantized_buf = self.quantize(
1087
+ x, prosody_feature, n_quantizers=n_quantizers
1088
+ )
1089
+
1090
+ x_timbre = x_timbre.transpose(1, 2)
1091
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
1092
+ x_timbre = x_timbre.transpose(1, 2)
1093
+ spk_embs = torch.mean(x_timbre, dim=2)
1094
+ return outs, qs, commit_loss, quantized_buf, spk_embs
1095
+
1096
+ out = {}
1097
+
1098
+ layer_0 = quantized[0]
1099
+ f0, uv = self.f0_predictor(layer_0)
1100
+ f0 = rearrange(f0, "... 1 -> ...")
1101
+ uv = rearrange(uv, "... 1 -> ...")
1102
+
1103
+ layer_1 = quantized[1]
1104
+ (phone,) = self.phone_predictor(layer_1)
1105
+
1106
+ out = {"f0": f0, "uv": uv, "phone": phone}
1107
+
1108
+ if self.use_gr_prosody_phone:
1109
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
1110
+ out["prosody_phone"] = prosody_phone
1111
+
1112
+ if self.use_gr_content_f0:
1113
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
1114
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
1115
+ content_uv = rearrange(content_uv, "... 1 -> ...")
1116
+ out["content_f0"] = content_f0
1117
+ out["content_uv"] = content_uv
1118
+
1119
+ if self.vq_num_q_r > 0:
1120
+ layer_2 = quantized[2]
1121
+
1122
+ if self.use_gr_residual_f0:
1123
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
1124
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
1125
+ res_uv = rearrange(res_uv, "... 1 -> ...")
1126
+ out["res_f0"] = res_f0
1127
+ out["res_uv"] = res_uv
1128
+
1129
+ if self.use_gr_residual_phone:
1130
+ (res_phone,) = self.res_phone_predictor(layer_2)
1131
+ out["res_phone"] = res_phone
1132
+
1133
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1134
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1135
+ if self.vq_num_q_r > 0:
1136
+ if self.use_random_mask_residual:
1137
+ bsz = quantized[2].shape[0]
1138
+ res_mask = np.random.choice(
1139
+ [0, 1],
1140
+ size=bsz,
1141
+ p=[
1142
+ self.prob_random_mask_residual,
1143
+ 1 - self.prob_random_mask_residual,
1144
+ ],
1145
+ )
1146
+ res_mask = (
1147
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
1148
+ ) # (B, 1, 1)
1149
+ res_mask = res_mask.to(
1150
+ device=quantized[2].device, dtype=quantized[2].dtype
1151
+ )
1152
+ x = (
1153
+ quantized[0].detach()
1154
+ + quantized[1].detach()
1155
+ + quantized[2] * res_mask
1156
+ )
1157
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
1158
+ else:
1159
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
1160
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
1161
+ else:
1162
+ x = quantized[0].detach() + quantized[1].detach()
1163
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
1164
+
1165
+ if self.use_gr_x_timbre:
1166
+ (x_timbre,) = self.x_timbre_predictor(x)
1167
+ out["x_timbre"] = x_timbre
1168
+
1169
+ x = x.transpose(1, 2)
1170
+ x = self.timbre_norm(x)
1171
+ x = x.transpose(1, 2)
1172
+ x = x * gamma + beta
1173
+
1174
+ x = self.model(x)
1175
+ out["audio"] = x
1176
+
1177
+ return out
1178
+
1179
+ def vq2emb(self, vq, use_residual=True):
1180
+ # vq: [num_quantizer, B, T]
1181
+ self.quantizer = self.quantizer.eval()
1182
+ out = 0
1183
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
1184
+ out += self.quantizer[1].vq2emb(
1185
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
1186
+ )
1187
+ if self.vq_num_q_r > 0 and use_residual:
1188
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
1189
+ return out
1190
+
1191
+ def inference(self, x, speaker_embedding):
1192
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1193
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1194
+ x = x.transpose(1, 2)
1195
+ x = self.timbre_norm(x)
1196
+ x = x.transpose(1, 2)
1197
+ x = x * gamma + beta
1198
+ x = self.model(x)
1199
+ return x
1200
+
1201
+ def remove_weight_norm(self):
1202
+ """Remove weight normalization module from all of the layers."""
1203
+
1204
+ def _remove_weight_norm(m):
1205
+ try:
1206
+ torch.nn.utils.remove_weight_norm(m)
1207
+ except ValueError: # this module didn't have weight norm
1208
+ return
1209
+
1210
+ self.apply(_remove_weight_norm)
1211
+
1212
+ def apply_weight_norm(self):
1213
+ """Apply weight normalization module from all of the layers."""
1214
+
1215
+ def _apply_weight_norm(m):
1216
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
1217
+ torch.nn.utils.weight_norm(m)
1218
+
1219
+ self.apply(_apply_weight_norm)
1220
+
1221
+ def reset_parameters(self):
1222
+ self.apply(init_weights)