xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,257 @@
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from indextts.s2mel.modules.gpt_fast.model import ModelArgs, Transformer
6
+ from indextts.s2mel.modules.wavenet import WN
7
+ from indextts.s2mel.modules.commons import sequence_mask
8
+
9
+ from torch.nn.utils import weight_norm
10
+
11
+ def modulate(x, shift, scale):
12
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
13
+
14
+
15
+ #################################################################################
16
+ # Embedding Layers for Timesteps and Class Labels #
17
+ #################################################################################
18
+
19
+ class TimestepEmbedder(nn.Module):
20
+ """
21
+ Embeds scalar timesteps into vector representations.
22
+ """
23
+ def __init__(self, hidden_size, frequency_embedding_size=256):
24
+ super().__init__()
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
27
+ nn.SiLU(),
28
+ nn.Linear(hidden_size, hidden_size, bias=True),
29
+ )
30
+ self.frequency_embedding_size = frequency_embedding_size
31
+ self.max_period = 10000
32
+ self.scale = 1000
33
+
34
+ half = frequency_embedding_size // 2
35
+ freqs = torch.exp(
36
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
37
+ )
38
+ self.register_buffer("freqs", freqs)
39
+
40
+ def timestep_embedding(self, t):
41
+ """
42
+ Create sinusoidal timestep embeddings.
43
+ :param t: a 1-D Tensor of N indices, one per batch element.
44
+ These may be fractional.
45
+ :param dim: the dimension of the output.
46
+ :param max_period: controls the minimum frequency of the embeddings.
47
+ :return: an (N, D) Tensor of positional embeddings.
48
+ """
49
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
50
+
51
+ args = self.scale * t[:, None].float() * self.freqs[None]
52
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
53
+ if self.frequency_embedding_size % 2:
54
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
55
+ return embedding
56
+
57
+ def forward(self, t):
58
+ t_freq = self.timestep_embedding(t)
59
+ t_emb = self.mlp(t_freq)
60
+ return t_emb
61
+
62
+
63
+ class StyleEmbedder(nn.Module):
64
+ """
65
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
66
+ """
67
+ def __init__(self, input_size, hidden_size, dropout_prob):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
71
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
72
+ self.input_size = input_size
73
+ self.dropout_prob = dropout_prob
74
+
75
+ def forward(self, labels, train, force_drop_ids=None):
76
+ use_dropout = self.dropout_prob > 0
77
+ if (train and use_dropout) or (force_drop_ids is not None):
78
+ labels = self.token_drop(labels, force_drop_ids)
79
+ else:
80
+ labels = self.style_in(labels)
81
+ embeddings = labels
82
+ return embeddings
83
+
84
+ class FinalLayer(nn.Module):
85
+ """
86
+ The final layer of DiT.
87
+ """
88
+ def __init__(self, hidden_size, patch_size, out_channels):
89
+ super().__init__()
90
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
91
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
92
+ self.adaLN_modulation = nn.Sequential(
93
+ nn.SiLU(),
94
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
95
+ )
96
+
97
+ def forward(self, x, c):
98
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
99
+ x = modulate(self.norm_final(x), shift, scale)
100
+ x = self.linear(x)
101
+ return x
102
+
103
+ class DiT(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ args
107
+ ):
108
+ super(DiT, self).__init__()
109
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
110
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
111
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
112
+ model_args = ModelArgs(
113
+ block_size=16384,#args.DiT.block_size,
114
+ n_layer=args.DiT.depth,
115
+ n_head=args.DiT.num_heads,
116
+ dim=args.DiT.hidden_dim,
117
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
118
+ vocab_size=1024,
119
+ uvit_skip_connection=self.uvit_skip_connection,
120
+ time_as_token=self.time_as_token,
121
+ )
122
+ self.transformer = Transformer(model_args)
123
+ self.in_channels = args.DiT.in_channels
124
+ self.out_channels = args.DiT.in_channels
125
+ self.num_heads = args.DiT.num_heads
126
+
127
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
128
+
129
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
130
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
131
+ self.content_dim = args.DiT.content_dim # for continuous content
132
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
133
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
134
+
135
+ self.is_causal = args.DiT.is_causal
136
+
137
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
138
+
139
+ # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
+ # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
+
142
+ input_pos = torch.arange(16384)
143
+ self.register_buffer("input_pos", input_pos)
144
+
145
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
146
+ if self.final_layer_type == 'wavenet':
147
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
148
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
149
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
150
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
151
+ kernel_size=args.wavenet.kernel_size,
152
+ dilation_rate=args.wavenet.dilation_rate,
153
+ n_layers=args.wavenet.num_layers,
154
+ gin_channels=args.wavenet.hidden_dim,
155
+ p_dropout=args.wavenet.p_dropout,
156
+ causal=False)
157
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
158
+ self.res_projection = nn.Linear(args.DiT.hidden_dim,
159
+ args.wavenet.hidden_dim) # residual connection from tranformer output to final output
160
+ self.wavenet_style_condition = args.wavenet.style_condition
161
+ assert args.DiT.style_condition == args.wavenet.style_condition
162
+ else:
163
+ self.final_mlp = nn.Sequential(
164
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
165
+ nn.SiLU(),
166
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
167
+ )
168
+ self.transformer_style_condition = args.DiT.style_condition
169
+
170
+
171
+ self.class_dropout_prob = args.DiT.class_dropout_prob
172
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
173
+
174
+ self.long_skip_connection = args.DiT.long_skip_connection
175
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
176
+
177
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
178
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
179
+ args.DiT.hidden_dim)
180
+ if self.style_as_token:
181
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
182
+
183
+ def setup_caches(self, max_batch_size, max_seq_length):
184
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
185
+
186
+ def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
187
+ """
188
+ x (torch.Tensor): random noise
189
+ prompt_x (torch.Tensor): reference mel + zero mel
190
+ shape: (batch_size, 80, 795+1068)
191
+ x_lens (torch.Tensor): mel frames output
192
+ shape: (batch_size, mel_timesteps)
193
+ t (torch.Tensor): radshape:
194
+ shape: (batch_size)
195
+ style (torch.Tensor): reference global style
196
+ shape: (batch_size, 192)
197
+ cond (torch.Tensor): semantic info of reference audio and altered audio
198
+ shape: (batch_size, mel_timesteps(795+1069), 512)
199
+
200
+ """
201
+ class_dropout = False
202
+ if self.training and torch.rand(1) < self.class_dropout_prob:
203
+ class_dropout = True
204
+ if not self.training and mask_content:
205
+ class_dropout = True
206
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
207
+ cond_in_module = self.cond_projection
208
+
209
+ B, _, T = x.size()
210
+
211
+
212
+ t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
213
+ cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
214
+
215
+ x = x.transpose(1, 2) # [2,1863,80]
216
+ prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
217
+
218
+ x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
219
+
220
+ if self.transformer_style_condition and not self.style_as_token: # True and True
221
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
222
+
223
+ if class_dropout: #False
224
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
225
+
226
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
227
+
228
+ if self.style_as_token: # False
229
+ style = self.style_in(style)
230
+ style = torch.zeros_like(style) if class_dropout else style
231
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
232
+
233
+ if self.time_as_token: # False
234
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
235
+
236
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
237
+ input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
238
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
239
+ x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
240
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
241
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
242
+
243
+ if self.long_skip_connection: #True
244
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
245
+ if self.final_layer_type == 'wavenet':
246
+ x = self.conv1(x_res)
247
+ x = x.transpose(1, 2)
248
+ t2 = self.t_embedder2(t)
249
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
250
+ x_res) # long residual connection
251
+ x = self.final_layer(x, t1).transpose(1, 2)
252
+ x = self.conv2(x)
253
+ else:
254
+ x = self.final_mlp(x_res)
255
+ x = x.transpose(1, 2)
256
+ # x [2,80,1863]
257
+ return x
@@ -0,0 +1,292 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ def forward(self, x):
213
+ B, C, T = x.shape
214
+ kernel_size = self.conv.conv.kernel_size[0]
215
+ stride = self.conv.conv.stride[0]
216
+ dilation = self.conv.conv.dilation[0]
217
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218
+ padding_total = kernel_size - stride
219
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220
+ if self.causal:
221
+ # Left padding for causal
222
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223
+ else:
224
+ # Asymmetric padding required for odd strides
225
+ padding_right = padding_total // 2
226
+ padding_left = padding_total - padding_right
227
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228
+ return self.conv(x)
229
+
230
+
231
+ class SConvTranspose1d(nn.Module):
232
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233
+ and normalization.
234
+ """
235
+ def __init__(self, in_channels: int, out_channels: int,
236
+ kernel_size: int, stride: int = 1, causal: bool = False,
237
+ norm: str = 'none', trim_right_ratio: float = 1.,
238
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239
+ super().__init__()
240
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242
+ self.causal = causal
243
+ self.trim_right_ratio = trim_right_ratio
244
+ assert self.causal or self.trim_right_ratio == 1., \
245
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247
+
248
+ def forward(self, x):
249
+ kernel_size = self.convtr.convtr.kernel_size[0]
250
+ stride = self.convtr.convtr.stride[0]
251
+ padding_total = kernel_size - stride
252
+
253
+ y = self.convtr(x)
254
+
255
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256
+ # removed at the very end, when keeping only the right length for the output,
257
+ # as removing it here would require also passing the length at the matching layer
258
+ # in the encoder.
259
+ if self.causal:
260
+ # Trim the padding on the right according to the specified ratio
261
+ # if trim_right_ratio = 1.0, trim everything from right
262
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
263
+ padding_left = padding_total - padding_right
264
+ y = unpad1d(y, (padding_left, padding_right))
265
+ else:
266
+ # Asymmetric padding required for odd strides
267
+ padding_right = padding_total // 2
268
+ padding_left = padding_total - padding_right
269
+ y = unpad1d(y, (padding_left, padding_right))
270
+ return y
271
+
272
+ class SLSTM(nn.Module):
273
+ """
274
+ LSTM without worrying about the hidden state, nor the layout of the data.
275
+ Expects input as convolutional layout.
276
+ """
277
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278
+ super().__init__()
279
+ self.skip = skip
280
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
281
+ self.hidden = None
282
+
283
+ def forward(self, x):
284
+ x = x.permute(2, 0, 1)
285
+ if self.training:
286
+ y, _ = self.lstm(x)
287
+ else:
288
+ y, self.hidden = self.lstm(x, self.hidden)
289
+ if self.skip:
290
+ y = y + x
291
+ y = y.permute(1, 2, 0)
292
+ return y