xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +11 -28
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/core/supervisor.py +87 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +38 -1
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +464 -2
- xinference/model/llm/sglang/core.py +30 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +12 -9
- xinference/model/llm/vllm/core.py +93 -17
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
|
2
|
+
# LICENSE is in incl_licenses directory.
|
|
3
|
+
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import matplotlib
|
|
8
|
+
import matplotlib.pylab as plt
|
|
9
|
+
import torch
|
|
10
|
+
from scipy.io.wavfile import write
|
|
11
|
+
from torch.nn.utils import weight_norm
|
|
12
|
+
|
|
13
|
+
matplotlib.use("Agg")
|
|
14
|
+
|
|
15
|
+
MAX_WAV_VALUE = 32768.0
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def plot_spectrogram(spectrogram):
|
|
19
|
+
fig, ax = plt.subplots(figsize=(10, 2))
|
|
20
|
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
21
|
+
plt.colorbar(im, ax=ax)
|
|
22
|
+
|
|
23
|
+
fig.canvas.draw()
|
|
24
|
+
plt.close()
|
|
25
|
+
|
|
26
|
+
return fig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
|
30
|
+
fig, ax = plt.subplots(figsize=(10, 2))
|
|
31
|
+
im = ax.imshow(
|
|
32
|
+
spectrogram,
|
|
33
|
+
aspect="auto",
|
|
34
|
+
origin="lower",
|
|
35
|
+
interpolation="none",
|
|
36
|
+
vmin=1e-6,
|
|
37
|
+
vmax=clip_max,
|
|
38
|
+
)
|
|
39
|
+
plt.colorbar(im, ax=ax)
|
|
40
|
+
|
|
41
|
+
fig.canvas.draw()
|
|
42
|
+
plt.close()
|
|
43
|
+
|
|
44
|
+
return fig
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def init_weights(m, mean=0.0, std=0.01):
|
|
48
|
+
classname = m.__class__.__name__
|
|
49
|
+
if classname.find("Conv") != -1:
|
|
50
|
+
m.weight.data.normal_(mean, std)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def apply_weight_norm(m):
|
|
54
|
+
classname = m.__class__.__name__
|
|
55
|
+
if classname.find("Conv") != -1:
|
|
56
|
+
weight_norm(m)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_padding(kernel_size, dilation=1):
|
|
60
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def load_checkpoint(filepath, device):
|
|
64
|
+
assert os.path.isfile(filepath)
|
|
65
|
+
print(f"Loading '{filepath}'")
|
|
66
|
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
|
67
|
+
print("Complete.")
|
|
68
|
+
return checkpoint_dict
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def save_checkpoint(filepath, obj):
|
|
72
|
+
print(f"Saving checkpoint to {filepath}")
|
|
73
|
+
torch.save(obj, filepath)
|
|
74
|
+
print("Complete.")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
|
78
|
+
# Fallback to original scanning logic first
|
|
79
|
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
|
80
|
+
cp_list = glob.glob(pattern)
|
|
81
|
+
|
|
82
|
+
if len(cp_list) > 0:
|
|
83
|
+
last_checkpoint_path = sorted(cp_list)[-1]
|
|
84
|
+
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
|
|
85
|
+
return last_checkpoint_path
|
|
86
|
+
|
|
87
|
+
# If no pattern-based checkpoints are found, check for renamed file
|
|
88
|
+
if renamed_file:
|
|
89
|
+
renamed_path = os.path.join(cp_dir, renamed_file)
|
|
90
|
+
if os.path.isfile(renamed_path):
|
|
91
|
+
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
|
|
92
|
+
return renamed_path
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def save_audio(audio, path, sr):
|
|
98
|
+
# wav: torch with 1d shape
|
|
99
|
+
audio = audio * MAX_WAV_VALUE
|
|
100
|
+
audio = audio.cpu().numpy().astype("int16")
|
|
101
|
+
write(path, sr, audio)
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import warnings
|
|
4
|
+
# Suppress warnings from tensorflow and other libraries
|
|
5
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
6
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
7
|
+
def main():
|
|
8
|
+
import argparse
|
|
9
|
+
parser = argparse.ArgumentParser(description="IndexTTS Command Line")
|
|
10
|
+
parser.add_argument("text", type=str, help="Text to be synthesized")
|
|
11
|
+
parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
|
|
12
|
+
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
|
|
13
|
+
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
|
|
14
|
+
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
|
|
15
|
+
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
|
|
16
|
+
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
|
|
17
|
+
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
|
|
18
|
+
args = parser.parse_args()
|
|
19
|
+
if len(args.text.strip()) == 0:
|
|
20
|
+
print("ERROR: Text is empty.")
|
|
21
|
+
parser.print_help()
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
if not os.path.exists(args.voice):
|
|
24
|
+
print(f"Audio prompt file {args.voice} does not exist.")
|
|
25
|
+
parser.print_help()
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
if not os.path.exists(args.config):
|
|
28
|
+
print(f"Config file {args.config} does not exist.")
|
|
29
|
+
parser.print_help()
|
|
30
|
+
sys.exit(1)
|
|
31
|
+
|
|
32
|
+
output_path = args.output_path
|
|
33
|
+
if os.path.exists(output_path):
|
|
34
|
+
if not args.force:
|
|
35
|
+
print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
|
|
36
|
+
parser.print_help()
|
|
37
|
+
sys.exit(1)
|
|
38
|
+
else:
|
|
39
|
+
os.remove(output_path)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
import torch
|
|
43
|
+
except ImportError:
|
|
44
|
+
print("ERROR: PyTorch is not installed. Please install it first.")
|
|
45
|
+
sys.exit(1)
|
|
46
|
+
|
|
47
|
+
if args.device is None:
|
|
48
|
+
if torch.cuda.is_available():
|
|
49
|
+
args.device = "cuda:0"
|
|
50
|
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
51
|
+
args.device = "xpu"
|
|
52
|
+
elif hasattr(torch, "mps") and torch.mps.is_available():
|
|
53
|
+
args.device = "mps"
|
|
54
|
+
else:
|
|
55
|
+
args.device = "cpu"
|
|
56
|
+
args.fp16 = False # Disable FP16 on CPU
|
|
57
|
+
print("WARNING: Running on CPU may be slow.")
|
|
58
|
+
|
|
59
|
+
# TODO: Add CLI support for IndexTTS2.
|
|
60
|
+
from indextts.infer import IndexTTS
|
|
61
|
+
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
|
|
62
|
+
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
|
|
63
|
+
|
|
64
|
+
if __name__ == "__main__":
|
|
65
|
+
main()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
"""Multi-Head Attention layer definition."""
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from typing import Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultiHeadedAttention(nn.Module):
|
|
27
|
+
"""Multi-Head Attention layer.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
n_head (int): The number of heads.
|
|
31
|
+
n_feat (int): The number of features.
|
|
32
|
+
dropout_rate (float): Dropout rate.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
|
36
|
+
"""Construct an MultiHeadedAttention object."""
|
|
37
|
+
super().__init__()
|
|
38
|
+
assert n_feat % n_head == 0
|
|
39
|
+
# We assume d_v always equals d_k
|
|
40
|
+
self.d_k = n_feat // n_head
|
|
41
|
+
self.h = n_head
|
|
42
|
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
|
43
|
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
|
44
|
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
|
45
|
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
46
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
47
|
+
|
|
48
|
+
def forward_qkv(
|
|
49
|
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
50
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
51
|
+
"""Transform query, key and value.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
55
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
56
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
torch.Tensor: Transformed query tensor, size
|
|
60
|
+
(#batch, n_head, time1, d_k).
|
|
61
|
+
torch.Tensor: Transformed key tensor, size
|
|
62
|
+
(#batch, n_head, time2, d_k).
|
|
63
|
+
torch.Tensor: Transformed value tensor, size
|
|
64
|
+
(#batch, n_head, time2, d_k).
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
n_batch = query.size(0)
|
|
68
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
69
|
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
70
|
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
71
|
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
|
72
|
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
|
73
|
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
|
74
|
+
|
|
75
|
+
return q, k, v
|
|
76
|
+
|
|
77
|
+
def forward_attention(
|
|
78
|
+
self, value: torch.Tensor, scores: torch.Tensor,
|
|
79
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
|
80
|
+
) -> torch.Tensor:
|
|
81
|
+
"""Compute attention context vector.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
value (torch.Tensor): Transformed value, size
|
|
85
|
+
(#batch, n_head, time2, d_k).
|
|
86
|
+
scores (torch.Tensor): Attention score, size
|
|
87
|
+
(#batch, n_head, time1, time2).
|
|
88
|
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
|
89
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
93
|
+
weighted by the attention score (#batch, time1, time2).
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
n_batch = value.size(0)
|
|
97
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
|
98
|
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
|
99
|
+
# 1st chunk to ease the onnx export.]
|
|
100
|
+
# 2. pytorch training
|
|
101
|
+
if mask.size(2) > 0 : # time2 > 0
|
|
102
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
103
|
+
# For last chunk, time2 might be larger than scores.size(-1)
|
|
104
|
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
|
105
|
+
scores = scores.masked_fill(mask, -float('inf'))
|
|
106
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
107
|
+
mask, 0.0) # (batch, head, time1, time2)
|
|
108
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
|
109
|
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
110
|
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
111
|
+
else:
|
|
112
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
113
|
+
|
|
114
|
+
p_attn = self.dropout(attn)
|
|
115
|
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
116
|
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
|
117
|
+
self.h * self.d_k)
|
|
118
|
+
) # (batch, time1, d_model)
|
|
119
|
+
|
|
120
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
|
121
|
+
|
|
122
|
+
def forward(self, query: torch.Tensor, key: torch.Tensor,
|
|
123
|
+
value: torch.Tensor,
|
|
124
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
125
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
126
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
127
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
128
|
+
"""Compute scaled dot product attention.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
132
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
133
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
134
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
135
|
+
(#batch, time1, time2).
|
|
136
|
+
1.When applying cross attention between decoder and encoder,
|
|
137
|
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
138
|
+
2.When applying self attention of encoder,
|
|
139
|
+
the mask is in (#batch, T, T) shape.
|
|
140
|
+
3.When applying self attention of decoder,
|
|
141
|
+
the mask is in (#batch, L, L) shape.
|
|
142
|
+
4.If the different position in decoder see different block
|
|
143
|
+
of the encoder, such as Mocha, the passed in mask could be
|
|
144
|
+
in (#batch, L, T) shape. But there is no such case in current
|
|
145
|
+
Wenet.
|
|
146
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
147
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
148
|
+
and `head * d_k == size`
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
153
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
154
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
155
|
+
and `head * d_k == size`
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
159
|
+
|
|
160
|
+
# NOTE(xcsong):
|
|
161
|
+
# when export onnx model, for 1st chunk, we feed
|
|
162
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
163
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
164
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
165
|
+
# and we will always do splitting and
|
|
166
|
+
# concatnation(this will simplify onnx export). Note that
|
|
167
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
168
|
+
# when export jit model, for 1st chunk, we always feed
|
|
169
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
170
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
171
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
172
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
173
|
+
# >>> torch.equal(b, c) # True
|
|
174
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
175
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
176
|
+
if cache.size(0) > 0:
|
|
177
|
+
key_cache, value_cache = torch.split(
|
|
178
|
+
cache, cache.size(-1) // 2, dim=-1)
|
|
179
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
180
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
181
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
182
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
183
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
184
|
+
|
|
185
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
186
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
190
|
+
"""Multi-Head Attention layer with relative position encoding.
|
|
191
|
+
Paper: https://arxiv.org/abs/1901.02860
|
|
192
|
+
Args:
|
|
193
|
+
n_head (int): The number of heads.
|
|
194
|
+
n_feat (int): The number of features.
|
|
195
|
+
dropout_rate (float): Dropout rate.
|
|
196
|
+
"""
|
|
197
|
+
def __init__(self, n_head, n_feat, dropout_rate):
|
|
198
|
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
|
199
|
+
super().__init__(n_head, n_feat, dropout_rate)
|
|
200
|
+
# linear transformation for positional encoding
|
|
201
|
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
|
202
|
+
# these two learnable bias are used in matrix c and matrix d
|
|
203
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
204
|
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
205
|
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
206
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
207
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
208
|
+
|
|
209
|
+
def rel_shift(self, x, zero_triu: bool = False):
|
|
210
|
+
"""Compute relative positinal encoding.
|
|
211
|
+
Args:
|
|
212
|
+
x (torch.Tensor): Input tensor (batch, time, size).
|
|
213
|
+
zero_triu (bool): If true, return the lower triangular part of
|
|
214
|
+
the matrix.
|
|
215
|
+
Returns:
|
|
216
|
+
torch.Tensor: Output tensor.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
|
220
|
+
device=x.device,
|
|
221
|
+
dtype=x.dtype)
|
|
222
|
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
223
|
+
|
|
224
|
+
x_padded = x_padded.view(x.size()[0],
|
|
225
|
+
x.size()[1],
|
|
226
|
+
x.size(3) + 1, x.size(2))
|
|
227
|
+
x = x_padded[:, :, 1:].view_as(x)
|
|
228
|
+
|
|
229
|
+
if zero_triu:
|
|
230
|
+
ones = torch.ones((x.size(2), x.size(3)))
|
|
231
|
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
|
232
|
+
|
|
233
|
+
return x
|
|
234
|
+
|
|
235
|
+
def forward(self, query: torch.Tensor,
|
|
236
|
+
key: torch.Tensor, value: torch.Tensor,
|
|
237
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
238
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
239
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
240
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
241
|
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
242
|
+
Args:
|
|
243
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
244
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
245
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
246
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
247
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
248
|
+
pos_emb (torch.Tensor): Positional embedding tensor
|
|
249
|
+
(#batch, time2, size).
|
|
250
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
251
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
252
|
+
and `head * d_k == size`
|
|
253
|
+
Returns:
|
|
254
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
255
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
256
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
257
|
+
and `head * d_k == size`
|
|
258
|
+
"""
|
|
259
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
260
|
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
|
261
|
+
|
|
262
|
+
# NOTE(xcsong):
|
|
263
|
+
# when export onnx model, for 1st chunk, we feed
|
|
264
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
265
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
266
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
267
|
+
# and we will always do splitting and
|
|
268
|
+
# concatnation(this will simplify onnx export). Note that
|
|
269
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
270
|
+
# when export jit model, for 1st chunk, we always feed
|
|
271
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
272
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
273
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
274
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
275
|
+
# >>> torch.equal(b, c) # True
|
|
276
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
277
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
278
|
+
if cache.size(0) > 0:
|
|
279
|
+
key_cache, value_cache = torch.split(
|
|
280
|
+
cache, cache.size(-1) // 2, dim=-1)
|
|
281
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
282
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
283
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
284
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
285
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
286
|
+
|
|
287
|
+
n_batch_pos = pos_emb.size(0)
|
|
288
|
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
289
|
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
|
290
|
+
|
|
291
|
+
# (batch, head, time1, d_k)
|
|
292
|
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
|
293
|
+
# (batch, head, time1, d_k)
|
|
294
|
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
|
295
|
+
|
|
296
|
+
# compute attention score
|
|
297
|
+
# first compute matrix a and matrix c
|
|
298
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
299
|
+
# (batch, head, time1, time2)
|
|
300
|
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
|
301
|
+
|
|
302
|
+
# compute matrix b and matrix d
|
|
303
|
+
# (batch, head, time1, time2)
|
|
304
|
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
|
305
|
+
# Remove rel_shift since it is useless in speech recognition,
|
|
306
|
+
# and it requires special attention for streaming.
|
|
307
|
+
# matrix_bd = self.rel_shift(matrix_bd)
|
|
308
|
+
|
|
309
|
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
310
|
+
self.d_k) # (batch, head, time1, time2)
|
|
311
|
+
|
|
312
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
15
|
+
|
|
16
|
+
"""Positonal Encoding Module."""
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from typing import Tuple, Union
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PositionalEncoding(torch.nn.Module):
|
|
26
|
+
"""Positional encoding.
|
|
27
|
+
|
|
28
|
+
:param int d_model: embedding dim
|
|
29
|
+
:param float dropout_rate: dropout rate
|
|
30
|
+
:param int max_len: maximum input length
|
|
31
|
+
|
|
32
|
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
|
33
|
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self,
|
|
36
|
+
d_model: int,
|
|
37
|
+
dropout_rate: float,
|
|
38
|
+
max_len: int = 5000,
|
|
39
|
+
reverse: bool = False):
|
|
40
|
+
"""Construct an PositionalEncoding object."""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.d_model = d_model
|
|
43
|
+
self.xscale = math.sqrt(self.d_model)
|
|
44
|
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
45
|
+
self.max_len = max_len
|
|
46
|
+
|
|
47
|
+
pe = torch.zeros(self.max_len, self.d_model)
|
|
48
|
+
position = torch.arange(0, self.max_len).unsqueeze(1)
|
|
49
|
+
div_term = torch.exp(
|
|
50
|
+
torch.arange(0, self.d_model, 2) *
|
|
51
|
+
-(math.log(10000.0) / self.d_model))
|
|
52
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
53
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
54
|
+
pe = pe.unsqueeze(0)
|
|
55
|
+
self.register_buffer('pe', pe)
|
|
56
|
+
|
|
57
|
+
def forward(self,
|
|
58
|
+
x: torch.Tensor,
|
|
59
|
+
offset: Union[int, torch.Tensor] = 0) \
|
|
60
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
|
61
|
+
"""Add positional encoding.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
|
65
|
+
offset (int, torch.tensor): position offset
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
|
69
|
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
self.pe = self.pe.to(x.device)
|
|
73
|
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
|
74
|
+
x = x * self.xscale + pos_emb
|
|
75
|
+
return self.dropout(x), self.dropout(pos_emb)
|
|
76
|
+
|
|
77
|
+
def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
|
|
78
|
+
apply_dropout: bool = True) -> torch.Tensor:
|
|
79
|
+
""" For getting encoding in a streaming fashion
|
|
80
|
+
|
|
81
|
+
Attention!!!!!
|
|
82
|
+
we apply dropout only once at the whole utterance level in a none
|
|
83
|
+
streaming way, but will call this function several times with
|
|
84
|
+
increasing input size in a streaming scenario, so the dropout will
|
|
85
|
+
be applied several times.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
offset (int or torch.tensor): start offset
|
|
89
|
+
size (int): required size of position encoding
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
torch.Tensor: Corresponding encoding
|
|
93
|
+
"""
|
|
94
|
+
# How to subscript a Union type:
|
|
95
|
+
# https://github.com/pytorch/pytorch/issues/69434
|
|
96
|
+
if isinstance(offset, int):
|
|
97
|
+
assert offset + size < self.max_len
|
|
98
|
+
pos_emb = self.pe[:, offset:offset + size]
|
|
99
|
+
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
|
100
|
+
assert offset + size < self.max_len
|
|
101
|
+
pos_emb = self.pe[:, offset:offset + size]
|
|
102
|
+
else: # for batched streaming decoding on GPU
|
|
103
|
+
assert torch.max(offset) + size < self.max_len
|
|
104
|
+
index = offset.unsqueeze(1) + \
|
|
105
|
+
torch.arange(0, size).to(offset.device) # B X T
|
|
106
|
+
flag = index > 0
|
|
107
|
+
# remove negative offset
|
|
108
|
+
index = index * flag
|
|
109
|
+
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
|
110
|
+
|
|
111
|
+
if apply_dropout:
|
|
112
|
+
pos_emb = self.dropout(pos_emb)
|
|
113
|
+
return pos_emb
|
|
114
|
+
|
|
115
|
+
class RelPositionalEncoding(PositionalEncoding):
|
|
116
|
+
"""Relative positional encoding module.
|
|
117
|
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
|
118
|
+
Args:
|
|
119
|
+
d_model (int): Embedding dimension.
|
|
120
|
+
dropout_rate (float): Dropout rate.
|
|
121
|
+
max_len (int): Maximum input length.
|
|
122
|
+
"""
|
|
123
|
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
|
124
|
+
"""Initialize class."""
|
|
125
|
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
|
126
|
+
|
|
127
|
+
def forward(self,
|
|
128
|
+
x: torch.Tensor,
|
|
129
|
+
offset: Union[int, torch.Tensor] = 0) \
|
|
130
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
|
131
|
+
"""Compute positional encoding.
|
|
132
|
+
Args:
|
|
133
|
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
|
134
|
+
Returns:
|
|
135
|
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
|
136
|
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
|
137
|
+
"""
|
|
138
|
+
self.pe = self.pe.to(x.device)
|
|
139
|
+
x = x * self.xscale
|
|
140
|
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
|
141
|
+
return self.dropout(x), self.dropout(pos_emb)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class NoPositionalEncoding(torch.nn.Module):
|
|
145
|
+
""" No position encoding
|
|
146
|
+
"""
|
|
147
|
+
def __init__(self, d_model: int, dropout_rate: float):
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.d_model = d_model
|
|
150
|
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
151
|
+
|
|
152
|
+
def forward(self,
|
|
153
|
+
x: torch.Tensor,
|
|
154
|
+
offset: Union[int, torch.Tensor] = 0) \
|
|
155
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
|
156
|
+
""" Just return zero vector for interface compatibility
|
|
157
|
+
"""
|
|
158
|
+
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
|
159
|
+
return self.dropout(x), pos_emb
|
|
160
|
+
|
|
161
|
+
def position_encoding(
|
|
162
|
+
self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
|
|
163
|
+
return torch.zeros(1, size, self.d_model)
|