xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +11 -28
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/core/supervisor.py +87 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +38 -1
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +464 -2
- xinference/model/llm/sglang/core.py +30 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +12 -9
- xinference/model/llm/vllm/core.py +93 -17
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from tokenizer import get_tokenizer
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from GPTQ import GenericGPTQRunner, InputRecorder
|
|
16
|
+
from eval import get_task_dict, evaluate, lm_eval
|
|
17
|
+
except:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
from model import Transformer
|
|
21
|
+
|
|
22
|
+
##### Quantization Primitives ######
|
|
23
|
+
|
|
24
|
+
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
|
25
|
+
# assumes symmetric quantization
|
|
26
|
+
# assumes axis == 0
|
|
27
|
+
# assumes dense memory format
|
|
28
|
+
# TODO(future): relax ^ as needed
|
|
29
|
+
|
|
30
|
+
# default setup for affine quantization of activations
|
|
31
|
+
eps = torch.finfo(torch.float32).eps
|
|
32
|
+
|
|
33
|
+
# get min and max
|
|
34
|
+
min_val, max_val = torch.aminmax(x, dim=1)
|
|
35
|
+
|
|
36
|
+
# calculate scales and zero_points based on min and max
|
|
37
|
+
# reference: https://fburl.com/code/srbiybme
|
|
38
|
+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
|
39
|
+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
|
40
|
+
device = min_val_neg.device
|
|
41
|
+
|
|
42
|
+
# reference: https://fburl.com/code/4wll53rk
|
|
43
|
+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
|
44
|
+
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
|
45
|
+
# ensure scales is the same dtype as the original tensor
|
|
46
|
+
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
|
47
|
+
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
|
48
|
+
|
|
49
|
+
# quantize based on qmin/qmax/scales/zp
|
|
50
|
+
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
|
51
|
+
x_div = x / scales.unsqueeze(-1)
|
|
52
|
+
x_round = torch.round(x_div)
|
|
53
|
+
x_zp = x_round + zero_points.unsqueeze(-1)
|
|
54
|
+
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
|
55
|
+
|
|
56
|
+
return quant, scales, zero_points
|
|
57
|
+
|
|
58
|
+
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
59
|
+
# needed for GPTQ with padding
|
|
60
|
+
if groupsize > w.shape[-1]:
|
|
61
|
+
groupsize = w.shape[-1]
|
|
62
|
+
assert groupsize > 1
|
|
63
|
+
assert w.shape[-1] % groupsize == 0
|
|
64
|
+
assert w.dim() == 2
|
|
65
|
+
|
|
66
|
+
to_quant = w.reshape(-1, groupsize)
|
|
67
|
+
assert torch.isnan(to_quant).sum() == 0
|
|
68
|
+
|
|
69
|
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
70
|
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
71
|
+
max_int = 2**n_bit - 1
|
|
72
|
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
73
|
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
74
|
+
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
75
|
+
torch.bfloat16
|
|
76
|
+
).reshape(w.shape[0], -1)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def pack_scales_and_zeros(scales, zeros):
|
|
80
|
+
assert scales.shape == zeros.shape
|
|
81
|
+
assert scales.dtype == torch.bfloat16
|
|
82
|
+
assert zeros.dtype == torch.bfloat16
|
|
83
|
+
return (
|
|
84
|
+
torch.cat(
|
|
85
|
+
[
|
|
86
|
+
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
87
|
+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
88
|
+
],
|
|
89
|
+
2,
|
|
90
|
+
)
|
|
91
|
+
.transpose(0, 1)
|
|
92
|
+
.contiguous()
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def unpack_scales_and_zeros(scales_and_zeros):
|
|
97
|
+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
|
98
|
+
assert scales_and_zeros.dtype == torch.float
|
|
99
|
+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
103
|
+
assert groupsize > 1
|
|
104
|
+
# needed for GPTQ single column quantize
|
|
105
|
+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
106
|
+
groupsize = w.shape[-1]
|
|
107
|
+
|
|
108
|
+
assert w.shape[-1] % groupsize == 0
|
|
109
|
+
assert w.dim() == 2
|
|
110
|
+
|
|
111
|
+
to_quant = w.reshape(-1, groupsize)
|
|
112
|
+
assert torch.isnan(to_quant).sum() == 0
|
|
113
|
+
|
|
114
|
+
scales = scales.reshape(-1, 1)
|
|
115
|
+
zeros = zeros.reshape(-1, 1)
|
|
116
|
+
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
117
|
+
max_int = 2**n_bit - 1
|
|
118
|
+
min_int = 0
|
|
119
|
+
w_int32 = (
|
|
120
|
+
to_quant.sub(min_val)
|
|
121
|
+
.div(scales)
|
|
122
|
+
.round()
|
|
123
|
+
.clamp_(min_int, max_int)
|
|
124
|
+
.to(torch.int32)
|
|
125
|
+
.reshape_as(w)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return w_int32
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
132
|
+
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
133
|
+
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
134
|
+
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
135
|
+
return w_int32, scales_and_zeros
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def group_dequantize_tensor_from_qparams(
|
|
139
|
+
w_int32, scales, zeros, n_bit=4, groupsize=128
|
|
140
|
+
):
|
|
141
|
+
assert groupsize > 1
|
|
142
|
+
# needed for GPTQ single column dequantize
|
|
143
|
+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
|
144
|
+
groupsize = w_int32.shape[-1]
|
|
145
|
+
assert w_int32.shape[-1] % groupsize == 0
|
|
146
|
+
assert w_int32.dim() == 2
|
|
147
|
+
|
|
148
|
+
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
|
149
|
+
scales = scales.reshape(-1, 1)
|
|
150
|
+
zeros = zeros.reshape(-1, 1)
|
|
151
|
+
|
|
152
|
+
w_dq = (
|
|
153
|
+
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
|
154
|
+
)
|
|
155
|
+
return w_dq
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
|
159
|
+
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
|
160
|
+
return group_dequantize_tensor_from_qparams(
|
|
161
|
+
w_int32, scales, zeros, n_bit, groupsize
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
class QuantHandler:
|
|
165
|
+
def __init__(self, mod):
|
|
166
|
+
self.mod = mod
|
|
167
|
+
|
|
168
|
+
def create_quantized_state_dict(self) -> "StateDict":
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def convert_for_runtime(self) -> "nn.Module":
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
class GPTQQuantHandler(QuantHandler):
|
|
175
|
+
"""
|
|
176
|
+
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
|
|
177
|
+
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
|
|
178
|
+
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
|
|
179
|
+
|
|
180
|
+
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
|
|
181
|
+
create_quantized_state_dict. Here is a description of each function.
|
|
182
|
+
|
|
183
|
+
get_qparams_func:
|
|
184
|
+
A function that calculates the quantization qparams for an input tensor.
|
|
185
|
+
Args:
|
|
186
|
+
weight: A 2d weight tensor with non-integer dtype.
|
|
187
|
+
Returns:
|
|
188
|
+
qparams: it can have any format but will need to be handled by the other defined functions below.
|
|
189
|
+
|
|
190
|
+
quantize_func:
|
|
191
|
+
A function that applies quantization to an input tensor. It should be noted
|
|
192
|
+
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
|
|
193
|
+
or a single column.
|
|
194
|
+
Args:
|
|
195
|
+
weight: A 2d weight tensor with non-integer dtype.
|
|
196
|
+
qparams: the output from get_qparams_func
|
|
197
|
+
Returns:
|
|
198
|
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
dequantize_func:
|
|
202
|
+
A function that dequantizes an input quantized weight tensor. It should be noted
|
|
203
|
+
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
|
|
204
|
+
or a single column.
|
|
205
|
+
Args:
|
|
206
|
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
|
207
|
+
qparams: the output from get_qparams_func
|
|
208
|
+
Returns:
|
|
209
|
+
weight: A 2d weight tensor with non-integer dtype.
|
|
210
|
+
|
|
211
|
+
combine_qparams_list_func:
|
|
212
|
+
A function that combines several qparams into one qparam.
|
|
213
|
+
Args:
|
|
214
|
+
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
|
|
215
|
+
on a single group from a weight tensor
|
|
216
|
+
Returns:
|
|
217
|
+
qparams: an object of the same format as the qparams above.
|
|
218
|
+
|
|
219
|
+
skip_layer_func:
|
|
220
|
+
A function that determines which linear layers should be skipped during GPTQ
|
|
221
|
+
Args:
|
|
222
|
+
weight: A 2d weight tensor with non-integer dtype.
|
|
223
|
+
Returns:
|
|
224
|
+
skip: boolean indicating whether layer should be skipped
|
|
225
|
+
|
|
226
|
+
make_names_and_values_dict_func:
|
|
227
|
+
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
|
|
228
|
+
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
|
|
229
|
+
Args:
|
|
230
|
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
|
231
|
+
qparams: the output from get_qparams_func
|
|
232
|
+
Returns:
|
|
233
|
+
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
|
|
234
|
+
corresponding quantized weights and qparams.
|
|
235
|
+
"""
|
|
236
|
+
def __init__(self):
|
|
237
|
+
assert self.mod is not None
|
|
238
|
+
assert self.get_qparams_func is not None
|
|
239
|
+
assert self.quantize_func is not None
|
|
240
|
+
assert self.dequantize_func is not None
|
|
241
|
+
assert self.combine_qparams_list_func is not None
|
|
242
|
+
assert self.make_names_and_values_dict_func is not None
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
|
|
246
|
+
input_recorder = InputRecorder(
|
|
247
|
+
model,
|
|
248
|
+
tokenizer,
|
|
249
|
+
calibration_seq_length,
|
|
250
|
+
pad_calibration_inputs,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
lm_eval.tasks.initialize_tasks()
|
|
255
|
+
except:
|
|
256
|
+
pass
|
|
257
|
+
task_dict = get_task_dict(calibration_tasks)
|
|
258
|
+
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
|
|
259
|
+
|
|
260
|
+
evaluate(
|
|
261
|
+
input_recorder,
|
|
262
|
+
task_dict,
|
|
263
|
+
limit=calibration_limit,
|
|
264
|
+
)
|
|
265
|
+
inputs = input_recorder.get_recorded_inputs()
|
|
266
|
+
assert inputs is not None, (
|
|
267
|
+
f"No inputs were collected, use a task other than {calibration_tasks}, "+
|
|
268
|
+
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
|
|
269
|
+
f"{calibration_seq_length})"
|
|
270
|
+
)
|
|
271
|
+
print(f"Obtained {len(inputs[0].values)} calibration samples")
|
|
272
|
+
return inputs
|
|
273
|
+
|
|
274
|
+
@torch.no_grad()
|
|
275
|
+
def create_quantized_state_dict(
|
|
276
|
+
self,
|
|
277
|
+
tokenizer,
|
|
278
|
+
blocksize,
|
|
279
|
+
percdamp,
|
|
280
|
+
groupsize,
|
|
281
|
+
calibration_tasks,
|
|
282
|
+
calibration_limit,
|
|
283
|
+
calibration_seq_length,
|
|
284
|
+
pad_calibration_inputs,
|
|
285
|
+
) -> "StateDict":
|
|
286
|
+
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
|
|
287
|
+
print("Tracing model for GPTQ")
|
|
288
|
+
GPTQ_runner = GenericGPTQRunner(
|
|
289
|
+
self.mod,
|
|
290
|
+
inputs,
|
|
291
|
+
blocksize,
|
|
292
|
+
percdamp,
|
|
293
|
+
groupsize,
|
|
294
|
+
).configure_quantization_mode(
|
|
295
|
+
self.get_qparams_func,
|
|
296
|
+
self.quantize_func,
|
|
297
|
+
self.dequantize_func,
|
|
298
|
+
self.combine_qparams_list_func,
|
|
299
|
+
self.make_names_and_values_dict_func,
|
|
300
|
+
self.skip_layer_func
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
print("Applying GPTQ to weights")
|
|
304
|
+
GPTQ_runner.run()
|
|
305
|
+
return GPTQ_runner.get_quantized_state_dict()
|
|
306
|
+
|
|
307
|
+
def convert_for_runtime(self) -> "nn.Module":
|
|
308
|
+
pass
|
|
309
|
+
|
|
310
|
+
##### Weight-only int8 per-channel quantized code ######
|
|
311
|
+
|
|
312
|
+
def replace_linear_weight_only_int8_per_channel(module):
|
|
313
|
+
for name, child in module.named_children():
|
|
314
|
+
if isinstance(child, nn.Linear):
|
|
315
|
+
setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
|
|
316
|
+
else:
|
|
317
|
+
replace_linear_weight_only_int8_per_channel(child)
|
|
318
|
+
|
|
319
|
+
class WeightOnlyInt8QuantHandler:
|
|
320
|
+
def __init__(self, mod):
|
|
321
|
+
self.mod = mod
|
|
322
|
+
|
|
323
|
+
@torch.no_grad()
|
|
324
|
+
def create_quantized_state_dict(self):
|
|
325
|
+
cur_state_dict = self.mod.state_dict()
|
|
326
|
+
for fqn, mod in self.mod.named_modules():
|
|
327
|
+
if isinstance(mod, torch.nn.Linear):
|
|
328
|
+
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
|
|
329
|
+
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
|
330
|
+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
|
331
|
+
|
|
332
|
+
return cur_state_dict
|
|
333
|
+
|
|
334
|
+
def convert_for_runtime(self):
|
|
335
|
+
replace_linear_weight_only_int8_per_channel(self.mod)
|
|
336
|
+
return self.mod
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class WeightOnlyInt8Linear(torch.nn.Module):
|
|
340
|
+
__constants__ = ['in_features', 'out_features']
|
|
341
|
+
in_features: int
|
|
342
|
+
out_features: int
|
|
343
|
+
weight: torch.Tensor
|
|
344
|
+
|
|
345
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
|
346
|
+
device=None, dtype=None) -> None:
|
|
347
|
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
348
|
+
super().__init__()
|
|
349
|
+
self.in_features = in_features
|
|
350
|
+
self.out_features = out_features
|
|
351
|
+
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
|
|
352
|
+
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
|
353
|
+
|
|
354
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
|
356
|
+
|
|
357
|
+
##### weight only int4 per channel groupwise quantized code ######
|
|
358
|
+
|
|
359
|
+
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
360
|
+
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
361
|
+
weight_bf16, n_bit=4, groupsize=groupsize
|
|
362
|
+
)
|
|
363
|
+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
|
364
|
+
return weight_int4pack, scales_and_zeros
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
368
|
+
origin_x_size = x.size()
|
|
369
|
+
x = x.reshape(-1, origin_x_size[-1])
|
|
370
|
+
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
|
371
|
+
new_shape = origin_x_size[:-1] + (out_features,)
|
|
372
|
+
c = c.reshape(new_shape)
|
|
373
|
+
return c
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
|
|
377
|
+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
|
378
|
+
|
|
379
|
+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
|
380
|
+
for name, child in module.named_children():
|
|
381
|
+
if isinstance(child, nn.Linear):
|
|
382
|
+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
|
383
|
+
setattr(module, name, WeightOnlyInt4Linear(
|
|
384
|
+
child.in_features, child.out_features, bias=False,
|
|
385
|
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
|
|
386
|
+
))
|
|
387
|
+
elif padding:
|
|
388
|
+
setattr(module, name, WeightOnlyInt4Linear(
|
|
389
|
+
child.in_features, child.out_features, bias=False,
|
|
390
|
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
|
|
391
|
+
))
|
|
392
|
+
else:
|
|
393
|
+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class WeightOnlyInt4QuantHandler:
|
|
397
|
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
|
398
|
+
self.mod = mod
|
|
399
|
+
self.groupsize = groupsize
|
|
400
|
+
self.inner_k_tiles = inner_k_tiles
|
|
401
|
+
self.padding = padding
|
|
402
|
+
assert groupsize in [32, 64, 128, 256]
|
|
403
|
+
assert inner_k_tiles in [2, 4, 8]
|
|
404
|
+
|
|
405
|
+
@torch.no_grad()
|
|
406
|
+
def create_quantized_state_dict(self, use_cuda = True):
|
|
407
|
+
if use_cuda:
|
|
408
|
+
device="cuda"
|
|
409
|
+
else:
|
|
410
|
+
device="cpu"
|
|
411
|
+
|
|
412
|
+
cur_state_dict = self.mod.state_dict()
|
|
413
|
+
for fqn, mod in self.mod.named_modules():
|
|
414
|
+
if isinstance(mod, torch.nn.Linear):
|
|
415
|
+
assert not mod.bias
|
|
416
|
+
out_features = mod.out_features
|
|
417
|
+
in_features = mod.in_features
|
|
418
|
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
419
|
+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
|
420
|
+
|
|
421
|
+
weight = mod.weight.data
|
|
422
|
+
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
|
|
423
|
+
if self.padding:
|
|
424
|
+
from model import find_multiple
|
|
425
|
+
import torch.nn.functional as F
|
|
426
|
+
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
|
|
427
|
+
padded_in_features = find_multiple(in_features, 1024)
|
|
428
|
+
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
|
|
429
|
+
else:
|
|
430
|
+
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
|
|
431
|
+
"and that groupsize and inner_k_tiles*16 evenly divide into it")
|
|
432
|
+
continue
|
|
433
|
+
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
|
434
|
+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
|
|
435
|
+
)
|
|
436
|
+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
|
|
437
|
+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
|
|
438
|
+
|
|
439
|
+
return cur_state_dict
|
|
440
|
+
|
|
441
|
+
def convert_for_runtime(self):
|
|
442
|
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
|
443
|
+
return self.mod
|
|
444
|
+
|
|
445
|
+
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
|
|
446
|
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
|
447
|
+
from model import find_multiple
|
|
448
|
+
self.mod = mod
|
|
449
|
+
self.groupsize = groupsize
|
|
450
|
+
self.inner_k_tiles = inner_k_tiles
|
|
451
|
+
self.padding = padding
|
|
452
|
+
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
|
|
453
|
+
self.quantize_func = lambda w, qparams: \
|
|
454
|
+
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
|
|
455
|
+
self.dequantize_func = lambda q, qparams: \
|
|
456
|
+
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
|
|
457
|
+
self.combine_qparams_list_func = lambda qparams_list: \
|
|
458
|
+
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
|
|
459
|
+
# skip unless padding=True or its correctly sized
|
|
460
|
+
self.skip_layer_func = lambda linear_weight: not (
|
|
461
|
+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
|
|
462
|
+
)
|
|
463
|
+
# we need to do the padding here, both for q and the qparams if necessary
|
|
464
|
+
def make_names_and_values_dict_func(q, qparams):
|
|
465
|
+
k = q.shape[1]
|
|
466
|
+
new_k = find_multiple(k, 1024)
|
|
467
|
+
# how much we need to pad the weight
|
|
468
|
+
delta_k = new_k - q.shape[1]
|
|
469
|
+
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
|
|
470
|
+
scales_and_zeros = pack_scales_and_zeros(*qparams)
|
|
471
|
+
# how many new groups we need for padded weight
|
|
472
|
+
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
|
|
473
|
+
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
|
|
474
|
+
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
|
|
475
|
+
self.make_names_and_values_dict_func = make_names_and_values_dict_func
|
|
476
|
+
super().__init__()
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def convert_for_runtime(self):
|
|
480
|
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
|
481
|
+
return self.mod
|
|
482
|
+
|
|
483
|
+
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
484
|
+
__constants__ = ['in_features', 'out_features']
|
|
485
|
+
in_features: int
|
|
486
|
+
out_features: int
|
|
487
|
+
weight: torch.Tensor
|
|
488
|
+
|
|
489
|
+
def __init__(
|
|
490
|
+
self, in_features: int, out_features: int,
|
|
491
|
+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
|
|
492
|
+
) -> None:
|
|
493
|
+
super().__init__()
|
|
494
|
+
self.padding = padding
|
|
495
|
+
if padding:
|
|
496
|
+
from model import find_multiple
|
|
497
|
+
self.origin_in_features = in_features
|
|
498
|
+
in_features = find_multiple(in_features, 1024)
|
|
499
|
+
|
|
500
|
+
self.in_features = in_features
|
|
501
|
+
self.out_features = out_features
|
|
502
|
+
assert not bias, "require bias=False"
|
|
503
|
+
self.groupsize = groupsize
|
|
504
|
+
self.inner_k_tiles = inner_k_tiles
|
|
505
|
+
|
|
506
|
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
507
|
+
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
|
508
|
+
self.register_buffer(
|
|
509
|
+
"weight",
|
|
510
|
+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
|
511
|
+
)
|
|
512
|
+
self.register_buffer(
|
|
513
|
+
"scales_and_zeros",
|
|
514
|
+
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
518
|
+
input = input.to(torch.bfloat16)
|
|
519
|
+
if self.padding:
|
|
520
|
+
import torch.nn.functional as F
|
|
521
|
+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
|
522
|
+
return linear_forward_int4(
|
|
523
|
+
input,
|
|
524
|
+
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def quantize(
|
|
529
|
+
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
|
530
|
+
mode: str = 'int8',
|
|
531
|
+
# following arguments only available when setting int4 quantization.
|
|
532
|
+
groupsize: int = 128,
|
|
533
|
+
# following arguments only used for GPTQ
|
|
534
|
+
calibration_tasks: list = ["hellaswag"],
|
|
535
|
+
calibration_limit: int = 1000,
|
|
536
|
+
calibration_seq_length: int = 100,
|
|
537
|
+
pad_calibration_inputs: bool = False,
|
|
538
|
+
percdamp: float = .01,
|
|
539
|
+
blocksize: int = 128,
|
|
540
|
+
label: str = '',
|
|
541
|
+
) -> None:
|
|
542
|
+
assert checkpoint_path.is_file(), checkpoint_path
|
|
543
|
+
|
|
544
|
+
device = 'cpu'
|
|
545
|
+
precision = torch.bfloat16
|
|
546
|
+
|
|
547
|
+
print("Loading model ...")
|
|
548
|
+
t0 = time.time()
|
|
549
|
+
|
|
550
|
+
with torch.device('meta'):
|
|
551
|
+
model = Transformer.from_name(checkpoint_path.parent.name)
|
|
552
|
+
|
|
553
|
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
|
554
|
+
model.load_state_dict(checkpoint, assign=True)
|
|
555
|
+
model = model.to(dtype=precision, device=device)
|
|
556
|
+
|
|
557
|
+
if mode == 'int8':
|
|
558
|
+
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
|
|
559
|
+
quant_handler = WeightOnlyInt8QuantHandler(model)
|
|
560
|
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
|
561
|
+
|
|
562
|
+
dir_name = checkpoint_path.parent
|
|
563
|
+
base_name = checkpoint_path.name
|
|
564
|
+
new_base_name = base_name.replace('.pth', f'{label}int8.pth')
|
|
565
|
+
|
|
566
|
+
elif mode == 'int4':
|
|
567
|
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
|
|
568
|
+
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
|
569
|
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
|
570
|
+
|
|
571
|
+
dir_name = checkpoint_path.parent
|
|
572
|
+
base_name = checkpoint_path.name
|
|
573
|
+
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
|
|
574
|
+
|
|
575
|
+
elif mode == 'int4-gptq':
|
|
576
|
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
|
|
577
|
+
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
|
|
578
|
+
|
|
579
|
+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
|
580
|
+
assert tokenizer_path.is_file(), str(tokenizer_path)
|
|
581
|
+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
|
582
|
+
|
|
583
|
+
quantized_state_dict = quant_handler.create_quantized_state_dict(
|
|
584
|
+
tokenizer,
|
|
585
|
+
blocksize,
|
|
586
|
+
percdamp,
|
|
587
|
+
groupsize,
|
|
588
|
+
calibration_tasks,
|
|
589
|
+
calibration_limit,
|
|
590
|
+
calibration_seq_length,
|
|
591
|
+
pad_calibration_inputs
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
dir_name = checkpoint_path.parent
|
|
595
|
+
base_name = checkpoint_path.name
|
|
596
|
+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
|
|
597
|
+
else:
|
|
598
|
+
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
|
|
599
|
+
|
|
600
|
+
quantize_path = dir_name / new_base_name
|
|
601
|
+
print(f"Writing quantized weights to {quantize_path}")
|
|
602
|
+
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
|
603
|
+
torch.save(quantized_state_dict, quantize_path)
|
|
604
|
+
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
|
605
|
+
return
|
|
606
|
+
|
|
607
|
+
if __name__ == '__main__':
|
|
608
|
+
import argparse
|
|
609
|
+
parser = argparse.ArgumentParser(description='Quantize a model.')
|
|
610
|
+
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
|
|
611
|
+
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
|
|
612
|
+
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
|
|
613
|
+
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
|
|
614
|
+
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
|
|
615
|
+
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
|
|
616
|
+
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
|
|
617
|
+
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
|
|
618
|
+
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
|
|
619
|
+
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
|
|
620
|
+
|
|
621
|
+
args = parser.parse_args()
|
|
622
|
+
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from torch.nn.utils import weight_norm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConvRNNF0Predictor(nn.Module):
|
|
20
|
+
def __init__(self,
|
|
21
|
+
num_class: int = 1,
|
|
22
|
+
in_channels: int = 80,
|
|
23
|
+
cond_channels: int = 512
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
self.num_class = num_class
|
|
28
|
+
self.condnet = nn.Sequential(
|
|
29
|
+
weight_norm(
|
|
30
|
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
|
31
|
+
),
|
|
32
|
+
nn.ELU(),
|
|
33
|
+
weight_norm(
|
|
34
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
35
|
+
),
|
|
36
|
+
nn.ELU(),
|
|
37
|
+
weight_norm(
|
|
38
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
39
|
+
),
|
|
40
|
+
nn.ELU(),
|
|
41
|
+
weight_norm(
|
|
42
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
43
|
+
),
|
|
44
|
+
nn.ELU(),
|
|
45
|
+
weight_norm(
|
|
46
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
47
|
+
),
|
|
48
|
+
nn.ELU(),
|
|
49
|
+
)
|
|
50
|
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
|
51
|
+
|
|
52
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
x = self.condnet(x)
|
|
54
|
+
x = x.transpose(1, 2)
|
|
55
|
+
return torch.abs(self.classifier(x).squeeze(-1))
|