xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,317 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This source file is copied from https://github.com/facebookresearch/encodec
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+
13
+ """Arithmetic coder."""
14
+
15
+ import io
16
+ import math
17
+ import random
18
+ import typing as tp
19
+ import torch
20
+
21
+ from ..binary import BitPacker, BitUnpacker
22
+
23
+
24
+ def build_stable_quantized_cdf(
25
+ pdf: torch.Tensor,
26
+ total_range_bits: int,
27
+ roundoff: float = 1e-8,
28
+ min_range: int = 2,
29
+ check: bool = True,
30
+ ) -> torch.Tensor:
31
+ """Turn the given PDF into a quantized CDF that splits
32
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
33
+ to the PDF.
34
+
35
+ Args:
36
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
37
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
38
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
39
+ roundoff (float): will round the pdf up to that level to remove difference coming
40
+ from e.g. evaluating the Language Model on different architectures.
41
+ min_range (int): minimum range width. Should always be at least 2 for numerical
42
+ stability. Use this to avoid pathological behavior is a value
43
+ that is expected to be rare actually happens in real life.
44
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
45
+ """
46
+ pdf = pdf.detach()
47
+ if roundoff:
48
+ pdf = (pdf / roundoff).floor() * roundoff
49
+ # interpolate with uniform distribution to achieve desired minimum probability.
50
+ total_range = 2**total_range_bits
51
+ cardinality = len(pdf)
52
+ alpha = min_range * cardinality / total_range
53
+ assert alpha <= 1, "you must reduce min_range"
54
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
55
+ ranges += min_range
56
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
57
+ if min_range < 2:
58
+ raise ValueError("min_range must be at least 2.")
59
+ if check:
60
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
61
+ if (
62
+ (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
63
+ ).any() or quantized_cdf[0] < min_range:
64
+ raise ValueError("You must increase your total_range_bits.")
65
+ return quantized_cdf
66
+
67
+
68
+ class ArithmeticCoder:
69
+ """ArithmeticCoder,
70
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
71
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
72
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
73
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
74
+ sequence `(s_t)` by doing the following:
75
+
76
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
77
+ 2) For each time step t, split the current range into contiguous chunks,
78
+ one for each possible outcome, with size roughly proportional to `p`.
79
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
80
+ would be `{[0, 2], [3, 3]}`.
81
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
82
+ 4) When done encoding all the values, just select any value remaining in the range.
83
+
84
+ You will notice that this procedure can fail: for instance if at any point in time
85
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
86
+ possible outcome. Intuitively, the more likely a value is, the less the range width
87
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
88
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
89
+ with a fixed budget.
90
+
91
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
92
+ when the current range decreases below a given limit (given by `total_range_bits`), without
93
+ having to redo all the computations. If we encode mostly likely values, we will seldom
94
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
95
+
96
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
97
+ code works for any sequence `(p_t)` possibly different for each timestep.
98
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
99
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
100
+
101
+ Args:
102
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
103
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
104
+ Any time the current range width fall under this limit, new bits will
105
+ be injected to rescale the initial range.
106
+ """
107
+
108
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
109
+ assert total_range_bits <= 30
110
+ self.total_range_bits = total_range_bits
111
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
112
+ self.low: int = 0
113
+ self.high: int = 0
114
+ self.max_bit: int = -1
115
+ self._dbg: tp.List[tp.Any] = []
116
+ self._dbg2: tp.List[tp.Any] = []
117
+
118
+ @property
119
+ def delta(self) -> int:
120
+ """Return the current range width."""
121
+ return self.high - self.low + 1
122
+
123
+ def _flush_common_prefix(self):
124
+ # If self.low and self.high start with the sames bits,
125
+ # those won't change anymore as we always just increase the range
126
+ # by powers of 2, and we can flush them out to the bit stream.
127
+ assert self.high >= self.low, (self.low, self.high)
128
+ assert self.high < 2 ** (self.max_bit + 1)
129
+ while self.max_bit >= 0:
130
+ b1 = self.low >> self.max_bit
131
+ b2 = self.high >> self.max_bit
132
+ if b1 == b2:
133
+ self.low -= b1 << self.max_bit
134
+ self.high -= b1 << self.max_bit
135
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
136
+ assert self.low >= 0
137
+ self.max_bit -= 1
138
+ self.packer.push(b1)
139
+ else:
140
+ break
141
+
142
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
143
+ """Push the given symbol on the stream, flushing out bits
144
+ if possible.
145
+
146
+ Args:
147
+ symbol (int): symbol to encode with the AC.
148
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
149
+ to build this from your pdf estimate.
150
+ """
151
+ while self.delta < 2**self.total_range_bits:
152
+ self.low *= 2
153
+ self.high = self.high * 2 + 1
154
+ self.max_bit += 1
155
+
156
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
157
+ range_high = quantized_cdf[symbol].item() - 1
158
+ effective_low = int(
159
+ math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
160
+ )
161
+ effective_high = int(
162
+ math.floor(range_high * (self.delta / (2**self.total_range_bits)))
163
+ )
164
+ assert self.low <= self.high
165
+ self.high = self.low + effective_high
166
+ self.low = self.low + effective_low
167
+ assert self.low <= self.high, (
168
+ effective_low,
169
+ effective_high,
170
+ range_low,
171
+ range_high,
172
+ )
173
+ self._dbg.append((self.low, self.high))
174
+ self._dbg2.append((self.low, self.high))
175
+ outs = self._flush_common_prefix()
176
+ assert self.low <= self.high
177
+ assert self.max_bit >= -1
178
+ assert self.max_bit <= 61, self.max_bit
179
+ return outs
180
+
181
+ def flush(self):
182
+ """Flush the remaining information to the stream."""
183
+ while self.max_bit >= 0:
184
+ b1 = (self.low >> self.max_bit) & 1
185
+ self.packer.push(b1)
186
+ self.max_bit -= 1
187
+ self.packer.flush()
188
+
189
+
190
+ class ArithmeticDecoder:
191
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
192
+
193
+ Note that this must be called with **exactly** the same parameters and sequence
194
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
195
+
196
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
197
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
198
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
199
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
200
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
201
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
202
+ and we will need to read new bits from the stream and repeat the process.
203
+
204
+ """
205
+
206
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
207
+ self.total_range_bits = total_range_bits
208
+ self.low: int = 0
209
+ self.high: int = 0
210
+ self.current: int = 0
211
+ self.max_bit: int = -1
212
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
213
+ # Following is for debugging
214
+ self._dbg: tp.List[tp.Any] = []
215
+ self._dbg2: tp.List[tp.Any] = []
216
+ self._last: tp.Any = None
217
+
218
+ @property
219
+ def delta(self) -> int:
220
+ return self.high - self.low + 1
221
+
222
+ def _flush_common_prefix(self):
223
+ # Given the current range [L, H], if both have a common prefix,
224
+ # we know we can remove it from our representation to avoid handling large numbers.
225
+ while self.max_bit >= 0:
226
+ b1 = self.low >> self.max_bit
227
+ b2 = self.high >> self.max_bit
228
+ if b1 == b2:
229
+ self.low -= b1 << self.max_bit
230
+ self.high -= b1 << self.max_bit
231
+ self.current -= b1 << self.max_bit
232
+ assert self.high >= self.low
233
+ assert self.low >= 0
234
+ self.max_bit -= 1
235
+ else:
236
+ break
237
+
238
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
239
+ """Pull a symbol, reading as many bits from the stream as required.
240
+ This returns `None` when the stream has been exhausted.
241
+
242
+ Args:
243
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
244
+ to build this from your pdf estimate. This must be **exatly**
245
+ the same cdf as the one used at encoding time.
246
+ """
247
+ while self.delta < 2**self.total_range_bits:
248
+ bit = self.unpacker.pull()
249
+ if bit is None:
250
+ return None
251
+ self.low *= 2
252
+ self.high = self.high * 2 + 1
253
+ self.current = self.current * 2 + bit
254
+ self.max_bit += 1
255
+
256
+ def bin_search(low_idx: int, high_idx: int):
257
+ # Binary search is not just for coding interviews :)
258
+ if high_idx < low_idx:
259
+ raise RuntimeError("Binary search failed")
260
+ mid = (low_idx + high_idx) // 2
261
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
262
+ range_high = quantized_cdf[mid].item() - 1
263
+ effective_low = int(
264
+ math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
265
+ )
266
+ effective_high = int(
267
+ math.floor(range_high * (self.delta / (2**self.total_range_bits)))
268
+ )
269
+ low = effective_low + self.low
270
+ high = effective_high + self.low
271
+ if self.current >= low:
272
+ if self.current <= high:
273
+ return (mid, low, high, self.current)
274
+ else:
275
+ return bin_search(mid + 1, high_idx)
276
+ else:
277
+ return bin_search(low_idx, mid - 1)
278
+
279
+ self._last = (self.low, self.high, self.current, self.max_bit)
280
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
281
+ self._dbg.append((self.low, self.high, self.current))
282
+ self._flush_common_prefix()
283
+ self._dbg2.append((self.low, self.high, self.current))
284
+
285
+ return sym
286
+
287
+
288
+ def test():
289
+ torch.manual_seed(1234)
290
+ random.seed(1234)
291
+ for _ in range(4):
292
+ pdfs = []
293
+ cardinality = random.randrange(4000)
294
+ steps = random.randrange(100, 500)
295
+ fo = io.BytesIO()
296
+ encoder = ArithmeticCoder(fo)
297
+ symbols = []
298
+ for step in range(steps):
299
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
300
+ pdfs.append(pdf)
301
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
302
+ symbol = torch.multinomial(pdf, 1).item()
303
+ symbols.append(symbol)
304
+ encoder.push(symbol, q_cdf)
305
+ encoder.flush()
306
+
307
+ fo.seek(0)
308
+ decoder = ArithmeticDecoder(fo)
309
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
310
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
311
+ decoded_symbol = decoder.pull(q_cdf)
312
+ assert decoded_symbol == symbol, idx
313
+ assert decoder.pull(torch.zeros(1)) is None
314
+
315
+
316
+ if __name__ == "__main__":
317
+ test()
@@ -0,0 +1,388 @@
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This source file is copied from https://github.com/facebookresearch/encodec
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ #
13
+ # This implementation is inspired from
14
+ # https://github.com/lucidrains/vector-quantize-pytorch
15
+ # which is released under MIT License. Hereafter, the original license:
16
+ # MIT License
17
+ #
18
+ # Copyright (c) 2020 Phil Wang
19
+ #
20
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
21
+ # of this software and associated documentation files (the "Software"), to deal
22
+ # in the Software without restriction, including without limitation the rights
23
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
24
+ # copies of the Software, and to permit persons to whom the Software is
25
+ # furnished to do so, subject to the following conditions:
26
+ #
27
+ # The above copyright notice and this permission notice shall be included in all
28
+ # copies or substantial portions of the Software.
29
+ #
30
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
31
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
32
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
33
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
34
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
35
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
36
+ # SOFTWARE.
37
+
38
+ """Core vector quantization implementation."""
39
+ import typing as tp
40
+
41
+ from einops import rearrange, repeat
42
+ import torch
43
+ from torch import nn
44
+ import torch.nn.functional as F
45
+
46
+ from .distrib import broadcast_tensors, rank
47
+
48
+
49
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
50
+ return val if val is not None else d
51
+
52
+
53
+ def ema_inplace(moving_avg, new, decay: float):
54
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
55
+
56
+
57
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
58
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
59
+
60
+
61
+ def uniform_init(*shape: int):
62
+ t = torch.empty(shape)
63
+ nn.init.kaiming_uniform_(t)
64
+ return t
65
+
66
+
67
+ def sample_vectors(samples, num: int):
68
+ num_samples, device = samples.shape[0], samples.device
69
+
70
+ if num_samples >= num:
71
+ indices = torch.randperm(num_samples, device=device)[:num]
72
+ else:
73
+ indices = torch.randint(0, num_samples, (num,), device=device)
74
+
75
+ return samples[indices]
76
+
77
+
78
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
79
+ dim, dtype = samples.shape[-1], samples.dtype
80
+
81
+ means = sample_vectors(samples, num_clusters)
82
+
83
+ for _ in range(num_iters):
84
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
85
+ dists = -(diffs**2).sum(dim=-1)
86
+
87
+ buckets = dists.max(dim=-1).indices
88
+ bins = torch.bincount(buckets, minlength=num_clusters)
89
+ zero_mask = bins == 0
90
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
91
+
92
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
93
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
94
+ new_means = new_means / bins_min_clamped[..., None]
95
+
96
+ means = torch.where(zero_mask[..., None], means, new_means)
97
+
98
+ return means, bins
99
+
100
+
101
+ class EuclideanCodebook(nn.Module):
102
+ """Codebook with Euclidean distance.
103
+ Args:
104
+ dim (int): Dimension.
105
+ codebook_size (int): Codebook size.
106
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
107
+ If set to true, run the k-means algorithm on the first training batch and use
108
+ the learned centroids as initialization.
109
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
110
+ decay (float): Decay for exponential moving average over the codebooks.
111
+ epsilon (float): Epsilon value for numerical stability.
112
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
113
+ that have an exponential moving average cluster size less than the specified threshold with
114
+ randomly selected vector from the current batch.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ codebook_size: int,
121
+ kmeans_init: int = False,
122
+ kmeans_iters: int = 10,
123
+ decay: float = 0.99,
124
+ epsilon: float = 1e-5,
125
+ threshold_ema_dead_code: int = 2,
126
+ ):
127
+ super().__init__()
128
+ self.decay = decay
129
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
130
+ uniform_init if not kmeans_init else torch.zeros
131
+ )
132
+ embed = init_fn(codebook_size, dim)
133
+
134
+ self.codebook_size = codebook_size
135
+
136
+ self.kmeans_iters = kmeans_iters
137
+ self.epsilon = epsilon
138
+ self.threshold_ema_dead_code = threshold_ema_dead_code
139
+
140
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
141
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
142
+ self.register_buffer("embed", embed)
143
+ self.register_buffer("embed_avg", embed.clone())
144
+
145
+ @torch.jit.ignore
146
+ def init_embed_(self, data):
147
+ if self.inited:
148
+ return
149
+
150
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
151
+ self.embed.data.copy_(embed)
152
+ self.embed_avg.data.copy_(embed.clone())
153
+ self.cluster_size.data.copy_(cluster_size)
154
+ self.inited.data.copy_(torch.Tensor([True]))
155
+ # Make sure all buffers across workers are in sync after initialization
156
+ # broadcast_tensors(self.buffers())
157
+
158
+ def replace_(self, samples, mask):
159
+ modified_codebook = torch.where(
160
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
161
+ )
162
+ self.embed.data.copy_(modified_codebook)
163
+
164
+ def expire_codes_(self, batch_samples):
165
+ if self.threshold_ema_dead_code == 0:
166
+ return
167
+
168
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
169
+ if not torch.any(expired_codes):
170
+ return
171
+
172
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
173
+ self.replace_(batch_samples, mask=expired_codes)
174
+ # broadcast_tensors(self.buffers())
175
+
176
+ def preprocess(self, x):
177
+ x = rearrange(x, "... d -> (...) d")
178
+ return x
179
+
180
+ def quantize(self, x):
181
+ embed = self.embed.t()
182
+ dist = -(
183
+ x.pow(2).sum(1, keepdim=True)
184
+ - 2 * x @ embed
185
+ + embed.pow(2).sum(0, keepdim=True)
186
+ )
187
+ embed_ind = dist.max(dim=-1).indices
188
+ return embed_ind
189
+
190
+ def postprocess_emb(self, embed_ind, shape):
191
+ return embed_ind.view(*shape[:-1])
192
+
193
+ def dequantize(self, embed_ind):
194
+ quantize = F.embedding(embed_ind, self.embed)
195
+ return quantize
196
+
197
+ def encode(self, x):
198
+ shape = x.shape
199
+ # pre-process
200
+ x = self.preprocess(x)
201
+ # quantize
202
+ embed_ind = self.quantize(x)
203
+ # post-process
204
+ embed_ind = self.postprocess_emb(embed_ind, shape)
205
+ return embed_ind
206
+
207
+ def decode(self, embed_ind):
208
+ quantize = self.dequantize(embed_ind)
209
+ return quantize
210
+
211
+ def forward(self, x):
212
+ shape, dtype = x.shape, x.dtype
213
+ x = self.preprocess(x)
214
+
215
+ self.init_embed_(x)
216
+
217
+ embed_ind = self.quantize(x)
218
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
219
+ embed_ind = self.postprocess_emb(embed_ind, shape)
220
+ quantize = self.dequantize(embed_ind)
221
+
222
+ if self.training:
223
+ # We do the expiry of code at that point as buffers are in sync
224
+ # and all the workers will take the same decision.
225
+ self.expire_codes_(x)
226
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
227
+ embed_sum = x.t() @ embed_onehot
228
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
229
+ cluster_size = (
230
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
231
+ * self.cluster_size.sum()
232
+ )
233
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
234
+ self.embed.data.copy_(embed_normalized)
235
+
236
+ return quantize, embed_ind
237
+
238
+
239
+ class VectorQuantization(nn.Module):
240
+ """Vector quantization implementation.
241
+ Currently supports only euclidean distance.
242
+ Args:
243
+ dim (int): Dimension
244
+ codebook_size (int): Codebook size
245
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
246
+ decay (float): Decay for exponential moving average over the codebooks.
247
+ epsilon (float): Epsilon value for numerical stability.
248
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
249
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
250
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
251
+ that have an exponential moving average cluster size less than the specified threshold with
252
+ randomly selected vector from the current batch.
253
+ commitment_weight (float): Weight for commitment loss.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ dim: int,
259
+ codebook_size: int,
260
+ codebook_dim: tp.Optional[int] = None,
261
+ decay: float = 0.99,
262
+ epsilon: float = 1e-5,
263
+ kmeans_init: bool = True,
264
+ kmeans_iters: int = 50,
265
+ threshold_ema_dead_code: int = 2,
266
+ commitment_weight: float = 1.0,
267
+ ):
268
+ super().__init__()
269
+ _codebook_dim: int = default(codebook_dim, dim)
270
+
271
+ requires_projection = _codebook_dim != dim
272
+ self.project_in = (
273
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
274
+ )
275
+ self.project_out = (
276
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
277
+ )
278
+
279
+ self.epsilon = epsilon
280
+ self.commitment_weight = commitment_weight
281
+
282
+ self._codebook = EuclideanCodebook(
283
+ dim=_codebook_dim,
284
+ codebook_size=codebook_size,
285
+ kmeans_init=kmeans_init,
286
+ kmeans_iters=kmeans_iters,
287
+ decay=decay,
288
+ epsilon=epsilon,
289
+ threshold_ema_dead_code=threshold_ema_dead_code,
290
+ )
291
+ self.codebook_size = codebook_size
292
+
293
+ @property
294
+ def codebook(self):
295
+ return self._codebook.embed
296
+
297
+ def encode(self, x):
298
+ x = rearrange(x, "b d n -> b n d")
299
+ x = self.project_in(x)
300
+ embed_in = self._codebook.encode(x)
301
+ return embed_in
302
+
303
+ def decode(self, embed_ind):
304
+ quantize = self._codebook.decode(embed_ind)
305
+ quantize = self.project_out(quantize)
306
+ quantize = rearrange(quantize, "b n d -> b d n")
307
+ return quantize
308
+
309
+ def forward(self, x):
310
+ device = x.device
311
+ x = rearrange(x, "b d n -> b n d")
312
+ x = self.project_in(x)
313
+
314
+ quantize, embed_ind = self._codebook(x)
315
+
316
+ if self.training:
317
+ quantize = x + (quantize - x).detach()
318
+
319
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
320
+
321
+ if self.training:
322
+ if self.commitment_weight > 0:
323
+ commit_loss = F.mse_loss(quantize.detach(), x)
324
+ loss = loss + commit_loss * self.commitment_weight
325
+
326
+ quantize = self.project_out(quantize)
327
+ quantize = rearrange(quantize, "b n d -> b d n")
328
+ return quantize, embed_ind, loss
329
+
330
+
331
+ class ResidualVectorQuantization(nn.Module):
332
+ """Residual vector quantization implementation.
333
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
334
+ """
335
+
336
+ def __init__(self, *, num_quantizers, **kwargs):
337
+ super().__init__()
338
+ self.layers = nn.ModuleList(
339
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
340
+ )
341
+
342
+ def forward(
343
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
344
+ ):
345
+ quantized_out = 0.0
346
+ residual = x
347
+
348
+ all_losses = []
349
+ all_indices = []
350
+ out_quantized = []
351
+
352
+ n_q = n_q or len(self.layers)
353
+
354
+ for i, layer in enumerate(self.layers[:n_q]):
355
+ quantized, indices, loss = layer(residual)
356
+ residual = residual - quantized
357
+ quantized_out = quantized_out + quantized
358
+
359
+ all_indices.append(indices)
360
+ all_losses.append(loss)
361
+ if layers and i in layers:
362
+ out_quantized.append(quantized)
363
+
364
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
365
+ return quantized_out, out_indices, out_losses, out_quantized
366
+
367
+ def encode(
368
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
369
+ ) -> torch.Tensor:
370
+ residual = x
371
+ all_indices = []
372
+ n_q = n_q or len(self.layers)
373
+ st = st or 0
374
+ for layer in self.layers[st:n_q]:
375
+ indices = layer.encode(residual)
376
+ quantized = layer.decode(indices)
377
+ residual = residual - quantized
378
+ all_indices.append(indices)
379
+ out_indices = torch.stack(all_indices)
380
+ return out_indices
381
+
382
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
383
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
384
+ for i, indices in enumerate(q_indices):
385
+ layer = self.layers[st + i]
386
+ quantized = layer.decode(indices)
387
+ quantized_out = quantized_out + quantized
388
+ return quantized_out