xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +400 -3
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +111 -49
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +26 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +58 -1
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +71 -3
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +503 -21
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +32 -55
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +138 -53
- xinference/model/llm/vllm/core.py +95 -78
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
|
|
8
|
+
RelPositionMultiHeadedAttention)
|
|
9
|
+
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
|
|
10
|
+
PositionalEncoding,
|
|
11
|
+
RelPositionalEncoding)
|
|
12
|
+
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
|
|
13
|
+
Conv2dSubsampling4,
|
|
14
|
+
Conv2dSubsampling6,
|
|
15
|
+
Conv2dSubsampling8,
|
|
16
|
+
LinearNoSubsampling)
|
|
17
|
+
from indextts.utils.common import make_pad_mask
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PositionwiseFeedForward(torch.nn.Module):
|
|
21
|
+
"""Positionwise feed forward layer.
|
|
22
|
+
|
|
23
|
+
FeedForward are appied on each position of the sequence.
|
|
24
|
+
The output dim is same with the input dim.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
idim (int): Input dimenstion.
|
|
28
|
+
hidden_units (int): The number of hidden units.
|
|
29
|
+
dropout_rate (float): Dropout rate.
|
|
30
|
+
activation (torch.nn.Module): Activation function
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
idim: int,
|
|
35
|
+
hidden_units: int,
|
|
36
|
+
dropout_rate: float,
|
|
37
|
+
activation: torch.nn.Module = torch.nn.ReLU()):
|
|
38
|
+
"""Construct a PositionwiseFeedForward object."""
|
|
39
|
+
super(PositionwiseFeedForward, self).__init__()
|
|
40
|
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
|
41
|
+
self.activation = activation
|
|
42
|
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
43
|
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
|
44
|
+
|
|
45
|
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Forward function.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
xs: input tensor (B, L, D)
|
|
50
|
+
Returns:
|
|
51
|
+
output tensor, (B, L, D)
|
|
52
|
+
"""
|
|
53
|
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ConvolutionModule(nn.Module):
|
|
57
|
+
"""ConvolutionModule in Conformer model."""
|
|
58
|
+
|
|
59
|
+
def __init__(self,
|
|
60
|
+
channels: int,
|
|
61
|
+
kernel_size: int = 15,
|
|
62
|
+
activation: nn.Module = nn.ReLU(),
|
|
63
|
+
bias: bool = True):
|
|
64
|
+
"""Construct an ConvolutionModule object.
|
|
65
|
+
Args:
|
|
66
|
+
channels (int): The number of channels of conv layers.
|
|
67
|
+
kernel_size (int): Kernel size of conv layers.
|
|
68
|
+
causal (int): Whether use causal convolution or not
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
|
|
72
|
+
self.pointwise_conv1 = nn.Conv1d(
|
|
73
|
+
channels,
|
|
74
|
+
2 * channels,
|
|
75
|
+
kernel_size=1,
|
|
76
|
+
stride=1,
|
|
77
|
+
padding=0,
|
|
78
|
+
bias=bias,
|
|
79
|
+
)
|
|
80
|
+
# self.lorder is used to distinguish if it's a causal convolution,
|
|
81
|
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
|
82
|
+
# padded with self.lorder frames on the left in forward.
|
|
83
|
+
# else: it's a symmetrical convolution
|
|
84
|
+
# kernel_size should be an odd number for none causal convolution
|
|
85
|
+
assert (kernel_size - 1) % 2 == 0
|
|
86
|
+
padding = (kernel_size - 1) // 2
|
|
87
|
+
self.lorder = 0
|
|
88
|
+
|
|
89
|
+
self.depthwise_conv = nn.Conv1d(
|
|
90
|
+
channels,
|
|
91
|
+
channels,
|
|
92
|
+
kernel_size,
|
|
93
|
+
stride=1,
|
|
94
|
+
padding=padding,
|
|
95
|
+
groups=channels,
|
|
96
|
+
bias=bias,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.use_layer_norm = True
|
|
100
|
+
self.norm = nn.LayerNorm(channels)
|
|
101
|
+
|
|
102
|
+
self.pointwise_conv2 = nn.Conv1d(
|
|
103
|
+
channels,
|
|
104
|
+
channels,
|
|
105
|
+
kernel_size=1,
|
|
106
|
+
stride=1,
|
|
107
|
+
padding=0,
|
|
108
|
+
bias=bias,
|
|
109
|
+
)
|
|
110
|
+
self.activation = activation
|
|
111
|
+
|
|
112
|
+
def forward(
|
|
113
|
+
self,
|
|
114
|
+
x: torch.Tensor,
|
|
115
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
116
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
|
117
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
+
"""Compute convolution module.
|
|
119
|
+
Args:
|
|
120
|
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
|
121
|
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
|
122
|
+
(0, 0, 0) means fake mask.
|
|
123
|
+
cache (torch.Tensor): left context cache, it is only
|
|
124
|
+
used in causal convolution (#batch, channels, cache_t),
|
|
125
|
+
(0, 0, 0) meas fake cache.
|
|
126
|
+
Returns:
|
|
127
|
+
torch.Tensor: Output tensor (#batch, time, channels).
|
|
128
|
+
"""
|
|
129
|
+
# exchange the temporal dimension and the feature dimension
|
|
130
|
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
|
131
|
+
|
|
132
|
+
# mask batch padding
|
|
133
|
+
if mask_pad.size(2) > 0: # time > 0
|
|
134
|
+
x.masked_fill_(~mask_pad, 0.0)
|
|
135
|
+
|
|
136
|
+
if self.lorder > 0:
|
|
137
|
+
if cache.size(2) == 0: # cache_t == 0
|
|
138
|
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
|
139
|
+
else:
|
|
140
|
+
assert cache.size(0) == x.size(0) # equal batch
|
|
141
|
+
assert cache.size(1) == x.size(1) # equal channel
|
|
142
|
+
x = torch.cat((cache, x), dim=2)
|
|
143
|
+
assert (x.size(2) > self.lorder)
|
|
144
|
+
new_cache = x[:, :, -self.lorder:]
|
|
145
|
+
else:
|
|
146
|
+
# It's better we just return None if no cache is required,
|
|
147
|
+
# However, for JIT export, here we just fake one tensor instead of
|
|
148
|
+
# None.
|
|
149
|
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
150
|
+
|
|
151
|
+
# GLU mechanism
|
|
152
|
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
|
153
|
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
|
154
|
+
|
|
155
|
+
# 1D Depthwise Conv
|
|
156
|
+
x = self.depthwise_conv(x)
|
|
157
|
+
if self.use_layer_norm:
|
|
158
|
+
x = x.transpose(1, 2)
|
|
159
|
+
x = self.activation(self.norm(x))
|
|
160
|
+
if self.use_layer_norm:
|
|
161
|
+
x = x.transpose(1, 2)
|
|
162
|
+
x = self.pointwise_conv2(x)
|
|
163
|
+
# mask batch padding
|
|
164
|
+
if mask_pad.size(2) > 0: # time > 0
|
|
165
|
+
x.masked_fill_(~mask_pad, 0.0)
|
|
166
|
+
|
|
167
|
+
return x.transpose(1, 2), new_cache
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class ConformerEncoderLayer(nn.Module):
|
|
171
|
+
"""Encoder layer module.
|
|
172
|
+
Args:
|
|
173
|
+
size (int): Input dimension.
|
|
174
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
175
|
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
|
176
|
+
instance can be used as the argument.
|
|
177
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
178
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
179
|
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
|
180
|
+
instance.
|
|
181
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
182
|
+
conv_module (torch.nn.Module): Convolution module instance.
|
|
183
|
+
`ConvlutionModule` instance can be used as the argument.
|
|
184
|
+
dropout_rate (float): Dropout rate.
|
|
185
|
+
normalize_before (bool):
|
|
186
|
+
True: use layer_norm before each sub-block.
|
|
187
|
+
False: use layer_norm after each sub-block.
|
|
188
|
+
concat_after (bool): Whether to concat attention layer's input and
|
|
189
|
+
output.
|
|
190
|
+
True: x -> x + linear(concat(x, att(x)))
|
|
191
|
+
False: x -> x + att(x)
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
size: int,
|
|
197
|
+
self_attn: torch.nn.Module,
|
|
198
|
+
feed_forward: Optional[nn.Module] = None,
|
|
199
|
+
feed_forward_macaron: Optional[nn.Module] = None,
|
|
200
|
+
conv_module: Optional[nn.Module] = None,
|
|
201
|
+
dropout_rate: float = 0.1,
|
|
202
|
+
normalize_before: bool = True,
|
|
203
|
+
concat_after: bool = False,
|
|
204
|
+
):
|
|
205
|
+
"""Construct an EncoderLayer object."""
|
|
206
|
+
super().__init__()
|
|
207
|
+
self.self_attn = self_attn
|
|
208
|
+
self.feed_forward = feed_forward
|
|
209
|
+
self.feed_forward_macaron = feed_forward_macaron
|
|
210
|
+
self.conv_module = conv_module
|
|
211
|
+
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
|
212
|
+
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
|
213
|
+
if feed_forward_macaron is not None:
|
|
214
|
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
|
215
|
+
self.ff_scale = 0.5
|
|
216
|
+
else:
|
|
217
|
+
self.ff_scale = 1.0
|
|
218
|
+
if self.conv_module is not None:
|
|
219
|
+
self.norm_conv = nn.LayerNorm(size,
|
|
220
|
+
eps=1e-5) # for the CNN module
|
|
221
|
+
self.norm_final = nn.LayerNorm(
|
|
222
|
+
size, eps=1e-5) # for the final output of the block
|
|
223
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
224
|
+
self.size = size
|
|
225
|
+
self.normalize_before = normalize_before
|
|
226
|
+
self.concat_after = concat_after
|
|
227
|
+
if self.concat_after:
|
|
228
|
+
self.concat_linear = nn.Linear(size + size, size)
|
|
229
|
+
else:
|
|
230
|
+
self.concat_linear = nn.Identity()
|
|
231
|
+
|
|
232
|
+
def forward(
|
|
233
|
+
self,
|
|
234
|
+
x: torch.Tensor,
|
|
235
|
+
mask: torch.Tensor,
|
|
236
|
+
pos_emb: torch.Tensor,
|
|
237
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
238
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
239
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
240
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
241
|
+
"""Compute encoded features.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x (torch.Tensor): (#batch, time, size)
|
|
245
|
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
|
246
|
+
(0, 0, 0) means fake mask.
|
|
247
|
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
|
248
|
+
for ConformerEncoderLayer.
|
|
249
|
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
|
250
|
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
|
251
|
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
|
252
|
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
253
|
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
|
254
|
+
(#batch=1, size, cache_t2)
|
|
255
|
+
Returns:
|
|
256
|
+
torch.Tensor: Output tensor (#batch, time, size).
|
|
257
|
+
torch.Tensor: Mask tensor (#batch, time, time).
|
|
258
|
+
torch.Tensor: att_cache tensor,
|
|
259
|
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
260
|
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
# whether to use macaron style
|
|
264
|
+
if self.feed_forward_macaron is not None:
|
|
265
|
+
residual = x
|
|
266
|
+
if self.normalize_before:
|
|
267
|
+
x = self.norm_ff_macaron(x)
|
|
268
|
+
x = residual + self.ff_scale * self.dropout(
|
|
269
|
+
self.feed_forward_macaron(x))
|
|
270
|
+
if not self.normalize_before:
|
|
271
|
+
x = self.norm_ff_macaron(x)
|
|
272
|
+
|
|
273
|
+
# multi-headed self-attention module
|
|
274
|
+
residual = x
|
|
275
|
+
if self.normalize_before:
|
|
276
|
+
x = self.norm_mha(x)
|
|
277
|
+
|
|
278
|
+
x_att, new_att_cache = self.self_attn(
|
|
279
|
+
x, x, x, mask, pos_emb, att_cache)
|
|
280
|
+
if self.concat_after:
|
|
281
|
+
x_concat = torch.cat((x, x_att), dim=-1)
|
|
282
|
+
x = residual + self.concat_linear(x_concat)
|
|
283
|
+
else:
|
|
284
|
+
x = residual + self.dropout(x_att)
|
|
285
|
+
if not self.normalize_before:
|
|
286
|
+
x = self.norm_mha(x)
|
|
287
|
+
|
|
288
|
+
# convolution module
|
|
289
|
+
# Fake new cnn cache here, and then change it in conv_module
|
|
290
|
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
291
|
+
if self.conv_module is not None:
|
|
292
|
+
residual = x
|
|
293
|
+
if self.normalize_before:
|
|
294
|
+
x = self.norm_conv(x)
|
|
295
|
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
|
296
|
+
x = residual + self.dropout(x)
|
|
297
|
+
|
|
298
|
+
if not self.normalize_before:
|
|
299
|
+
x = self.norm_conv(x)
|
|
300
|
+
|
|
301
|
+
# feed forward module
|
|
302
|
+
residual = x
|
|
303
|
+
if self.normalize_before:
|
|
304
|
+
x = self.norm_ff(x)
|
|
305
|
+
|
|
306
|
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
|
307
|
+
if not self.normalize_before:
|
|
308
|
+
x = self.norm_ff(x)
|
|
309
|
+
|
|
310
|
+
if self.conv_module is not None:
|
|
311
|
+
x = self.norm_final(x)
|
|
312
|
+
|
|
313
|
+
return x, mask, new_att_cache, new_cnn_cache
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class BaseEncoder(torch.nn.Module):
|
|
317
|
+
def __init__(
|
|
318
|
+
self,
|
|
319
|
+
input_size: int,
|
|
320
|
+
output_size: int = 256,
|
|
321
|
+
attention_heads: int = 4,
|
|
322
|
+
linear_units: int = 2048,
|
|
323
|
+
num_blocks: int = 6,
|
|
324
|
+
dropout_rate: float = 0.0,
|
|
325
|
+
input_layer: str = "conv2d",
|
|
326
|
+
pos_enc_layer_type: str = "abs_pos",
|
|
327
|
+
normalize_before: bool = True,
|
|
328
|
+
concat_after: bool = False,
|
|
329
|
+
):
|
|
330
|
+
"""
|
|
331
|
+
Args:
|
|
332
|
+
input_size (int): input dim
|
|
333
|
+
output_size (int): dimension of attention
|
|
334
|
+
attention_heads (int): the number of heads of multi head attention
|
|
335
|
+
linear_units (int): the hidden units number of position-wise feed
|
|
336
|
+
forward
|
|
337
|
+
num_blocks (int): the number of decoder blocks
|
|
338
|
+
dropout_rate (float): dropout rate
|
|
339
|
+
attention_dropout_rate (float): dropout rate in attention
|
|
340
|
+
positional_dropout_rate (float): dropout rate after adding
|
|
341
|
+
positional encoding
|
|
342
|
+
input_layer (str): input layer type.
|
|
343
|
+
optional [linear, conv2d, conv2d6, conv2d8]
|
|
344
|
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
|
345
|
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
|
346
|
+
normalize_before (bool):
|
|
347
|
+
True: use layer_norm before each sub-block of a layer.
|
|
348
|
+
False: use layer_norm after each sub-block of a layer.
|
|
349
|
+
concat_after (bool): whether to concat attention layer's input
|
|
350
|
+
and output.
|
|
351
|
+
True: x -> x + linear(concat(x, att(x)))
|
|
352
|
+
False: x -> x + att(x)
|
|
353
|
+
static_chunk_size (int): chunk size for static chunk training and
|
|
354
|
+
decoding
|
|
355
|
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
|
356
|
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
|
357
|
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
|
358
|
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
|
359
|
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
|
360
|
+
dynamic chunk training
|
|
361
|
+
"""
|
|
362
|
+
super().__init__()
|
|
363
|
+
self._output_size = output_size
|
|
364
|
+
|
|
365
|
+
if pos_enc_layer_type == "abs_pos":
|
|
366
|
+
pos_enc_class = PositionalEncoding
|
|
367
|
+
elif pos_enc_layer_type == "rel_pos":
|
|
368
|
+
pos_enc_class = RelPositionalEncoding
|
|
369
|
+
elif pos_enc_layer_type == "no_pos":
|
|
370
|
+
pos_enc_class = NoPositionalEncoding
|
|
371
|
+
else:
|
|
372
|
+
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
|
373
|
+
|
|
374
|
+
if input_layer == "linear":
|
|
375
|
+
subsampling_class = LinearNoSubsampling
|
|
376
|
+
elif input_layer == "conv2d2":
|
|
377
|
+
subsampling_class = Conv2dSubsampling2
|
|
378
|
+
elif input_layer == "conv2d":
|
|
379
|
+
subsampling_class = Conv2dSubsampling4
|
|
380
|
+
elif input_layer == "conv2d6":
|
|
381
|
+
subsampling_class = Conv2dSubsampling6
|
|
382
|
+
elif input_layer == "conv2d8":
|
|
383
|
+
subsampling_class = Conv2dSubsampling8
|
|
384
|
+
else:
|
|
385
|
+
raise ValueError("unknown input_layer: " + input_layer)
|
|
386
|
+
|
|
387
|
+
self.embed = subsampling_class(
|
|
388
|
+
input_size,
|
|
389
|
+
output_size,
|
|
390
|
+
dropout_rate,
|
|
391
|
+
pos_enc_class(output_size, dropout_rate),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
self.normalize_before = normalize_before
|
|
395
|
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
|
396
|
+
|
|
397
|
+
def output_size(self) -> int:
|
|
398
|
+
return self._output_size
|
|
399
|
+
|
|
400
|
+
def forward(
|
|
401
|
+
self,
|
|
402
|
+
xs: torch.Tensor,
|
|
403
|
+
xs_lens: torch.Tensor,
|
|
404
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
405
|
+
"""Embed positions in tensor.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
xs: padded input tensor (B, T, D)
|
|
409
|
+
xs_lens: input length (B)
|
|
410
|
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
411
|
+
0: default for training, use random dynamic chunk.
|
|
412
|
+
<0: for decoding, use full chunk.
|
|
413
|
+
>0: for decoding, use fixed chunk size as set.
|
|
414
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
415
|
+
the chunk size is decoding_chunk_size.
|
|
416
|
+
>=0: use num_decoding_left_chunks
|
|
417
|
+
<0: use all left chunks
|
|
418
|
+
Returns:
|
|
419
|
+
encoder output tensor xs, and subsampled masks
|
|
420
|
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
421
|
+
masks: torch.Tensor batch padding mask after subsample
|
|
422
|
+
(B, 1, T' ~= T/subsample_rate)
|
|
423
|
+
"""
|
|
424
|
+
T = xs.size(1)
|
|
425
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
426
|
+
xs, pos_emb, masks = self.embed(xs, masks)
|
|
427
|
+
chunk_masks = masks
|
|
428
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
429
|
+
for layer in self.encoders:
|
|
430
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
431
|
+
if self.normalize_before:
|
|
432
|
+
xs = self.after_norm(xs)
|
|
433
|
+
# Here we assume the mask is not changed in encoder layers, so just
|
|
434
|
+
# return the masks before encoder layers, and the masks will be used
|
|
435
|
+
# for cross attention with decoder later
|
|
436
|
+
return xs, masks
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class ConformerEncoder(BaseEncoder):
|
|
440
|
+
"""Conformer encoder module."""
|
|
441
|
+
|
|
442
|
+
def __init__(
|
|
443
|
+
self,
|
|
444
|
+
input_size: int,
|
|
445
|
+
output_size: int = 256,
|
|
446
|
+
attention_heads: int = 4,
|
|
447
|
+
linear_units: int = 2048,
|
|
448
|
+
num_blocks: int = 6,
|
|
449
|
+
dropout_rate: float = 0.0,
|
|
450
|
+
input_layer: str = "conv2d",
|
|
451
|
+
pos_enc_layer_type: str = "rel_pos",
|
|
452
|
+
normalize_before: bool = True,
|
|
453
|
+
concat_after: bool = False,
|
|
454
|
+
macaron_style: bool = False,
|
|
455
|
+
use_cnn_module: bool = True,
|
|
456
|
+
cnn_module_kernel: int = 15,
|
|
457
|
+
):
|
|
458
|
+
"""Construct ConformerEncoder
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
|
462
|
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
|
463
|
+
conv1d layer.
|
|
464
|
+
macaron_style (bool): Whether to use macaron style for
|
|
465
|
+
positionwise layer.
|
|
466
|
+
selfattention_layer_type (str): Encoder attention layer type,
|
|
467
|
+
the parameter has no effect now, it's just for configure
|
|
468
|
+
compatibility.
|
|
469
|
+
activation_type (str): Encoder activation function type.
|
|
470
|
+
use_cnn_module (bool): Whether to use convolution module.
|
|
471
|
+
cnn_module_kernel (int): Kernel size of convolution module.
|
|
472
|
+
causal (bool): whether to use causal convolution or not.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
super().__init__(input_size, output_size, attention_heads,
|
|
476
|
+
linear_units, num_blocks, dropout_rate,
|
|
477
|
+
input_layer, pos_enc_layer_type, normalize_before,
|
|
478
|
+
concat_after)
|
|
479
|
+
|
|
480
|
+
activation = torch.nn.SiLU()
|
|
481
|
+
|
|
482
|
+
# self-attention module definition
|
|
483
|
+
if pos_enc_layer_type != "rel_pos":
|
|
484
|
+
encoder_selfattn_layer = MultiHeadedAttention
|
|
485
|
+
else:
|
|
486
|
+
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
|
487
|
+
encoder_selfattn_layer_args = (
|
|
488
|
+
attention_heads,
|
|
489
|
+
output_size,
|
|
490
|
+
dropout_rate,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# feed-forward module definition
|
|
494
|
+
positionwise_layer = PositionwiseFeedForward
|
|
495
|
+
positionwise_layer_args = (
|
|
496
|
+
output_size,
|
|
497
|
+
linear_units,
|
|
498
|
+
dropout_rate,
|
|
499
|
+
activation,
|
|
500
|
+
)
|
|
501
|
+
# convolution module definition
|
|
502
|
+
convolution_layer = ConvolutionModule
|
|
503
|
+
convolution_layer_args = (output_size,
|
|
504
|
+
cnn_module_kernel,
|
|
505
|
+
activation,)
|
|
506
|
+
|
|
507
|
+
self.encoders = torch.nn.ModuleList([
|
|
508
|
+
ConformerEncoderLayer(
|
|
509
|
+
output_size,
|
|
510
|
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
511
|
+
positionwise_layer(*positionwise_layer_args),
|
|
512
|
+
positionwise_layer(
|
|
513
|
+
*positionwise_layer_args) if macaron_style else None,
|
|
514
|
+
convolution_layer(
|
|
515
|
+
*convolution_layer_args) if use_cnn_module else None,
|
|
516
|
+
dropout_rate,
|
|
517
|
+
normalize_before,
|
|
518
|
+
concat_after,
|
|
519
|
+
) for _ in range(num_blocks)
|
|
520
|
+
])
|