xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +400 -3
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +111 -49
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +26 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +58 -1
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +71 -3
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +503 -21
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +32 -55
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +138 -53
- xinference/model/llm/vllm/core.py +95 -78
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,656 @@
|
|
|
1
|
+
"""A popular speaker recognition and diarization model.
|
|
2
|
+
|
|
3
|
+
Authors
|
|
4
|
+
* Hwidong Na 2020
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch # noqa: F401
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
|
|
12
|
+
from indextts.BigVGAN.nnet.linear import Linear
|
|
13
|
+
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
|
17
|
+
"""Creates a binary mask for each sequence.
|
|
18
|
+
|
|
19
|
+
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
|
20
|
+
|
|
21
|
+
Arguments
|
|
22
|
+
---------
|
|
23
|
+
length : torch.LongTensor
|
|
24
|
+
Containing the length of each sequence in the batch. Must be 1D.
|
|
25
|
+
max_len : int
|
|
26
|
+
Max length for the mask, also the size of the second dimension.
|
|
27
|
+
dtype : torch.dtype, default: None
|
|
28
|
+
The dtype of the generated mask.
|
|
29
|
+
device: torch.device, default: None
|
|
30
|
+
The device to put the mask variable.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
mask : tensor
|
|
35
|
+
The binary mask.
|
|
36
|
+
|
|
37
|
+
Example
|
|
38
|
+
-------
|
|
39
|
+
>>> length=torch.Tensor([1,2,3])
|
|
40
|
+
>>> mask=length_to_mask(length)
|
|
41
|
+
>>> mask
|
|
42
|
+
tensor([[1., 0., 0.],
|
|
43
|
+
[1., 1., 0.],
|
|
44
|
+
[1., 1., 1.]])
|
|
45
|
+
"""
|
|
46
|
+
assert len(length.shape) == 1
|
|
47
|
+
|
|
48
|
+
if max_len is None:
|
|
49
|
+
max_len = length.max().long().item() # using arange to generate mask
|
|
50
|
+
mask = torch.arange(
|
|
51
|
+
max_len, device=length.device, dtype=length.dtype
|
|
52
|
+
).expand(len(length), max_len) < length.unsqueeze(1)
|
|
53
|
+
|
|
54
|
+
if dtype is None:
|
|
55
|
+
dtype = length.dtype
|
|
56
|
+
|
|
57
|
+
if device is None:
|
|
58
|
+
device = length.device
|
|
59
|
+
|
|
60
|
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
|
61
|
+
return mask
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Skip transpose as much as possible for efficiency
|
|
65
|
+
class Conv1d(_Conv1d):
|
|
66
|
+
"""1D convolution. Skip transpose is used to improve efficiency."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, *args, **kwargs):
|
|
69
|
+
super().__init__(skip_transpose=True, *args, **kwargs)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BatchNorm1d(_BatchNorm1d):
|
|
73
|
+
"""1D batch normalization. Skip transpose is used to improve efficiency."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, *args, **kwargs):
|
|
76
|
+
super().__init__(skip_transpose=True, *args, **kwargs)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TDNNBlock(nn.Module):
|
|
80
|
+
"""An implementation of TDNN.
|
|
81
|
+
|
|
82
|
+
Arguments
|
|
83
|
+
---------
|
|
84
|
+
in_channels : int
|
|
85
|
+
Number of input channels.
|
|
86
|
+
out_channels : int
|
|
87
|
+
The number of output channels.
|
|
88
|
+
kernel_size : int
|
|
89
|
+
The kernel size of the TDNN blocks.
|
|
90
|
+
dilation : int
|
|
91
|
+
The dilation of the TDNN block.
|
|
92
|
+
activation : torch class
|
|
93
|
+
A class for constructing the activation layers.
|
|
94
|
+
groups : int
|
|
95
|
+
The groups size of the TDNN blocks.
|
|
96
|
+
|
|
97
|
+
Example
|
|
98
|
+
-------
|
|
99
|
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
|
100
|
+
>>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
|
|
101
|
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
|
102
|
+
>>> out_tensor.shape
|
|
103
|
+
torch.Size([8, 120, 64])
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
in_channels,
|
|
109
|
+
out_channels,
|
|
110
|
+
kernel_size,
|
|
111
|
+
dilation,
|
|
112
|
+
activation=nn.ReLU,
|
|
113
|
+
groups=1,
|
|
114
|
+
):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.conv = Conv1d(
|
|
117
|
+
in_channels=in_channels,
|
|
118
|
+
out_channels=out_channels,
|
|
119
|
+
kernel_size=kernel_size,
|
|
120
|
+
dilation=dilation,
|
|
121
|
+
groups=groups,
|
|
122
|
+
)
|
|
123
|
+
self.activation = activation()
|
|
124
|
+
self.norm = BatchNorm1d(input_size=out_channels)
|
|
125
|
+
|
|
126
|
+
def forward(self, x):
|
|
127
|
+
"""Processes the input tensor x and returns an output tensor."""
|
|
128
|
+
return self.norm(self.activation(self.conv(x)))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Res2NetBlock(torch.nn.Module):
|
|
132
|
+
"""An implementation of Res2NetBlock w/ dilation.
|
|
133
|
+
|
|
134
|
+
Arguments
|
|
135
|
+
---------
|
|
136
|
+
in_channels : int
|
|
137
|
+
The number of channels expected in the input.
|
|
138
|
+
out_channels : int
|
|
139
|
+
The number of output channels.
|
|
140
|
+
scale : int
|
|
141
|
+
The scale of the Res2Net block.
|
|
142
|
+
kernel_size: int
|
|
143
|
+
The kernel size of the Res2Net block.
|
|
144
|
+
dilation : int
|
|
145
|
+
The dilation of the Res2Net block.
|
|
146
|
+
|
|
147
|
+
Example
|
|
148
|
+
-------
|
|
149
|
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
|
150
|
+
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
|
151
|
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
|
152
|
+
>>> out_tensor.shape
|
|
153
|
+
torch.Size([8, 120, 64])
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
|
|
158
|
+
):
|
|
159
|
+
super().__init__()
|
|
160
|
+
assert in_channels % scale == 0
|
|
161
|
+
assert out_channels % scale == 0
|
|
162
|
+
|
|
163
|
+
in_channel = in_channels // scale
|
|
164
|
+
hidden_channel = out_channels // scale
|
|
165
|
+
|
|
166
|
+
self.blocks = nn.ModuleList(
|
|
167
|
+
[
|
|
168
|
+
TDNNBlock(
|
|
169
|
+
in_channel,
|
|
170
|
+
hidden_channel,
|
|
171
|
+
kernel_size=kernel_size,
|
|
172
|
+
dilation=dilation,
|
|
173
|
+
)
|
|
174
|
+
for i in range(scale - 1)
|
|
175
|
+
]
|
|
176
|
+
)
|
|
177
|
+
self.scale = scale
|
|
178
|
+
|
|
179
|
+
def forward(self, x):
|
|
180
|
+
"""Processes the input tensor x and returns an output tensor."""
|
|
181
|
+
y = []
|
|
182
|
+
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
|
183
|
+
if i == 0:
|
|
184
|
+
y_i = x_i
|
|
185
|
+
elif i == 1:
|
|
186
|
+
y_i = self.blocks[i - 1](x_i)
|
|
187
|
+
else:
|
|
188
|
+
y_i = self.blocks[i - 1](x_i + y_i)
|
|
189
|
+
y.append(y_i)
|
|
190
|
+
y = torch.cat(y, dim=1)
|
|
191
|
+
return y
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class SEBlock(nn.Module):
|
|
195
|
+
"""An implementation of squeeze-and-excitation block.
|
|
196
|
+
|
|
197
|
+
Arguments
|
|
198
|
+
---------
|
|
199
|
+
in_channels : int
|
|
200
|
+
The number of input channels.
|
|
201
|
+
se_channels : int
|
|
202
|
+
The number of output channels after squeeze.
|
|
203
|
+
out_channels : int
|
|
204
|
+
The number of output channels.
|
|
205
|
+
|
|
206
|
+
Example
|
|
207
|
+
-------
|
|
208
|
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
|
209
|
+
>>> se_layer = SEBlock(64, 16, 64)
|
|
210
|
+
>>> lengths = torch.rand((8,))
|
|
211
|
+
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
|
212
|
+
>>> out_tensor.shape
|
|
213
|
+
torch.Size([8, 120, 64])
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self, in_channels, se_channels, out_channels):
|
|
217
|
+
super().__init__()
|
|
218
|
+
|
|
219
|
+
self.conv1 = Conv1d(
|
|
220
|
+
in_channels=in_channels, out_channels=se_channels, kernel_size=1
|
|
221
|
+
)
|
|
222
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
223
|
+
self.conv2 = Conv1d(
|
|
224
|
+
in_channels=se_channels, out_channels=out_channels, kernel_size=1
|
|
225
|
+
)
|
|
226
|
+
self.sigmoid = torch.nn.Sigmoid()
|
|
227
|
+
|
|
228
|
+
def forward(self, x, lengths=None):
|
|
229
|
+
"""Processes the input tensor x and returns an output tensor."""
|
|
230
|
+
L = x.shape[-1]
|
|
231
|
+
if lengths is not None:
|
|
232
|
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
|
233
|
+
mask = mask.unsqueeze(1)
|
|
234
|
+
total = mask.sum(dim=2, keepdim=True)
|
|
235
|
+
s = (x * mask).sum(dim=2, keepdim=True) / total
|
|
236
|
+
else:
|
|
237
|
+
s = x.mean(dim=2, keepdim=True)
|
|
238
|
+
|
|
239
|
+
s = self.relu(self.conv1(s))
|
|
240
|
+
s = self.sigmoid(self.conv2(s))
|
|
241
|
+
|
|
242
|
+
return s * x
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class AttentiveStatisticsPooling(nn.Module):
|
|
246
|
+
"""This class implements an attentive statistic pooling layer for each channel.
|
|
247
|
+
It returns the concatenated mean and std of the input tensor.
|
|
248
|
+
|
|
249
|
+
Arguments
|
|
250
|
+
---------
|
|
251
|
+
channels: int
|
|
252
|
+
The number of input channels.
|
|
253
|
+
attention_channels: int
|
|
254
|
+
The number of attention channels.
|
|
255
|
+
global_context: bool
|
|
256
|
+
Whether to use global context.
|
|
257
|
+
|
|
258
|
+
Example
|
|
259
|
+
-------
|
|
260
|
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
|
261
|
+
>>> asp_layer = AttentiveStatisticsPooling(64)
|
|
262
|
+
>>> lengths = torch.rand((8,))
|
|
263
|
+
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
|
264
|
+
>>> out_tensor.shape
|
|
265
|
+
torch.Size([8, 1, 128])
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, channels, attention_channels=128, global_context=True):
|
|
269
|
+
super().__init__()
|
|
270
|
+
|
|
271
|
+
self.eps = 1e-12
|
|
272
|
+
self.global_context = global_context
|
|
273
|
+
if global_context:
|
|
274
|
+
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
|
275
|
+
else:
|
|
276
|
+
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
|
277
|
+
self.tanh = nn.Tanh()
|
|
278
|
+
self.conv = Conv1d(
|
|
279
|
+
in_channels=attention_channels, out_channels=channels, kernel_size=1
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def forward(self, x, lengths=None):
|
|
283
|
+
"""Calculates mean and std for a batch (input tensor).
|
|
284
|
+
|
|
285
|
+
Arguments
|
|
286
|
+
---------
|
|
287
|
+
x : torch.Tensor
|
|
288
|
+
Tensor of shape [N, C, L].
|
|
289
|
+
lengths : torch.Tensor
|
|
290
|
+
The corresponding relative lengths of the inputs.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
pooled_stats : torch.Tensor
|
|
295
|
+
mean and std of batch
|
|
296
|
+
"""
|
|
297
|
+
L = x.shape[-1]
|
|
298
|
+
|
|
299
|
+
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
|
300
|
+
mean = (m * x).sum(dim)
|
|
301
|
+
std = torch.sqrt(
|
|
302
|
+
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
|
|
303
|
+
)
|
|
304
|
+
return mean, std
|
|
305
|
+
|
|
306
|
+
if lengths is None:
|
|
307
|
+
lengths = torch.ones(x.shape[0], device=x.device)
|
|
308
|
+
|
|
309
|
+
# Make binary mask of shape [N, 1, L]
|
|
310
|
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
|
311
|
+
mask = mask.unsqueeze(1)
|
|
312
|
+
|
|
313
|
+
# Expand the temporal context of the pooling layer by allowing the
|
|
314
|
+
# self-attention to look at global properties of the utterance.
|
|
315
|
+
if self.global_context:
|
|
316
|
+
# torch.std is unstable for backward computation
|
|
317
|
+
# https://github.com/pytorch/pytorch/issues/4320
|
|
318
|
+
total = mask.sum(dim=2, keepdim=True).float()
|
|
319
|
+
mean, std = _compute_statistics(x, mask / total)
|
|
320
|
+
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
|
321
|
+
std = std.unsqueeze(2).repeat(1, 1, L)
|
|
322
|
+
attn = torch.cat([x, mean, std], dim=1)
|
|
323
|
+
else:
|
|
324
|
+
attn = x
|
|
325
|
+
|
|
326
|
+
# Apply layers
|
|
327
|
+
attn = self.conv(self.tanh(self.tdnn(attn)))
|
|
328
|
+
|
|
329
|
+
# Filter out zero-paddings
|
|
330
|
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
|
331
|
+
|
|
332
|
+
attn = F.softmax(attn, dim=2)
|
|
333
|
+
mean, std = _compute_statistics(x, attn)
|
|
334
|
+
# Append mean and std of the batch
|
|
335
|
+
pooled_stats = torch.cat((mean, std), dim=1)
|
|
336
|
+
pooled_stats = pooled_stats.unsqueeze(2)
|
|
337
|
+
|
|
338
|
+
return pooled_stats
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class SERes2NetBlock(nn.Module):
|
|
342
|
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
|
343
|
+
TDNN-Res2Net-TDNN-SEBlock.
|
|
344
|
+
|
|
345
|
+
Arguments
|
|
346
|
+
---------
|
|
347
|
+
in_channels: int
|
|
348
|
+
Expected size of input channels.
|
|
349
|
+
out_channels: int
|
|
350
|
+
The number of output channels.
|
|
351
|
+
res2net_scale: int
|
|
352
|
+
The scale of the Res2Net block.
|
|
353
|
+
se_channels : int
|
|
354
|
+
The number of output channels after squeeze.
|
|
355
|
+
kernel_size: int
|
|
356
|
+
The kernel size of the TDNN blocks.
|
|
357
|
+
dilation: int
|
|
358
|
+
The dilation of the Res2Net block.
|
|
359
|
+
activation : torch class
|
|
360
|
+
A class for constructing the activation layers.
|
|
361
|
+
groups: int
|
|
362
|
+
Number of blocked connections from input channels to output channels.
|
|
363
|
+
|
|
364
|
+
Example
|
|
365
|
+
-------
|
|
366
|
+
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
|
367
|
+
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
|
368
|
+
>>> out = conv(x).transpose(1, 2)
|
|
369
|
+
>>> out.shape
|
|
370
|
+
torch.Size([8, 120, 64])
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
in_channels,
|
|
376
|
+
out_channels,
|
|
377
|
+
res2net_scale=8,
|
|
378
|
+
se_channels=128,
|
|
379
|
+
kernel_size=1,
|
|
380
|
+
dilation=1,
|
|
381
|
+
activation=torch.nn.ReLU,
|
|
382
|
+
groups=1,
|
|
383
|
+
):
|
|
384
|
+
super().__init__()
|
|
385
|
+
self.out_channels = out_channels
|
|
386
|
+
self.tdnn1 = TDNNBlock(
|
|
387
|
+
in_channels,
|
|
388
|
+
out_channels,
|
|
389
|
+
kernel_size=1,
|
|
390
|
+
dilation=1,
|
|
391
|
+
activation=activation,
|
|
392
|
+
groups=groups,
|
|
393
|
+
)
|
|
394
|
+
self.res2net_block = Res2NetBlock(
|
|
395
|
+
out_channels, out_channels, res2net_scale, kernel_size, dilation
|
|
396
|
+
)
|
|
397
|
+
self.tdnn2 = TDNNBlock(
|
|
398
|
+
out_channels,
|
|
399
|
+
out_channels,
|
|
400
|
+
kernel_size=1,
|
|
401
|
+
dilation=1,
|
|
402
|
+
activation=activation,
|
|
403
|
+
groups=groups,
|
|
404
|
+
)
|
|
405
|
+
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
|
406
|
+
|
|
407
|
+
self.shortcut = None
|
|
408
|
+
if in_channels != out_channels:
|
|
409
|
+
self.shortcut = Conv1d(
|
|
410
|
+
in_channels=in_channels,
|
|
411
|
+
out_channels=out_channels,
|
|
412
|
+
kernel_size=1,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def forward(self, x, lengths=None):
|
|
416
|
+
"""Processes the input tensor x and returns an output tensor."""
|
|
417
|
+
residual = x
|
|
418
|
+
if self.shortcut:
|
|
419
|
+
residual = self.shortcut(x)
|
|
420
|
+
|
|
421
|
+
x = self.tdnn1(x)
|
|
422
|
+
x = self.res2net_block(x)
|
|
423
|
+
x = self.tdnn2(x)
|
|
424
|
+
x = self.se_block(x, lengths)
|
|
425
|
+
|
|
426
|
+
return x + residual
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class ECAPA_TDNN(torch.nn.Module):
|
|
430
|
+
"""An implementation of the speaker embedding model in a paper.
|
|
431
|
+
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
|
432
|
+
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
|
433
|
+
|
|
434
|
+
Arguments
|
|
435
|
+
---------
|
|
436
|
+
input_size : int
|
|
437
|
+
Expected size of the input dimension.
|
|
438
|
+
device : str
|
|
439
|
+
Device used, e.g., "cpu" or "cuda".
|
|
440
|
+
lin_neurons : int
|
|
441
|
+
Number of neurons in linear layers.
|
|
442
|
+
activation : torch class
|
|
443
|
+
A class for constructing the activation layers.
|
|
444
|
+
channels : list of ints
|
|
445
|
+
Output channels for TDNN/SERes2Net layer.
|
|
446
|
+
kernel_sizes : list of ints
|
|
447
|
+
List of kernel sizes for each layer.
|
|
448
|
+
dilations : list of ints
|
|
449
|
+
List of dilations for kernels in each layer.
|
|
450
|
+
attention_channels: int
|
|
451
|
+
The number of attention channels.
|
|
452
|
+
res2net_scale : int
|
|
453
|
+
The scale of the Res2Net block.
|
|
454
|
+
se_channels : int
|
|
455
|
+
The number of output channels after squeeze.
|
|
456
|
+
global_context: bool
|
|
457
|
+
Whether to use global context.
|
|
458
|
+
groups : list of ints
|
|
459
|
+
List of groups for kernels in each layer.
|
|
460
|
+
|
|
461
|
+
Example
|
|
462
|
+
-------
|
|
463
|
+
>>> input_feats = torch.rand([5, 120, 80])
|
|
464
|
+
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
|
465
|
+
>>> outputs = compute_embedding(input_feats)
|
|
466
|
+
>>> outputs.shape
|
|
467
|
+
torch.Size([5, 1, 192])
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
def __init__(
|
|
471
|
+
self,
|
|
472
|
+
input_size,
|
|
473
|
+
device="cpu",
|
|
474
|
+
lin_neurons=192,
|
|
475
|
+
activation=torch.nn.ReLU,
|
|
476
|
+
channels=[512, 512, 512, 512, 1536],
|
|
477
|
+
kernel_sizes=[5, 3, 3, 3, 1],
|
|
478
|
+
dilations=[1, 2, 3, 4, 1],
|
|
479
|
+
attention_channels=128,
|
|
480
|
+
res2net_scale=8,
|
|
481
|
+
se_channels=128,
|
|
482
|
+
global_context=True,
|
|
483
|
+
groups=[1, 1, 1, 1, 1],
|
|
484
|
+
):
|
|
485
|
+
super().__init__()
|
|
486
|
+
assert len(channels) == len(kernel_sizes)
|
|
487
|
+
assert len(channels) == len(dilations)
|
|
488
|
+
self.channels = channels
|
|
489
|
+
self.blocks = nn.ModuleList()
|
|
490
|
+
|
|
491
|
+
# The initial TDNN layer
|
|
492
|
+
self.blocks.append(
|
|
493
|
+
TDNNBlock(
|
|
494
|
+
input_size,
|
|
495
|
+
channels[0],
|
|
496
|
+
kernel_sizes[0],
|
|
497
|
+
dilations[0],
|
|
498
|
+
activation,
|
|
499
|
+
groups[0],
|
|
500
|
+
)
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# SE-Res2Net layers
|
|
504
|
+
for i in range(1, len(channels) - 1):
|
|
505
|
+
self.blocks.append(
|
|
506
|
+
SERes2NetBlock(
|
|
507
|
+
channels[i - 1],
|
|
508
|
+
channels[i],
|
|
509
|
+
res2net_scale=res2net_scale,
|
|
510
|
+
se_channels=se_channels,
|
|
511
|
+
kernel_size=kernel_sizes[i],
|
|
512
|
+
dilation=dilations[i],
|
|
513
|
+
activation=activation,
|
|
514
|
+
groups=groups[i],
|
|
515
|
+
)
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Multi-layer feature aggregation
|
|
519
|
+
self.mfa = TDNNBlock(
|
|
520
|
+
channels[-2] * (len(channels) - 2),
|
|
521
|
+
channels[-1],
|
|
522
|
+
kernel_sizes[-1],
|
|
523
|
+
dilations[-1],
|
|
524
|
+
activation,
|
|
525
|
+
groups=groups[-1],
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Attentive Statistical Pooling
|
|
529
|
+
self.asp = AttentiveStatisticsPooling(
|
|
530
|
+
channels[-1],
|
|
531
|
+
attention_channels=attention_channels,
|
|
532
|
+
global_context=global_context,
|
|
533
|
+
)
|
|
534
|
+
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
|
535
|
+
|
|
536
|
+
# Final linear transformation
|
|
537
|
+
self.fc = Conv1d(
|
|
538
|
+
in_channels=channels[-1] * 2,
|
|
539
|
+
out_channels=lin_neurons,
|
|
540
|
+
kernel_size=1,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
def forward(self, x, lengths=None):
|
|
544
|
+
"""Returns the embedding vector.
|
|
545
|
+
|
|
546
|
+
Arguments
|
|
547
|
+
---------
|
|
548
|
+
x : torch.Tensor
|
|
549
|
+
Tensor of shape (batch, time, channel).
|
|
550
|
+
lengths : torch.Tensor
|
|
551
|
+
Corresponding relative lengths of inputs.
|
|
552
|
+
|
|
553
|
+
Returns
|
|
554
|
+
-------
|
|
555
|
+
x : torch.Tensor
|
|
556
|
+
Embedding vector.
|
|
557
|
+
"""
|
|
558
|
+
# Minimize transpose for efficiency
|
|
559
|
+
x = x.transpose(1, 2)
|
|
560
|
+
|
|
561
|
+
xl = []
|
|
562
|
+
for layer in self.blocks:
|
|
563
|
+
try:
|
|
564
|
+
x = layer(x, lengths=lengths)
|
|
565
|
+
except TypeError:
|
|
566
|
+
x = layer(x)
|
|
567
|
+
xl.append(x)
|
|
568
|
+
|
|
569
|
+
# Multi-layer feature aggregation
|
|
570
|
+
x = torch.cat(xl[1:], dim=1)
|
|
571
|
+
x = self.mfa(x)
|
|
572
|
+
|
|
573
|
+
# Attentive Statistical Pooling
|
|
574
|
+
x = self.asp(x, lengths=lengths)
|
|
575
|
+
x = self.asp_bn(x)
|
|
576
|
+
|
|
577
|
+
# Final linear transformation
|
|
578
|
+
x = self.fc(x)
|
|
579
|
+
|
|
580
|
+
x = x.transpose(1, 2)
|
|
581
|
+
return x
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
class Classifier(torch.nn.Module):
|
|
585
|
+
"""This class implements the cosine similarity on the top of features.
|
|
586
|
+
|
|
587
|
+
Arguments
|
|
588
|
+
---------
|
|
589
|
+
input_size : int
|
|
590
|
+
Expected size of input dimension.
|
|
591
|
+
device : str
|
|
592
|
+
Device used, e.g., "cpu" or "cuda".
|
|
593
|
+
lin_blocks : int
|
|
594
|
+
Number of linear layers.
|
|
595
|
+
lin_neurons : int
|
|
596
|
+
Number of neurons in linear layers.
|
|
597
|
+
out_neurons : int
|
|
598
|
+
Number of classes.
|
|
599
|
+
|
|
600
|
+
Example
|
|
601
|
+
-------
|
|
602
|
+
>>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
|
|
603
|
+
>>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
|
|
604
|
+
>>> outputs = outputs.unsqueeze(1)
|
|
605
|
+
>>> cos = classify(outputs)
|
|
606
|
+
>>> (cos < -1.0).long().sum()
|
|
607
|
+
tensor(0)
|
|
608
|
+
>>> (cos > 1.0).long().sum()
|
|
609
|
+
tensor(0)
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
def __init__(
|
|
613
|
+
self,
|
|
614
|
+
input_size,
|
|
615
|
+
device="cpu",
|
|
616
|
+
lin_blocks=0,
|
|
617
|
+
lin_neurons=192,
|
|
618
|
+
out_neurons=1211,
|
|
619
|
+
):
|
|
620
|
+
super().__init__()
|
|
621
|
+
self.blocks = nn.ModuleList()
|
|
622
|
+
|
|
623
|
+
for block_index in range(lin_blocks):
|
|
624
|
+
self.blocks.extend(
|
|
625
|
+
[
|
|
626
|
+
_BatchNorm1d(input_size=input_size),
|
|
627
|
+
Linear(input_size=input_size, n_neurons=lin_neurons),
|
|
628
|
+
]
|
|
629
|
+
)
|
|
630
|
+
input_size = lin_neurons
|
|
631
|
+
|
|
632
|
+
# Final Layer
|
|
633
|
+
self.weight = nn.Parameter(
|
|
634
|
+
torch.FloatTensor(out_neurons, input_size, device=device)
|
|
635
|
+
)
|
|
636
|
+
nn.init.xavier_uniform_(self.weight)
|
|
637
|
+
|
|
638
|
+
def forward(self, x):
|
|
639
|
+
"""Returns the output probabilities over speakers.
|
|
640
|
+
|
|
641
|
+
Arguments
|
|
642
|
+
---------
|
|
643
|
+
x : torch.Tensor
|
|
644
|
+
Torch tensor.
|
|
645
|
+
|
|
646
|
+
Returns
|
|
647
|
+
-------
|
|
648
|
+
out : torch.Tensor
|
|
649
|
+
Output probabilities over speakers.
|
|
650
|
+
"""
|
|
651
|
+
for layer in self.blocks:
|
|
652
|
+
x = layer(x)
|
|
653
|
+
|
|
654
|
+
# Need to be normalized
|
|
655
|
+
x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight))
|
|
656
|
+
return x.unsqueeze(1)
|
|
File without changes
|