xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch.nn.utils import weight_norm
14
+
15
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
16
+ FactorizedVectorQuantize,
17
+ )
18
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
19
+ from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
20
+
21
+
22
+ class ResidualVQ(nn.Module):
23
+ """
24
+ Introduced in SoundStream: An end2end neural audio codec
25
+ https://arxiv.org/abs/2107.03312
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ input_dim: int = 256,
31
+ num_quantizers: int = 8,
32
+ codebook_size: int = 1024,
33
+ codebook_dim: int = 256,
34
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
35
+ quantizer_dropout: float = 0.5,
36
+ **kwargs,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_dim = input_dim
41
+ self.num_quantizers = num_quantizers
42
+ self.codebook_size = codebook_size
43
+ self.codebook_dim = codebook_dim
44
+ self.quantizer_type = quantizer_type
45
+ self.quantizer_dropout = quantizer_dropout
46
+
47
+ if quantizer_type == "vq":
48
+ VQ = VectorQuantize
49
+ elif quantizer_type == "fvq":
50
+ VQ = FactorizedVectorQuantize
51
+ elif quantizer_type == "lfq":
52
+ VQ = LookupFreeQuantize
53
+ else:
54
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
55
+
56
+ self.quantizers = nn.ModuleList(
57
+ [
58
+ VQ(
59
+ input_dim=input_dim,
60
+ codebook_size=codebook_size,
61
+ codebook_dim=codebook_dim,
62
+ **kwargs,
63
+ )
64
+ for _ in range(num_quantizers)
65
+ ]
66
+ )
67
+
68
+ def forward(self, z, n_quantizers: int = None):
69
+ """
70
+ Parameters
71
+ ----------
72
+ z : Tensor[B x D x T]
73
+ n_quantizers : int, optional
74
+ No. of quantizers to use
75
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
76
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
77
+ when in training mode, and a random number of quantizers is used.
78
+ Returns
79
+ -------
80
+ "quantized_out" : Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ "all_indices" : Tensor[N x B x T]
83
+ Codebook indices for each codebook
84
+ (quantized discrete representation of input)
85
+ "all_commit_losses" : Tensor[N]
86
+ "all_codebook_losses" : Tensor[N]
87
+ "all_quantized" : Tensor[N x B x D x T]
88
+ """
89
+
90
+ quantized_out = 0.0
91
+ residual = z
92
+
93
+ all_commit_losses = []
94
+ all_codebook_losses = []
95
+ all_indices = []
96
+ all_quantized = []
97
+
98
+ if n_quantizers is None:
99
+ n_quantizers = self.num_quantizers
100
+
101
+ if self.training:
102
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
103
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
104
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
105
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
106
+ n_quantizers = n_quantizers.to(z.device)
107
+
108
+ for i, quantizer in enumerate(self.quantizers):
109
+ if self.training is False and i >= n_quantizers:
110
+ break
111
+
112
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
113
+ residual
114
+ )
115
+
116
+ # Create mask to apply quantizer dropout
117
+ mask = (
118
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
119
+ )
120
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
121
+ residual = residual - z_q_i
122
+
123
+ commit_loss_i = (commit_loss_i * mask).mean()
124
+ codebook_loss_i = (codebook_loss_i * mask).mean()
125
+
126
+ all_commit_losses.append(commit_loss_i)
127
+ all_codebook_losses.append(codebook_loss_i)
128
+ all_indices.append(indices_i)
129
+ all_quantized.append(z_q_i)
130
+
131
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
132
+ torch.stack,
133
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
134
+ )
135
+
136
+ return (
137
+ quantized_out,
138
+ all_indices,
139
+ all_commit_losses,
140
+ all_codebook_losses,
141
+ all_quantized,
142
+ )
143
+
144
+ def vq2emb(self, vq, n_quantizers=None):
145
+ quantized_out = 0.0
146
+ if n_quantizers is None:
147
+ n_quantizers = self.num_quantizers
148
+ for idx, quantizer in enumerate(self.quantizers):
149
+ if idx >= n_quantizers:
150
+ break
151
+ quantized_out += quantizer.vq2emb(vq[idx])
152
+ return quantized_out
153
+
154
+ def latent2dist(self, z, n_quantizers=None):
155
+ quantized_out = 0.0
156
+ residual = z
157
+
158
+ all_dists = []
159
+ all_indices = []
160
+
161
+ if n_quantizers is None:
162
+ n_quantizers = self.num_quantizers
163
+
164
+ for i, quantizer in enumerate(self.quantizers):
165
+ if self.training is False and i >= n_quantizers:
166
+ break
167
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
168
+ all_dists.append(dist_i)
169
+ all_indices.append(indices_i)
170
+
171
+ quantized_out = quantized_out + z_q_i
172
+ residual = residual - z_q_i
173
+
174
+ all_dists = torch.stack(all_dists)
175
+ all_indices = torch.stack(all_indices)
176
+
177
+ return all_dists, all_indices
@@ -0,0 +1,401 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ def l2norm(t):
23
+ return F.normalize(t, p=2, dim=-1)
24
+
25
+
26
+ def ema_inplace(moving_avg, new, decay):
27
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
28
+
29
+
30
+ def laplace_smoothing(x, n_categories, eps=1e-5):
31
+ return (x + eps) / (x.sum() + n_categories * eps)
32
+
33
+
34
+ def sample_vectors(samples, num):
35
+ num_samples, device = samples.shape[0], samples.device
36
+
37
+ if num_samples >= num:
38
+ indices = torch.randperm(num_samples, device=device)[:num]
39
+ else:
40
+ indices = torch.randint(0, num_samples, (num,), device=device)
41
+
42
+ return samples[indices]
43
+
44
+
45
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
46
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
47
+
48
+ means = sample_vectors(samples, num_clusters)
49
+
50
+ for _ in range(num_iters):
51
+ if use_cosine_sim:
52
+ dists = samples @ means.t()
53
+ else:
54
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
55
+ means, "c d -> () c d"
56
+ )
57
+ dists = -(diffs**2).sum(dim=-1)
58
+
59
+ buckets = dists.max(dim=-1).indices
60
+ bins = torch.bincount(buckets, minlength=num_clusters)
61
+ zero_mask = bins == 0
62
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
63
+
64
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
65
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
66
+ new_means = new_means / bins_min_clamped[..., None]
67
+
68
+ if use_cosine_sim:
69
+ new_means = l2norm(new_means)
70
+
71
+ means = torch.where(zero_mask[..., None], means, new_means)
72
+
73
+ return means, bins
74
+
75
+
76
+ class EuclideanCodebook(nn.Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ codebook_size,
81
+ kmeans_init=False,
82
+ kmeans_iters=10,
83
+ decay=0.8,
84
+ eps=1e-5,
85
+ threshold_ema_dead_code=2,
86
+ weight_init=False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.decay = decay
91
+ init_fn = torch.randn if not weight_init else torch.zeros
92
+ embed = init_fn(codebook_size, dim)
93
+
94
+ if weight_init:
95
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
96
+
97
+ self.codebook_size = codebook_size
98
+ self.kmeans_iters = kmeans_iters
99
+ self.eps = eps
100
+ self.threshold_ema_dead_code = threshold_ema_dead_code
101
+
102
+ self.register_buffer(
103
+ "initted", torch.Tensor([not kmeans_init])
104
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
105
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
106
+ self.register_buffer("embed", embed)
107
+ self.register_buffer("embed_avg", embed.clone())
108
+
109
+ def init_embed_(self, data):
110
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
111
+ self.embed.data.copy_(embed)
112
+ self.embed_avg.data.copy_(embed)
113
+ self.cluster_size.data.copy_(cluster_size)
114
+ self.initted.data.copy_(torch.Tensor([True]))
115
+
116
+ def replace(self, samples, mask):
117
+ modified_codebook = torch.where(
118
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
119
+ )
120
+ self.embed.data.copy_(modified_codebook)
121
+
122
+ def expire_codes_(self, batch_samples):
123
+ if self.threshold_ema_dead_code == 0:
124
+ return
125
+
126
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
127
+ if not torch.any(expired_codes):
128
+ return
129
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
130
+ self.replace(batch_samples, mask=expired_codes)
131
+
132
+ def forward(self, x):
133
+ shape, dtype = x.shape, x.dtype
134
+ flatten = rearrange(x, "... d -> (...) d")
135
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
136
+
137
+ if not self.initted:
138
+ self.init_embed_(flatten)
139
+
140
+ dist = -(
141
+ flatten.pow(2).sum(1, keepdim=True)
142
+ - 2 * flatten @ embed
143
+ + embed.pow(2).sum(0, keepdim=True)
144
+ )
145
+
146
+ embed_ind = dist.max(dim=-1).indices
147
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
148
+ embed_ind = embed_ind.view(*shape[:-1])
149
+ quantize = F.embedding(embed_ind, self.embed)
150
+
151
+ if self.training:
152
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
153
+ embed_sum = (
154
+ flatten.t() @ embed_onehot
155
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
156
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
157
+ cluster_size = (
158
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
159
+ * self.cluster_size.sum()
160
+ )
161
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
162
+ self.embed.data.copy_(embed_normalized)
163
+ self.expire_codes_(x)
164
+
165
+ return quantize, embed_ind
166
+
167
+ def vq2emb(self, vq):
168
+ quantize = F.embedding(vq, self.embed)
169
+ return quantize
170
+
171
+ def latent2dist(self, x):
172
+ shape, dtype = x.shape, x.dtype
173
+ flatten = rearrange(x, "... d -> (...) d")
174
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
175
+
176
+ if not self.initted:
177
+ self.init_embed_(flatten)
178
+
179
+ dist = -(
180
+ flatten.pow(2).sum(1, keepdim=True)
181
+ - 2 * flatten @ embed
182
+ + embed.pow(2).sum(0, keepdim=True)
183
+ )
184
+
185
+ embed_ind = dist.max(dim=-1).indices
186
+ embed_ind = embed_ind.view(*shape[:-1])
187
+ quantize = F.embedding(embed_ind, self.embed)
188
+
189
+ dist = dist.view(*shape[:-1], -1)
190
+
191
+ return dist, embed_ind, quantize
192
+
193
+
194
+ class SimpleCodebook(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ codebook_size,
199
+ use_l2_normlize=False,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.dim = dim
204
+ self.codebook_size = codebook_size
205
+ self.use_l2_normlize = use_l2_normlize
206
+
207
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
208
+
209
+ def forward(self, x):
210
+ shape, dtype = x.shape, x.dtype
211
+ flatten = rearrange(x, "... d -> (...) d")
212
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
213
+
214
+ if self.use_l2_normlize:
215
+ flatten = F.normalize(flatten)
216
+ embed = F.normalize(embed)
217
+
218
+ dist = -(
219
+ flatten.pow(2).sum(1, keepdim=True)
220
+ - 2 * flatten @ embed
221
+ + embed.pow(2).sum(0, keepdim=True)
222
+ )
223
+
224
+ embed_ind = dist.max(dim=-1).indices
225
+ embed_ind = embed_ind.view(*shape[:-1])
226
+ quantize = F.embedding(embed_ind, self.embed)
227
+
228
+ return quantize, embed_ind
229
+
230
+ def vq2emb(self, vq):
231
+ quantize = F.embedding(vq, self.embed.weight)
232
+ return quantize
233
+
234
+ def latent2dist(self, x):
235
+ shape, dtype = x.shape, x.dtype
236
+ flatten = rearrange(x, "... d -> (...) d")
237
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
238
+
239
+ if self.use_l2_normlize:
240
+ flatten = F.normalize(flatten)
241
+ embed = F.normalize(embed)
242
+
243
+ dist = -(
244
+ flatten.pow(2).sum(1, keepdim=True)
245
+ - 2 * flatten @ embed
246
+ + embed.pow(2).sum(0, keepdim=True)
247
+ )
248
+
249
+ embed_ind = dist.max(dim=-1).indices
250
+ embed_ind = embed_ind.view(*shape[:-1])
251
+ quantize = F.embedding(embed_ind, self.embed)
252
+
253
+ dist = dist.view(*shape[:-1], -1)
254
+
255
+ return dist, embed_ind, quantize
256
+
257
+
258
+ class VectorQuantize(nn.Module):
259
+ """Vector quantization and factorized vecotor quantization implementation
260
+ Args:
261
+ input_dim (int): Dimension of input.
262
+ codebook_size (int): Codebook size.
263
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
264
+ if use codebook_type == "euclidean", otherwise, if you want to use
265
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
266
+ commitment (float): Weight for commitment loss.
267
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
268
+ we suggest use it as True if you want to use factorized vector quantization
269
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
270
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
271
+ decay (float): Decay for exponential moving average over the codebooks.
272
+ epsilon (float): Epsilon value for numerical stability.
273
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
274
+ that have an exponential moving average cluster size less than the specified threshold with
275
+ randomly selected vector from the current batch.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ input_dim,
281
+ codebook_size,
282
+ codebook_dim,
283
+ commitment=0.005,
284
+ codebook_loss_weight=1.0,
285
+ use_l2_normlize=False,
286
+ codebook_type="euclidean", # "euclidean" or "simple"
287
+ kmeans_init=False,
288
+ kmeans_iters=10,
289
+ decay=0.8,
290
+ eps=1e-5,
291
+ threshold_ema_dead_code=2,
292
+ weight_init=False,
293
+ ):
294
+ super().__init__()
295
+ self.input_dim = input_dim
296
+ self.codebook_size = codebook_size
297
+ self.codebook_dim = codebook_dim
298
+ self.commitment = commitment
299
+ self.codebook_loss_weight = codebook_loss_weight
300
+ self.use_l2_normlize = use_l2_normlize
301
+ self.codebook_type = codebook_type
302
+ self.kmeans_init = kmeans_init
303
+ self.kmeans_iters = kmeans_iters
304
+ self.decay = decay
305
+ self.eps = eps
306
+ self.threshold_ema_dead_code = threshold_ema_dead_code
307
+ self.weight_init = weight_init
308
+
309
+ if self.input_dim != self.codebook_dim:
310
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
311
+ self.out_project = WNConv1d(
312
+ self.codebook_dim, self.input_dim, kernel_size=1
313
+ )
314
+
315
+ else:
316
+ self.in_project = nn.Identity()
317
+ self.out_project = nn.Identity()
318
+
319
+ if self.codebook_type == "euclidean":
320
+ self.codebook = EuclideanCodebook(
321
+ self.codebook_dim,
322
+ codebook_size=self.codebook_size,
323
+ kmeans_init=self.kmeans_init,
324
+ kmeans_iters=self.kmeans_iters,
325
+ decay=self.decay,
326
+ eps=self.eps,
327
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
328
+ weight_init=self.weight_init,
329
+ )
330
+ elif self.codebook_type == "simple":
331
+ self.codebook = SimpleCodebook(
332
+ self.codebook_dim,
333
+ codebook_size=self.codebook_size,
334
+ use_l2_normlize=self.use_l2_normlize,
335
+ )
336
+ else:
337
+ raise NotImplementedError(
338
+ f"codebook_type {self.codebook_type} is not implemented!"
339
+ )
340
+
341
+ def forward(self, z):
342
+ """
343
+ Parameters
344
+ ----------
345
+ z: torch.Tensor[B x D x T]
346
+
347
+ Returns
348
+ -------
349
+ z_q: torch.Tensor[B x D x T]
350
+ Quantized continuous representation of input
351
+ commit_loss: Tensor[B]
352
+ Commitment loss to train encoder to predict vectors closer to codebook entries
353
+ codebook_loss: Tensor[B]
354
+ Codebook loss to update the codebook
355
+ indices: torch.Tensor[B x T]
356
+ Codebook indices (quantized discrete representation of input)
357
+ z_e: torch.Tensor[B x D x T]
358
+ Projected latents (continuous representation of input before quantization)
359
+ """
360
+
361
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
362
+ z_e = self.in_project(z)
363
+ z_q, indices = self.decode_latents(z_e)
364
+
365
+ # Compute commitment loss and codebook loss
366
+ if self.training:
367
+ commit_loss = (
368
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
369
+ * self.commitment
370
+ )
371
+ codebook_loss = (
372
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
373
+ * self.codebook_loss_weight
374
+ )
375
+ else:
376
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
377
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
378
+
379
+ z_q = z_e + (z_q - z_e).detach()
380
+
381
+ z_q = self.out_project(z_q)
382
+
383
+ return z_q, commit_loss, codebook_loss, indices, z_e
384
+
385
+ def decode_latents(self, latents):
386
+ encodings = rearrange(latents, "b d t -> b t d")
387
+ z_q, indices = self.codebook(encodings)
388
+ z_q = z_q.transpose(1, 2)
389
+ return z_q, indices
390
+
391
+ def vq2emb(self, vq, out_proj=True):
392
+ emb = self.codebook.vq2emb(vq)
393
+ emb = emb.transpose(1, 2)
394
+ if out_proj:
395
+ emb = self.out_project(emb)
396
+ return emb
397
+
398
+ def latent2dist(self, latents):
399
+ latents = rearrange(latents, "b d t -> b t d")
400
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
401
+ return dist, embed_ind, quantize.transpose(1, 2)