xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
"""Library implementing normalization.
|
|
2
|
+
|
|
3
|
+
Authors
|
|
4
|
+
* Mirco Ravanelli 2020
|
|
5
|
+
* Guillermo Cámbara 2021
|
|
6
|
+
* Sarthak Yadav 2022
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BatchNorm1d(nn.Module):
|
|
14
|
+
"""Applies 1d batch normalization to the input tensor.
|
|
15
|
+
|
|
16
|
+
Arguments
|
|
17
|
+
---------
|
|
18
|
+
input_shape : tuple
|
|
19
|
+
The expected shape of the input. Alternatively, use ``input_size``.
|
|
20
|
+
input_size : int
|
|
21
|
+
The expected size of the input. Alternatively, use ``input_shape``.
|
|
22
|
+
eps : float
|
|
23
|
+
This value is added to std deviation estimation to improve the numerical
|
|
24
|
+
stability.
|
|
25
|
+
momentum : float
|
|
26
|
+
It is a value used for the running_mean and running_var computation.
|
|
27
|
+
affine : bool
|
|
28
|
+
When set to True, the affine parameters are learned.
|
|
29
|
+
track_running_stats : bool
|
|
30
|
+
When set to True, this module tracks the running mean and variance,
|
|
31
|
+
and when set to False, this module does not track such statistics.
|
|
32
|
+
combine_batch_time : bool
|
|
33
|
+
When true, it combines batch an time axis.
|
|
34
|
+
skip_transpose : bool
|
|
35
|
+
Whether to skip the transposition.
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
Example
|
|
39
|
+
-------
|
|
40
|
+
>>> input = torch.randn(100, 10)
|
|
41
|
+
>>> norm = BatchNorm1d(input_shape=input.shape)
|
|
42
|
+
>>> output = norm(input)
|
|
43
|
+
>>> output.shape
|
|
44
|
+
torch.Size([100, 10])
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
input_shape=None,
|
|
50
|
+
input_size=None,
|
|
51
|
+
eps=1e-05,
|
|
52
|
+
momentum=0.1,
|
|
53
|
+
affine=True,
|
|
54
|
+
track_running_stats=True,
|
|
55
|
+
combine_batch_time=False,
|
|
56
|
+
skip_transpose=False,
|
|
57
|
+
):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.combine_batch_time = combine_batch_time
|
|
60
|
+
self.skip_transpose = skip_transpose
|
|
61
|
+
|
|
62
|
+
if input_size is None and skip_transpose:
|
|
63
|
+
input_size = input_shape[1]
|
|
64
|
+
elif input_size is None:
|
|
65
|
+
input_size = input_shape[-1]
|
|
66
|
+
|
|
67
|
+
self.norm = nn.BatchNorm1d(
|
|
68
|
+
input_size,
|
|
69
|
+
eps=eps,
|
|
70
|
+
momentum=momentum,
|
|
71
|
+
affine=affine,
|
|
72
|
+
track_running_stats=track_running_stats,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def forward(self, x):
|
|
76
|
+
"""Returns the normalized input tensor.
|
|
77
|
+
|
|
78
|
+
Arguments
|
|
79
|
+
---------
|
|
80
|
+
x : torch.Tensor (batch, time, [channels])
|
|
81
|
+
input to normalize. 2d or 3d tensors are expected in input
|
|
82
|
+
4d tensors can be used when combine_dims=True.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
x_n : torch.Tensor
|
|
87
|
+
The normalized outputs.
|
|
88
|
+
"""
|
|
89
|
+
shape_or = x.shape
|
|
90
|
+
if self.combine_batch_time:
|
|
91
|
+
if x.ndim == 3:
|
|
92
|
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
|
93
|
+
else:
|
|
94
|
+
x = x.reshape(
|
|
95
|
+
shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
elif not self.skip_transpose:
|
|
99
|
+
x = x.transpose(-1, 1)
|
|
100
|
+
|
|
101
|
+
x_n = self.norm(x)
|
|
102
|
+
|
|
103
|
+
if self.combine_batch_time:
|
|
104
|
+
x_n = x_n.reshape(shape_or)
|
|
105
|
+
elif not self.skip_transpose:
|
|
106
|
+
x_n = x_n.transpose(1, -1)
|
|
107
|
+
|
|
108
|
+
return x_n
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class BatchNorm2d(nn.Module):
|
|
112
|
+
"""Applies 2d batch normalization to the input tensor.
|
|
113
|
+
|
|
114
|
+
Arguments
|
|
115
|
+
---------
|
|
116
|
+
input_shape : tuple
|
|
117
|
+
The expected shape of the input. Alternatively, use ``input_size``.
|
|
118
|
+
input_size : int
|
|
119
|
+
The expected size of the input. Alternatively, use ``input_shape``.
|
|
120
|
+
eps : float
|
|
121
|
+
This value is added to std deviation estimation to improve the numerical
|
|
122
|
+
stability.
|
|
123
|
+
momentum : float
|
|
124
|
+
It is a value used for the running_mean and running_var computation.
|
|
125
|
+
affine : bool
|
|
126
|
+
When set to True, the affine parameters are learned.
|
|
127
|
+
track_running_stats : bool
|
|
128
|
+
When set to True, this module tracks the running mean and variance,
|
|
129
|
+
and when set to False, this module does not track such statistics.
|
|
130
|
+
|
|
131
|
+
Example
|
|
132
|
+
-------
|
|
133
|
+
>>> input = torch.randn(100, 10, 5, 20)
|
|
134
|
+
>>> norm = BatchNorm2d(input_shape=input.shape)
|
|
135
|
+
>>> output = norm(input)
|
|
136
|
+
>>> output.shape
|
|
137
|
+
torch.Size([100, 10, 5, 20])
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
input_shape=None,
|
|
143
|
+
input_size=None,
|
|
144
|
+
eps=1e-05,
|
|
145
|
+
momentum=0.1,
|
|
146
|
+
affine=True,
|
|
147
|
+
track_running_stats=True,
|
|
148
|
+
):
|
|
149
|
+
super().__init__()
|
|
150
|
+
|
|
151
|
+
if input_shape is None and input_size is None:
|
|
152
|
+
raise ValueError("Expected input_shape or input_size as input")
|
|
153
|
+
|
|
154
|
+
if input_size is None:
|
|
155
|
+
input_size = input_shape[-1]
|
|
156
|
+
|
|
157
|
+
self.norm = nn.BatchNorm2d(
|
|
158
|
+
input_size,
|
|
159
|
+
eps=eps,
|
|
160
|
+
momentum=momentum,
|
|
161
|
+
affine=affine,
|
|
162
|
+
track_running_stats=track_running_stats,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def forward(self, x):
|
|
166
|
+
"""Returns the normalized input tensor.
|
|
167
|
+
|
|
168
|
+
Arguments
|
|
169
|
+
---------
|
|
170
|
+
x : torch.Tensor (batch, time, channel1, channel2)
|
|
171
|
+
input to normalize. 4d tensors are expected.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
x_n : torch.Tensor
|
|
176
|
+
The normalized outputs.
|
|
177
|
+
"""
|
|
178
|
+
x = x.transpose(-1, 1)
|
|
179
|
+
x_n = self.norm(x)
|
|
180
|
+
x_n = x_n.transpose(1, -1)
|
|
181
|
+
|
|
182
|
+
return x_n
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class LayerNorm(nn.Module):
|
|
186
|
+
"""Applies layer normalization to the input tensor.
|
|
187
|
+
|
|
188
|
+
Arguments
|
|
189
|
+
---------
|
|
190
|
+
input_size : int
|
|
191
|
+
The expected size of the dimension to be normalized.
|
|
192
|
+
input_shape : tuple
|
|
193
|
+
The expected shape of the input.
|
|
194
|
+
eps : float
|
|
195
|
+
This value is added to std deviation estimation to improve the numerical
|
|
196
|
+
stability.
|
|
197
|
+
elementwise_affine : bool
|
|
198
|
+
If True, this module has learnable per-element affine parameters
|
|
199
|
+
initialized to ones (for weights) and zeros (for biases).
|
|
200
|
+
|
|
201
|
+
Example
|
|
202
|
+
-------
|
|
203
|
+
>>> input = torch.randn(100, 101, 128)
|
|
204
|
+
>>> norm = LayerNorm(input_shape=input.shape)
|
|
205
|
+
>>> output = norm(input)
|
|
206
|
+
>>> output.shape
|
|
207
|
+
torch.Size([100, 101, 128])
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
input_size=None,
|
|
213
|
+
input_shape=None,
|
|
214
|
+
eps=1e-05,
|
|
215
|
+
elementwise_affine=True,
|
|
216
|
+
):
|
|
217
|
+
super().__init__()
|
|
218
|
+
self.eps = eps
|
|
219
|
+
self.elementwise_affine = elementwise_affine
|
|
220
|
+
|
|
221
|
+
if input_shape is not None:
|
|
222
|
+
input_size = input_shape[2:]
|
|
223
|
+
|
|
224
|
+
self.norm = torch.nn.LayerNorm(
|
|
225
|
+
input_size,
|
|
226
|
+
eps=self.eps,
|
|
227
|
+
elementwise_affine=self.elementwise_affine,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def forward(self, x):
|
|
231
|
+
"""Returns the normalized input tensor.
|
|
232
|
+
|
|
233
|
+
Arguments
|
|
234
|
+
---------
|
|
235
|
+
x : torch.Tensor (batch, time, channels)
|
|
236
|
+
input to normalize. 3d or 4d tensors are expected.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
The normalized outputs.
|
|
241
|
+
"""
|
|
242
|
+
return self.norm(x)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class InstanceNorm1d(nn.Module):
|
|
246
|
+
"""Applies 1d instance normalization to the input tensor.
|
|
247
|
+
|
|
248
|
+
Arguments
|
|
249
|
+
---------
|
|
250
|
+
input_shape : tuple
|
|
251
|
+
The expected shape of the input. Alternatively, use ``input_size``.
|
|
252
|
+
input_size : int
|
|
253
|
+
The expected size of the input. Alternatively, use ``input_shape``.
|
|
254
|
+
eps : float
|
|
255
|
+
This value is added to std deviation estimation to improve the numerical
|
|
256
|
+
stability.
|
|
257
|
+
momentum : float
|
|
258
|
+
It is a value used for the running_mean and running_var computation.
|
|
259
|
+
track_running_stats : bool
|
|
260
|
+
When set to True, this module tracks the running mean and variance,
|
|
261
|
+
and when set to False, this module does not track such statistics.
|
|
262
|
+
affine : bool
|
|
263
|
+
A boolean value that when set to True, this module has learnable
|
|
264
|
+
affine parameters, initialized the same way as done for
|
|
265
|
+
batch normalization. Default: False.
|
|
266
|
+
|
|
267
|
+
Example
|
|
268
|
+
-------
|
|
269
|
+
>>> input = torch.randn(100, 10, 20)
|
|
270
|
+
>>> norm = InstanceNorm1d(input_shape=input.shape)
|
|
271
|
+
>>> output = norm(input)
|
|
272
|
+
>>> output.shape
|
|
273
|
+
torch.Size([100, 10, 20])
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
input_shape=None,
|
|
279
|
+
input_size=None,
|
|
280
|
+
eps=1e-05,
|
|
281
|
+
momentum=0.1,
|
|
282
|
+
track_running_stats=True,
|
|
283
|
+
affine=False,
|
|
284
|
+
):
|
|
285
|
+
super().__init__()
|
|
286
|
+
|
|
287
|
+
if input_shape is None and input_size is None:
|
|
288
|
+
raise ValueError("Expected input_shape or input_size as input")
|
|
289
|
+
|
|
290
|
+
if input_size is None:
|
|
291
|
+
input_size = input_shape[-1]
|
|
292
|
+
|
|
293
|
+
self.norm = nn.InstanceNorm1d(
|
|
294
|
+
input_size,
|
|
295
|
+
eps=eps,
|
|
296
|
+
momentum=momentum,
|
|
297
|
+
track_running_stats=track_running_stats,
|
|
298
|
+
affine=affine,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def forward(self, x):
|
|
302
|
+
"""Returns the normalized input tensor.
|
|
303
|
+
|
|
304
|
+
Arguments
|
|
305
|
+
---------
|
|
306
|
+
x : torch.Tensor (batch, time, channels)
|
|
307
|
+
input to normalize. 3d tensors are expected.
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
x_n : torch.Tensor
|
|
312
|
+
The normalized outputs.
|
|
313
|
+
"""
|
|
314
|
+
x = x.transpose(-1, 1)
|
|
315
|
+
x_n = self.norm(x)
|
|
316
|
+
x_n = x_n.transpose(1, -1)
|
|
317
|
+
|
|
318
|
+
return x_n
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class InstanceNorm2d(nn.Module):
|
|
322
|
+
"""Applies 2d instance normalization to the input tensor.
|
|
323
|
+
|
|
324
|
+
Arguments
|
|
325
|
+
---------
|
|
326
|
+
input_shape : tuple
|
|
327
|
+
The expected shape of the input. Alternatively, use ``input_size``.
|
|
328
|
+
input_size : int
|
|
329
|
+
The expected size of the input. Alternatively, use ``input_shape``.
|
|
330
|
+
eps : float
|
|
331
|
+
This value is added to std deviation estimation to improve the numerical
|
|
332
|
+
stability.
|
|
333
|
+
momentum : float
|
|
334
|
+
It is a value used for the running_mean and running_var computation.
|
|
335
|
+
track_running_stats : bool
|
|
336
|
+
When set to True, this module tracks the running mean and variance,
|
|
337
|
+
and when set to False, this module does not track such statistics.
|
|
338
|
+
affine : bool
|
|
339
|
+
A boolean value that when set to True, this module has learnable
|
|
340
|
+
affine parameters, initialized the same way as done for
|
|
341
|
+
batch normalization. Default: False.
|
|
342
|
+
|
|
343
|
+
Example
|
|
344
|
+
-------
|
|
345
|
+
>>> input = torch.randn(100, 10, 20, 2)
|
|
346
|
+
>>> norm = InstanceNorm2d(input_shape=input.shape)
|
|
347
|
+
>>> output = norm(input)
|
|
348
|
+
>>> output.shape
|
|
349
|
+
torch.Size([100, 10, 20, 2])
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
input_shape=None,
|
|
355
|
+
input_size=None,
|
|
356
|
+
eps=1e-05,
|
|
357
|
+
momentum=0.1,
|
|
358
|
+
track_running_stats=True,
|
|
359
|
+
affine=False,
|
|
360
|
+
):
|
|
361
|
+
super().__init__()
|
|
362
|
+
|
|
363
|
+
if input_shape is None and input_size is None:
|
|
364
|
+
raise ValueError("Expected input_shape or input_size as input")
|
|
365
|
+
|
|
366
|
+
if input_size is None:
|
|
367
|
+
input_size = input_shape[-1]
|
|
368
|
+
|
|
369
|
+
self.norm = nn.InstanceNorm2d(
|
|
370
|
+
input_size,
|
|
371
|
+
eps=eps,
|
|
372
|
+
momentum=momentum,
|
|
373
|
+
track_running_stats=track_running_stats,
|
|
374
|
+
affine=affine,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def forward(self, x):
|
|
378
|
+
"""Returns the normalized input tensor.
|
|
379
|
+
|
|
380
|
+
Arguments
|
|
381
|
+
---------
|
|
382
|
+
x : torch.Tensor (batch, time, channel1, channel2)
|
|
383
|
+
input to normalize. 4d tensors are expected.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
x_n : torch.Tensor
|
|
388
|
+
The normalized outputs.
|
|
389
|
+
"""
|
|
390
|
+
x = x.transpose(-1, 1)
|
|
391
|
+
x_n = self.norm(x)
|
|
392
|
+
x_n = x_n.transpose(1, -1)
|
|
393
|
+
|
|
394
|
+
return x_n
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class GroupNorm(nn.Module):
|
|
398
|
+
"""Applies group normalization to the input tensor.
|
|
399
|
+
|
|
400
|
+
Arguments
|
|
401
|
+
---------
|
|
402
|
+
input_shape : tuple
|
|
403
|
+
The expected shape of the input. Alternatively, use ``input_size``.
|
|
404
|
+
input_size : int
|
|
405
|
+
The expected size of the input. Alternatively, use ``input_shape``.
|
|
406
|
+
num_groups : int
|
|
407
|
+
Number of groups to separate the channels into.
|
|
408
|
+
eps : float
|
|
409
|
+
This value is added to std deviation estimation to improve the numerical
|
|
410
|
+
stability.
|
|
411
|
+
affine : bool
|
|
412
|
+
A boolean value that when set to True, this module has learnable per-channel
|
|
413
|
+
affine parameters initialized to ones (for weights) and zeros (for biases).
|
|
414
|
+
|
|
415
|
+
Example
|
|
416
|
+
-------
|
|
417
|
+
>>> input = torch.randn(100, 101, 128)
|
|
418
|
+
>>> norm = GroupNorm(input_size=128, num_groups=128)
|
|
419
|
+
>>> output = norm(input)
|
|
420
|
+
>>> output.shape
|
|
421
|
+
torch.Size([100, 101, 128])
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
def __init__(
|
|
425
|
+
self,
|
|
426
|
+
input_shape=None,
|
|
427
|
+
input_size=None,
|
|
428
|
+
num_groups=None,
|
|
429
|
+
eps=1e-05,
|
|
430
|
+
affine=True,
|
|
431
|
+
):
|
|
432
|
+
super().__init__()
|
|
433
|
+
self.eps = eps
|
|
434
|
+
self.affine = affine
|
|
435
|
+
|
|
436
|
+
if input_shape is None and input_size is None:
|
|
437
|
+
raise ValueError("Expected input_shape or input_size as input")
|
|
438
|
+
|
|
439
|
+
if num_groups is None:
|
|
440
|
+
raise ValueError("Expected num_groups as input")
|
|
441
|
+
|
|
442
|
+
if input_shape is not None:
|
|
443
|
+
input_size = input_shape[-1]
|
|
444
|
+
|
|
445
|
+
self.norm = torch.nn.GroupNorm(
|
|
446
|
+
num_groups,
|
|
447
|
+
input_size,
|
|
448
|
+
eps=self.eps,
|
|
449
|
+
affine=self.affine,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def forward(self, x):
|
|
453
|
+
"""Returns the normalized input tensor.
|
|
454
|
+
|
|
455
|
+
Arguments
|
|
456
|
+
---------
|
|
457
|
+
x : torch.Tensor (batch, time, channels)
|
|
458
|
+
input to normalize. 3d or 4d tensors are expected.
|
|
459
|
+
|
|
460
|
+
Returns
|
|
461
|
+
-------
|
|
462
|
+
x_n : torch.Tensor
|
|
463
|
+
The normalized outputs.
|
|
464
|
+
"""
|
|
465
|
+
x = x.transpose(-1, 1)
|
|
466
|
+
x_n = self.norm(x)
|
|
467
|
+
x_n = x_n.transpose(1, -1)
|
|
468
|
+
|
|
469
|
+
return x_n
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
class ExponentialMovingAverage(nn.Module):
|
|
473
|
+
"""
|
|
474
|
+
Applies learnable exponential moving average, as required by learnable PCEN layer
|
|
475
|
+
|
|
476
|
+
Arguments
|
|
477
|
+
---------
|
|
478
|
+
input_size : int
|
|
479
|
+
The expected size of the input.
|
|
480
|
+
coeff_init: float
|
|
481
|
+
Initial smoothing coefficient value
|
|
482
|
+
per_channel: bool
|
|
483
|
+
Controls whether every smoothing coefficients are learned
|
|
484
|
+
independently for every input channel
|
|
485
|
+
trainable: bool
|
|
486
|
+
whether to learn the PCEN parameters or use fixed
|
|
487
|
+
skip_transpose : bool
|
|
488
|
+
If False, uses batch x time x channel convention of speechbrain.
|
|
489
|
+
If True, uses batch x channel x time convention.
|
|
490
|
+
|
|
491
|
+
Example
|
|
492
|
+
-------
|
|
493
|
+
>>> inp_tensor = torch.rand([10, 50, 40])
|
|
494
|
+
>>> pcen = ExponentialMovingAverage(40)
|
|
495
|
+
>>> out_tensor = pcen(inp_tensor)
|
|
496
|
+
>>> out_tensor.shape
|
|
497
|
+
torch.Size([10, 50, 40])
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
def __init__(
|
|
501
|
+
self,
|
|
502
|
+
input_size: int,
|
|
503
|
+
coeff_init: float = 0.04,
|
|
504
|
+
per_channel: bool = False,
|
|
505
|
+
trainable: bool = True,
|
|
506
|
+
skip_transpose: bool = False,
|
|
507
|
+
):
|
|
508
|
+
super().__init__()
|
|
509
|
+
self._coeff_init = coeff_init
|
|
510
|
+
self._per_channel = per_channel
|
|
511
|
+
self.skip_transpose = skip_transpose
|
|
512
|
+
self.trainable = trainable
|
|
513
|
+
weights = (
|
|
514
|
+
torch.ones(
|
|
515
|
+
input_size,
|
|
516
|
+
)
|
|
517
|
+
if self._per_channel
|
|
518
|
+
else torch.ones(
|
|
519
|
+
1,
|
|
520
|
+
)
|
|
521
|
+
)
|
|
522
|
+
self._weights = nn.Parameter(
|
|
523
|
+
weights * self._coeff_init, requires_grad=trainable
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
def forward(self, x):
|
|
527
|
+
"""Returns the normalized input tensor.
|
|
528
|
+
|
|
529
|
+
Arguments
|
|
530
|
+
---------
|
|
531
|
+
x : torch.Tensor (batch, time, channels)
|
|
532
|
+
input to normalize.
|
|
533
|
+
"""
|
|
534
|
+
if not self.skip_transpose:
|
|
535
|
+
x = x.transpose(1, -1)
|
|
536
|
+
w = torch.clamp(self._weights, min=0.0, max=1.0)
|
|
537
|
+
initial_state = x[:, :, 0]
|
|
538
|
+
|
|
539
|
+
def scan(init_state, x, w):
|
|
540
|
+
"""Loops and accumulates."""
|
|
541
|
+
x = x.permute(2, 0, 1)
|
|
542
|
+
acc = init_state
|
|
543
|
+
results = []
|
|
544
|
+
for ix in range(x.shape[0]):
|
|
545
|
+
acc = (w * x[ix]) + ((1.0 - w) * acc)
|
|
546
|
+
results.append(acc.unsqueeze(0))
|
|
547
|
+
results = torch.cat(results, dim=0)
|
|
548
|
+
results = results.permute(1, 2, 0)
|
|
549
|
+
return results
|
|
550
|
+
|
|
551
|
+
output = scan(initial_state, x, w)
|
|
552
|
+
if not self.skip_transpose:
|
|
553
|
+
output = output.transpose(1, -1)
|
|
554
|
+
return output
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class PCEN(nn.Module):
|
|
558
|
+
"""
|
|
559
|
+
This class implements a learnable Per-channel energy normalization (PCEN) layer, supporting both
|
|
560
|
+
original PCEN as specified in [1] as well as sPCEN as specified in [2]
|
|
561
|
+
|
|
562
|
+
[1] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For
|
|
563
|
+
Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 (https://arxiv.org/abs/1607.05666)
|
|
564
|
+
|
|
565
|
+
[2] Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND
|
|
566
|
+
FOR AUDIO CLASSIFICATION", in Proc of ICLR 2021 (https://arxiv.org/abs/2101.08596)
|
|
567
|
+
|
|
568
|
+
The default argument values correspond with those used by [2].
|
|
569
|
+
|
|
570
|
+
Arguments
|
|
571
|
+
---------
|
|
572
|
+
input_size : int
|
|
573
|
+
The expected size of the input.
|
|
574
|
+
alpha: float
|
|
575
|
+
specifies alpha coefficient for PCEN
|
|
576
|
+
smooth_coef: float
|
|
577
|
+
specified smooth coefficient for PCEN
|
|
578
|
+
delta: float
|
|
579
|
+
specifies delta coefficient for PCEN
|
|
580
|
+
root: float
|
|
581
|
+
specifies root coefficient for PCEN
|
|
582
|
+
floor: float
|
|
583
|
+
specifies floor coefficient for PCEN
|
|
584
|
+
trainable: bool
|
|
585
|
+
whether to learn the PCEN parameters or use fixed
|
|
586
|
+
per_channel_smooth_coef: bool
|
|
587
|
+
whether to learn independent smooth coefficients for every channel.
|
|
588
|
+
when True, essentially using sPCEN from [2]
|
|
589
|
+
skip_transpose : bool
|
|
590
|
+
If False, uses batch x time x channel convention of speechbrain.
|
|
591
|
+
If True, uses batch x channel x time convention.
|
|
592
|
+
|
|
593
|
+
Example
|
|
594
|
+
-------
|
|
595
|
+
>>> inp_tensor = torch.rand([10, 50, 40])
|
|
596
|
+
>>> pcen = PCEN(40, alpha=0.96) # sPCEN
|
|
597
|
+
>>> out_tensor = pcen(inp_tensor)
|
|
598
|
+
>>> out_tensor.shape
|
|
599
|
+
torch.Size([10, 50, 40])
|
|
600
|
+
"""
|
|
601
|
+
|
|
602
|
+
def __init__(
|
|
603
|
+
self,
|
|
604
|
+
input_size,
|
|
605
|
+
alpha: float = 0.96,
|
|
606
|
+
smooth_coef: float = 0.04,
|
|
607
|
+
delta: float = 2.0,
|
|
608
|
+
root: float = 2.0,
|
|
609
|
+
floor: float = 1e-12,
|
|
610
|
+
trainable: bool = True,
|
|
611
|
+
per_channel_smooth_coef: bool = True,
|
|
612
|
+
skip_transpose: bool = False,
|
|
613
|
+
):
|
|
614
|
+
super().__init__()
|
|
615
|
+
self._smooth_coef = smooth_coef
|
|
616
|
+
self._floor = floor
|
|
617
|
+
self._per_channel_smooth_coef = per_channel_smooth_coef
|
|
618
|
+
self.skip_transpose = skip_transpose
|
|
619
|
+
self.alpha = nn.Parameter(
|
|
620
|
+
torch.ones(input_size) * alpha, requires_grad=trainable
|
|
621
|
+
)
|
|
622
|
+
self.delta = nn.Parameter(
|
|
623
|
+
torch.ones(input_size) * delta, requires_grad=trainable
|
|
624
|
+
)
|
|
625
|
+
self.root = nn.Parameter(
|
|
626
|
+
torch.ones(input_size) * root, requires_grad=trainable
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
self.ema = ExponentialMovingAverage(
|
|
630
|
+
input_size,
|
|
631
|
+
coeff_init=self._smooth_coef,
|
|
632
|
+
per_channel=self._per_channel_smooth_coef,
|
|
633
|
+
skip_transpose=True,
|
|
634
|
+
trainable=trainable,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
def forward(self, x):
|
|
638
|
+
"""Returns the normalized input tensor.
|
|
639
|
+
|
|
640
|
+
Arguments
|
|
641
|
+
---------
|
|
642
|
+
x : torch.Tensor (batch, time, channels)
|
|
643
|
+
input to normalize.
|
|
644
|
+
|
|
645
|
+
Returns
|
|
646
|
+
-------
|
|
647
|
+
output : torch.Tensor
|
|
648
|
+
The normalized outputs.
|
|
649
|
+
"""
|
|
650
|
+
if not self.skip_transpose:
|
|
651
|
+
x = x.transpose(1, -1)
|
|
652
|
+
alpha = torch.min(
|
|
653
|
+
self.alpha, torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
|
654
|
+
)
|
|
655
|
+
root = torch.max(
|
|
656
|
+
self.root, torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
|
657
|
+
)
|
|
658
|
+
ema_smoother = self.ema(x)
|
|
659
|
+
one_over_root = 1.0 / root
|
|
660
|
+
output = (
|
|
661
|
+
x / (self._floor + ema_smoother) ** alpha.view(1, -1, 1)
|
|
662
|
+
+ self.delta.view(1, -1, 1)
|
|
663
|
+
) ** one_over_root.view(1, -1, 1) - self.delta.view(
|
|
664
|
+
1, -1, 1
|
|
665
|
+
) ** one_over_root.view(
|
|
666
|
+
1, -1, 1
|
|
667
|
+
)
|
|
668
|
+
if not self.skip_transpose:
|
|
669
|
+
output = output.transpose(1, -1)
|
|
670
|
+
return output
|