xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Arithmetic coder."""
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
import math
|
|
17
|
+
import random
|
|
18
|
+
import typing as tp
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from ..binary import BitPacker, BitUnpacker
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_stable_quantized_cdf(
|
|
25
|
+
pdf: torch.Tensor,
|
|
26
|
+
total_range_bits: int,
|
|
27
|
+
roundoff: float = 1e-8,
|
|
28
|
+
min_range: int = 2,
|
|
29
|
+
check: bool = True,
|
|
30
|
+
) -> torch.Tensor:
|
|
31
|
+
"""Turn the given PDF into a quantized CDF that splits
|
|
32
|
+
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
|
33
|
+
to the PDF.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
|
37
|
+
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
|
38
|
+
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
|
39
|
+
roundoff (float): will round the pdf up to that level to remove difference coming
|
|
40
|
+
from e.g. evaluating the Language Model on different architectures.
|
|
41
|
+
min_range (int): minimum range width. Should always be at least 2 for numerical
|
|
42
|
+
stability. Use this to avoid pathological behavior is a value
|
|
43
|
+
that is expected to be rare actually happens in real life.
|
|
44
|
+
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
|
45
|
+
"""
|
|
46
|
+
pdf = pdf.detach()
|
|
47
|
+
if roundoff:
|
|
48
|
+
pdf = (pdf / roundoff).floor() * roundoff
|
|
49
|
+
# interpolate with uniform distribution to achieve desired minimum probability.
|
|
50
|
+
total_range = 2**total_range_bits
|
|
51
|
+
cardinality = len(pdf)
|
|
52
|
+
alpha = min_range * cardinality / total_range
|
|
53
|
+
assert alpha <= 1, "you must reduce min_range"
|
|
54
|
+
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
|
55
|
+
ranges += min_range
|
|
56
|
+
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
|
57
|
+
if min_range < 2:
|
|
58
|
+
raise ValueError("min_range must be at least 2.")
|
|
59
|
+
if check:
|
|
60
|
+
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
|
61
|
+
if (
|
|
62
|
+
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
|
|
63
|
+
).any() or quantized_cdf[0] < min_range:
|
|
64
|
+
raise ValueError("You must increase your total_range_bits.")
|
|
65
|
+
return quantized_cdf
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ArithmeticCoder:
|
|
69
|
+
"""ArithmeticCoder,
|
|
70
|
+
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
|
71
|
+
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
|
72
|
+
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
|
73
|
+
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
|
74
|
+
sequence `(s_t)` by doing the following:
|
|
75
|
+
|
|
76
|
+
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
|
77
|
+
2) For each time step t, split the current range into contiguous chunks,
|
|
78
|
+
one for each possible outcome, with size roughly proportional to `p`.
|
|
79
|
+
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
|
80
|
+
would be `{[0, 2], [3, 3]}`.
|
|
81
|
+
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
|
82
|
+
4) When done encoding all the values, just select any value remaining in the range.
|
|
83
|
+
|
|
84
|
+
You will notice that this procedure can fail: for instance if at any point in time
|
|
85
|
+
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
|
86
|
+
possible outcome. Intuitively, the more likely a value is, the less the range width
|
|
87
|
+
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
|
88
|
+
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
|
89
|
+
with a fixed budget.
|
|
90
|
+
|
|
91
|
+
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
|
92
|
+
when the current range decreases below a given limit (given by `total_range_bits`), without
|
|
93
|
+
having to redo all the computations. If we encode mostly likely values, we will seldom
|
|
94
|
+
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
|
95
|
+
|
|
96
|
+
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
|
97
|
+
code works for any sequence `(p_t)` possibly different for each timestep.
|
|
98
|
+
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
|
99
|
+
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
|
103
|
+
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
|
104
|
+
Any time the current range width fall under this limit, new bits will
|
|
105
|
+
be injected to rescale the initial range.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
|
109
|
+
assert total_range_bits <= 30
|
|
110
|
+
self.total_range_bits = total_range_bits
|
|
111
|
+
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
|
112
|
+
self.low: int = 0
|
|
113
|
+
self.high: int = 0
|
|
114
|
+
self.max_bit: int = -1
|
|
115
|
+
self._dbg: tp.List[tp.Any] = []
|
|
116
|
+
self._dbg2: tp.List[tp.Any] = []
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def delta(self) -> int:
|
|
120
|
+
"""Return the current range width."""
|
|
121
|
+
return self.high - self.low + 1
|
|
122
|
+
|
|
123
|
+
def _flush_common_prefix(self):
|
|
124
|
+
# If self.low and self.high start with the sames bits,
|
|
125
|
+
# those won't change anymore as we always just increase the range
|
|
126
|
+
# by powers of 2, and we can flush them out to the bit stream.
|
|
127
|
+
assert self.high >= self.low, (self.low, self.high)
|
|
128
|
+
assert self.high < 2 ** (self.max_bit + 1)
|
|
129
|
+
while self.max_bit >= 0:
|
|
130
|
+
b1 = self.low >> self.max_bit
|
|
131
|
+
b2 = self.high >> self.max_bit
|
|
132
|
+
if b1 == b2:
|
|
133
|
+
self.low -= b1 << self.max_bit
|
|
134
|
+
self.high -= b1 << self.max_bit
|
|
135
|
+
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
|
136
|
+
assert self.low >= 0
|
|
137
|
+
self.max_bit -= 1
|
|
138
|
+
self.packer.push(b1)
|
|
139
|
+
else:
|
|
140
|
+
break
|
|
141
|
+
|
|
142
|
+
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
|
143
|
+
"""Push the given symbol on the stream, flushing out bits
|
|
144
|
+
if possible.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
symbol (int): symbol to encode with the AC.
|
|
148
|
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
|
149
|
+
to build this from your pdf estimate.
|
|
150
|
+
"""
|
|
151
|
+
while self.delta < 2**self.total_range_bits:
|
|
152
|
+
self.low *= 2
|
|
153
|
+
self.high = self.high * 2 + 1
|
|
154
|
+
self.max_bit += 1
|
|
155
|
+
|
|
156
|
+
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
|
157
|
+
range_high = quantized_cdf[symbol].item() - 1
|
|
158
|
+
effective_low = int(
|
|
159
|
+
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
|
160
|
+
)
|
|
161
|
+
effective_high = int(
|
|
162
|
+
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
|
163
|
+
)
|
|
164
|
+
assert self.low <= self.high
|
|
165
|
+
self.high = self.low + effective_high
|
|
166
|
+
self.low = self.low + effective_low
|
|
167
|
+
assert self.low <= self.high, (
|
|
168
|
+
effective_low,
|
|
169
|
+
effective_high,
|
|
170
|
+
range_low,
|
|
171
|
+
range_high,
|
|
172
|
+
)
|
|
173
|
+
self._dbg.append((self.low, self.high))
|
|
174
|
+
self._dbg2.append((self.low, self.high))
|
|
175
|
+
outs = self._flush_common_prefix()
|
|
176
|
+
assert self.low <= self.high
|
|
177
|
+
assert self.max_bit >= -1
|
|
178
|
+
assert self.max_bit <= 61, self.max_bit
|
|
179
|
+
return outs
|
|
180
|
+
|
|
181
|
+
def flush(self):
|
|
182
|
+
"""Flush the remaining information to the stream."""
|
|
183
|
+
while self.max_bit >= 0:
|
|
184
|
+
b1 = (self.low >> self.max_bit) & 1
|
|
185
|
+
self.packer.push(b1)
|
|
186
|
+
self.max_bit -= 1
|
|
187
|
+
self.packer.flush()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ArithmeticDecoder:
|
|
191
|
+
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
|
192
|
+
|
|
193
|
+
Note that this must be called with **exactly** the same parameters and sequence
|
|
194
|
+
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
|
195
|
+
|
|
196
|
+
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
|
197
|
+
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
|
198
|
+
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
|
199
|
+
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
|
200
|
+
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
|
201
|
+
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
|
202
|
+
and we will need to read new bits from the stream and repeat the process.
|
|
203
|
+
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
|
207
|
+
self.total_range_bits = total_range_bits
|
|
208
|
+
self.low: int = 0
|
|
209
|
+
self.high: int = 0
|
|
210
|
+
self.current: int = 0
|
|
211
|
+
self.max_bit: int = -1
|
|
212
|
+
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
|
213
|
+
# Following is for debugging
|
|
214
|
+
self._dbg: tp.List[tp.Any] = []
|
|
215
|
+
self._dbg2: tp.List[tp.Any] = []
|
|
216
|
+
self._last: tp.Any = None
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def delta(self) -> int:
|
|
220
|
+
return self.high - self.low + 1
|
|
221
|
+
|
|
222
|
+
def _flush_common_prefix(self):
|
|
223
|
+
# Given the current range [L, H], if both have a common prefix,
|
|
224
|
+
# we know we can remove it from our representation to avoid handling large numbers.
|
|
225
|
+
while self.max_bit >= 0:
|
|
226
|
+
b1 = self.low >> self.max_bit
|
|
227
|
+
b2 = self.high >> self.max_bit
|
|
228
|
+
if b1 == b2:
|
|
229
|
+
self.low -= b1 << self.max_bit
|
|
230
|
+
self.high -= b1 << self.max_bit
|
|
231
|
+
self.current -= b1 << self.max_bit
|
|
232
|
+
assert self.high >= self.low
|
|
233
|
+
assert self.low >= 0
|
|
234
|
+
self.max_bit -= 1
|
|
235
|
+
else:
|
|
236
|
+
break
|
|
237
|
+
|
|
238
|
+
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
|
239
|
+
"""Pull a symbol, reading as many bits from the stream as required.
|
|
240
|
+
This returns `None` when the stream has been exhausted.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
|
244
|
+
to build this from your pdf estimate. This must be **exatly**
|
|
245
|
+
the same cdf as the one used at encoding time.
|
|
246
|
+
"""
|
|
247
|
+
while self.delta < 2**self.total_range_bits:
|
|
248
|
+
bit = self.unpacker.pull()
|
|
249
|
+
if bit is None:
|
|
250
|
+
return None
|
|
251
|
+
self.low *= 2
|
|
252
|
+
self.high = self.high * 2 + 1
|
|
253
|
+
self.current = self.current * 2 + bit
|
|
254
|
+
self.max_bit += 1
|
|
255
|
+
|
|
256
|
+
def bin_search(low_idx: int, high_idx: int):
|
|
257
|
+
# Binary search is not just for coding interviews :)
|
|
258
|
+
if high_idx < low_idx:
|
|
259
|
+
raise RuntimeError("Binary search failed")
|
|
260
|
+
mid = (low_idx + high_idx) // 2
|
|
261
|
+
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
|
262
|
+
range_high = quantized_cdf[mid].item() - 1
|
|
263
|
+
effective_low = int(
|
|
264
|
+
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
|
265
|
+
)
|
|
266
|
+
effective_high = int(
|
|
267
|
+
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
|
268
|
+
)
|
|
269
|
+
low = effective_low + self.low
|
|
270
|
+
high = effective_high + self.low
|
|
271
|
+
if self.current >= low:
|
|
272
|
+
if self.current <= high:
|
|
273
|
+
return (mid, low, high, self.current)
|
|
274
|
+
else:
|
|
275
|
+
return bin_search(mid + 1, high_idx)
|
|
276
|
+
else:
|
|
277
|
+
return bin_search(low_idx, mid - 1)
|
|
278
|
+
|
|
279
|
+
self._last = (self.low, self.high, self.current, self.max_bit)
|
|
280
|
+
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
|
281
|
+
self._dbg.append((self.low, self.high, self.current))
|
|
282
|
+
self._flush_common_prefix()
|
|
283
|
+
self._dbg2.append((self.low, self.high, self.current))
|
|
284
|
+
|
|
285
|
+
return sym
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test():
|
|
289
|
+
torch.manual_seed(1234)
|
|
290
|
+
random.seed(1234)
|
|
291
|
+
for _ in range(4):
|
|
292
|
+
pdfs = []
|
|
293
|
+
cardinality = random.randrange(4000)
|
|
294
|
+
steps = random.randrange(100, 500)
|
|
295
|
+
fo = io.BytesIO()
|
|
296
|
+
encoder = ArithmeticCoder(fo)
|
|
297
|
+
symbols = []
|
|
298
|
+
for step in range(steps):
|
|
299
|
+
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
|
300
|
+
pdfs.append(pdf)
|
|
301
|
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
|
302
|
+
symbol = torch.multinomial(pdf, 1).item()
|
|
303
|
+
symbols.append(symbol)
|
|
304
|
+
encoder.push(symbol, q_cdf)
|
|
305
|
+
encoder.flush()
|
|
306
|
+
|
|
307
|
+
fo.seek(0)
|
|
308
|
+
decoder = ArithmeticDecoder(fo)
|
|
309
|
+
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
|
310
|
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
|
311
|
+
decoded_symbol = decoder.pull(q_cdf)
|
|
312
|
+
assert decoded_symbol == symbol, idx
|
|
313
|
+
assert decoder.pull(torch.zeros(1)) is None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
if __name__ == "__main__":
|
|
317
|
+
test()
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
#
|
|
13
|
+
# This implementation is inspired from
|
|
14
|
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
|
15
|
+
# which is released under MIT License. Hereafter, the original license:
|
|
16
|
+
# MIT License
|
|
17
|
+
#
|
|
18
|
+
# Copyright (c) 2020 Phil Wang
|
|
19
|
+
#
|
|
20
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
21
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
22
|
+
# in the Software without restriction, including without limitation the rights
|
|
23
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
24
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
25
|
+
# furnished to do so, subject to the following conditions:
|
|
26
|
+
#
|
|
27
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
28
|
+
# copies or substantial portions of the Software.
|
|
29
|
+
#
|
|
30
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
31
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
32
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
33
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
34
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
35
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
36
|
+
# SOFTWARE.
|
|
37
|
+
|
|
38
|
+
"""Core vector quantization implementation."""
|
|
39
|
+
import typing as tp
|
|
40
|
+
|
|
41
|
+
from einops import rearrange, repeat
|
|
42
|
+
import torch
|
|
43
|
+
from torch import nn
|
|
44
|
+
import torch.nn.functional as F
|
|
45
|
+
|
|
46
|
+
from .distrib import broadcast_tensors, rank
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
|
50
|
+
return val if val is not None else d
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def ema_inplace(moving_avg, new, decay: float):
|
|
54
|
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
|
58
|
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def uniform_init(*shape: int):
|
|
62
|
+
t = torch.empty(shape)
|
|
63
|
+
nn.init.kaiming_uniform_(t)
|
|
64
|
+
return t
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def sample_vectors(samples, num: int):
|
|
68
|
+
num_samples, device = samples.shape[0], samples.device
|
|
69
|
+
|
|
70
|
+
if num_samples >= num:
|
|
71
|
+
indices = torch.randperm(num_samples, device=device)[:num]
|
|
72
|
+
else:
|
|
73
|
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
|
74
|
+
|
|
75
|
+
return samples[indices]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
|
79
|
+
dim, dtype = samples.shape[-1], samples.dtype
|
|
80
|
+
|
|
81
|
+
means = sample_vectors(samples, num_clusters)
|
|
82
|
+
|
|
83
|
+
for _ in range(num_iters):
|
|
84
|
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
|
85
|
+
dists = -(diffs**2).sum(dim=-1)
|
|
86
|
+
|
|
87
|
+
buckets = dists.max(dim=-1).indices
|
|
88
|
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
|
89
|
+
zero_mask = bins == 0
|
|
90
|
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
91
|
+
|
|
92
|
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
|
93
|
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
|
94
|
+
new_means = new_means / bins_min_clamped[..., None]
|
|
95
|
+
|
|
96
|
+
means = torch.where(zero_mask[..., None], means, new_means)
|
|
97
|
+
|
|
98
|
+
return means, bins
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class EuclideanCodebook(nn.Module):
|
|
102
|
+
"""Codebook with Euclidean distance.
|
|
103
|
+
Args:
|
|
104
|
+
dim (int): Dimension.
|
|
105
|
+
codebook_size (int): Codebook size.
|
|
106
|
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
|
107
|
+
If set to true, run the k-means algorithm on the first training batch and use
|
|
108
|
+
the learned centroids as initialization.
|
|
109
|
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
|
110
|
+
decay (float): Decay for exponential moving average over the codebooks.
|
|
111
|
+
epsilon (float): Epsilon value for numerical stability.
|
|
112
|
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
113
|
+
that have an exponential moving average cluster size less than the specified threshold with
|
|
114
|
+
randomly selected vector from the current batch.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
dim: int,
|
|
120
|
+
codebook_size: int,
|
|
121
|
+
kmeans_init: int = False,
|
|
122
|
+
kmeans_iters: int = 10,
|
|
123
|
+
decay: float = 0.99,
|
|
124
|
+
epsilon: float = 1e-5,
|
|
125
|
+
threshold_ema_dead_code: int = 2,
|
|
126
|
+
):
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.decay = decay
|
|
129
|
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
|
130
|
+
uniform_init if not kmeans_init else torch.zeros
|
|
131
|
+
)
|
|
132
|
+
embed = init_fn(codebook_size, dim)
|
|
133
|
+
|
|
134
|
+
self.codebook_size = codebook_size
|
|
135
|
+
|
|
136
|
+
self.kmeans_iters = kmeans_iters
|
|
137
|
+
self.epsilon = epsilon
|
|
138
|
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
139
|
+
|
|
140
|
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
|
141
|
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
|
142
|
+
self.register_buffer("embed", embed)
|
|
143
|
+
self.register_buffer("embed_avg", embed.clone())
|
|
144
|
+
|
|
145
|
+
@torch.jit.ignore
|
|
146
|
+
def init_embed_(self, data):
|
|
147
|
+
if self.inited:
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
|
151
|
+
self.embed.data.copy_(embed)
|
|
152
|
+
self.embed_avg.data.copy_(embed.clone())
|
|
153
|
+
self.cluster_size.data.copy_(cluster_size)
|
|
154
|
+
self.inited.data.copy_(torch.Tensor([True]))
|
|
155
|
+
# Make sure all buffers across workers are in sync after initialization
|
|
156
|
+
# broadcast_tensors(self.buffers())
|
|
157
|
+
|
|
158
|
+
def replace_(self, samples, mask):
|
|
159
|
+
modified_codebook = torch.where(
|
|
160
|
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
|
161
|
+
)
|
|
162
|
+
self.embed.data.copy_(modified_codebook)
|
|
163
|
+
|
|
164
|
+
def expire_codes_(self, batch_samples):
|
|
165
|
+
if self.threshold_ema_dead_code == 0:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
169
|
+
if not torch.any(expired_codes):
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
|
173
|
+
self.replace_(batch_samples, mask=expired_codes)
|
|
174
|
+
# broadcast_tensors(self.buffers())
|
|
175
|
+
|
|
176
|
+
def preprocess(self, x):
|
|
177
|
+
x = rearrange(x, "... d -> (...) d")
|
|
178
|
+
return x
|
|
179
|
+
|
|
180
|
+
def quantize(self, x):
|
|
181
|
+
embed = self.embed.t()
|
|
182
|
+
dist = -(
|
|
183
|
+
x.pow(2).sum(1, keepdim=True)
|
|
184
|
+
- 2 * x @ embed
|
|
185
|
+
+ embed.pow(2).sum(0, keepdim=True)
|
|
186
|
+
)
|
|
187
|
+
embed_ind = dist.max(dim=-1).indices
|
|
188
|
+
return embed_ind
|
|
189
|
+
|
|
190
|
+
def postprocess_emb(self, embed_ind, shape):
|
|
191
|
+
return embed_ind.view(*shape[:-1])
|
|
192
|
+
|
|
193
|
+
def dequantize(self, embed_ind):
|
|
194
|
+
quantize = F.embedding(embed_ind, self.embed)
|
|
195
|
+
return quantize
|
|
196
|
+
|
|
197
|
+
def encode(self, x):
|
|
198
|
+
shape = x.shape
|
|
199
|
+
# pre-process
|
|
200
|
+
x = self.preprocess(x)
|
|
201
|
+
# quantize
|
|
202
|
+
embed_ind = self.quantize(x)
|
|
203
|
+
# post-process
|
|
204
|
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
|
205
|
+
return embed_ind
|
|
206
|
+
|
|
207
|
+
def decode(self, embed_ind):
|
|
208
|
+
quantize = self.dequantize(embed_ind)
|
|
209
|
+
return quantize
|
|
210
|
+
|
|
211
|
+
def forward(self, x):
|
|
212
|
+
shape, dtype = x.shape, x.dtype
|
|
213
|
+
x = self.preprocess(x)
|
|
214
|
+
|
|
215
|
+
self.init_embed_(x)
|
|
216
|
+
|
|
217
|
+
embed_ind = self.quantize(x)
|
|
218
|
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
|
219
|
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
|
220
|
+
quantize = self.dequantize(embed_ind)
|
|
221
|
+
|
|
222
|
+
if self.training:
|
|
223
|
+
# We do the expiry of code at that point as buffers are in sync
|
|
224
|
+
# and all the workers will take the same decision.
|
|
225
|
+
self.expire_codes_(x)
|
|
226
|
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
|
227
|
+
embed_sum = x.t() @ embed_onehot
|
|
228
|
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
|
229
|
+
cluster_size = (
|
|
230
|
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
|
231
|
+
* self.cluster_size.sum()
|
|
232
|
+
)
|
|
233
|
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
|
234
|
+
self.embed.data.copy_(embed_normalized)
|
|
235
|
+
|
|
236
|
+
return quantize, embed_ind
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class VectorQuantization(nn.Module):
|
|
240
|
+
"""Vector quantization implementation.
|
|
241
|
+
Currently supports only euclidean distance.
|
|
242
|
+
Args:
|
|
243
|
+
dim (int): Dimension
|
|
244
|
+
codebook_size (int): Codebook size
|
|
245
|
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
|
246
|
+
decay (float): Decay for exponential moving average over the codebooks.
|
|
247
|
+
epsilon (float): Epsilon value for numerical stability.
|
|
248
|
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
249
|
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
250
|
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
251
|
+
that have an exponential moving average cluster size less than the specified threshold with
|
|
252
|
+
randomly selected vector from the current batch.
|
|
253
|
+
commitment_weight (float): Weight for commitment loss.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
dim: int,
|
|
259
|
+
codebook_size: int,
|
|
260
|
+
codebook_dim: tp.Optional[int] = None,
|
|
261
|
+
decay: float = 0.99,
|
|
262
|
+
epsilon: float = 1e-5,
|
|
263
|
+
kmeans_init: bool = True,
|
|
264
|
+
kmeans_iters: int = 50,
|
|
265
|
+
threshold_ema_dead_code: int = 2,
|
|
266
|
+
commitment_weight: float = 1.0,
|
|
267
|
+
):
|
|
268
|
+
super().__init__()
|
|
269
|
+
_codebook_dim: int = default(codebook_dim, dim)
|
|
270
|
+
|
|
271
|
+
requires_projection = _codebook_dim != dim
|
|
272
|
+
self.project_in = (
|
|
273
|
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
|
274
|
+
)
|
|
275
|
+
self.project_out = (
|
|
276
|
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
self.epsilon = epsilon
|
|
280
|
+
self.commitment_weight = commitment_weight
|
|
281
|
+
|
|
282
|
+
self._codebook = EuclideanCodebook(
|
|
283
|
+
dim=_codebook_dim,
|
|
284
|
+
codebook_size=codebook_size,
|
|
285
|
+
kmeans_init=kmeans_init,
|
|
286
|
+
kmeans_iters=kmeans_iters,
|
|
287
|
+
decay=decay,
|
|
288
|
+
epsilon=epsilon,
|
|
289
|
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
|
290
|
+
)
|
|
291
|
+
self.codebook_size = codebook_size
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def codebook(self):
|
|
295
|
+
return self._codebook.embed
|
|
296
|
+
|
|
297
|
+
def encode(self, x):
|
|
298
|
+
x = rearrange(x, "b d n -> b n d")
|
|
299
|
+
x = self.project_in(x)
|
|
300
|
+
embed_in = self._codebook.encode(x)
|
|
301
|
+
return embed_in
|
|
302
|
+
|
|
303
|
+
def decode(self, embed_ind):
|
|
304
|
+
quantize = self._codebook.decode(embed_ind)
|
|
305
|
+
quantize = self.project_out(quantize)
|
|
306
|
+
quantize = rearrange(quantize, "b n d -> b d n")
|
|
307
|
+
return quantize
|
|
308
|
+
|
|
309
|
+
def forward(self, x):
|
|
310
|
+
device = x.device
|
|
311
|
+
x = rearrange(x, "b d n -> b n d")
|
|
312
|
+
x = self.project_in(x)
|
|
313
|
+
|
|
314
|
+
quantize, embed_ind = self._codebook(x)
|
|
315
|
+
|
|
316
|
+
if self.training:
|
|
317
|
+
quantize = x + (quantize - x).detach()
|
|
318
|
+
|
|
319
|
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
320
|
+
|
|
321
|
+
if self.training:
|
|
322
|
+
if self.commitment_weight > 0:
|
|
323
|
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
|
324
|
+
loss = loss + commit_loss * self.commitment_weight
|
|
325
|
+
|
|
326
|
+
quantize = self.project_out(quantize)
|
|
327
|
+
quantize = rearrange(quantize, "b n d -> b d n")
|
|
328
|
+
return quantize, embed_ind, loss
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class ResidualVectorQuantization(nn.Module):
|
|
332
|
+
"""Residual vector quantization implementation.
|
|
333
|
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
def __init__(self, *, num_quantizers, **kwargs):
|
|
337
|
+
super().__init__()
|
|
338
|
+
self.layers = nn.ModuleList(
|
|
339
|
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
def forward(
|
|
343
|
+
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
|
344
|
+
):
|
|
345
|
+
quantized_out = 0.0
|
|
346
|
+
residual = x
|
|
347
|
+
|
|
348
|
+
all_losses = []
|
|
349
|
+
all_indices = []
|
|
350
|
+
out_quantized = []
|
|
351
|
+
|
|
352
|
+
n_q = n_q or len(self.layers)
|
|
353
|
+
|
|
354
|
+
for i, layer in enumerate(self.layers[:n_q]):
|
|
355
|
+
quantized, indices, loss = layer(residual)
|
|
356
|
+
residual = residual - quantized
|
|
357
|
+
quantized_out = quantized_out + quantized
|
|
358
|
+
|
|
359
|
+
all_indices.append(indices)
|
|
360
|
+
all_losses.append(loss)
|
|
361
|
+
if layers and i in layers:
|
|
362
|
+
out_quantized.append(quantized)
|
|
363
|
+
|
|
364
|
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
|
365
|
+
return quantized_out, out_indices, out_losses, out_quantized
|
|
366
|
+
|
|
367
|
+
def encode(
|
|
368
|
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
|
369
|
+
) -> torch.Tensor:
|
|
370
|
+
residual = x
|
|
371
|
+
all_indices = []
|
|
372
|
+
n_q = n_q or len(self.layers)
|
|
373
|
+
st = st or 0
|
|
374
|
+
for layer in self.layers[st:n_q]:
|
|
375
|
+
indices = layer.encode(residual)
|
|
376
|
+
quantized = layer.decode(indices)
|
|
377
|
+
residual = residual - quantized
|
|
378
|
+
all_indices.append(indices)
|
|
379
|
+
out_indices = torch.stack(all_indices)
|
|
380
|
+
return out_indices
|
|
381
|
+
|
|
382
|
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
|
383
|
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
|
384
|
+
for i, indices in enumerate(q_indices):
|
|
385
|
+
layer = self.layers[st + i]
|
|
386
|
+
quantized = layer.decode(indices)
|
|
387
|
+
quantized_out = quantized_out + quantized
|
|
388
|
+
return quantized_out
|