xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
"""Library implementing convolutional neural networks.
|
|
2
|
+
|
|
3
|
+
Authors
|
|
4
|
+
* Mirco Ravanelli 2020
|
|
5
|
+
* Jianyuan Zhong 2020
|
|
6
|
+
* Cem Subakan 2021
|
|
7
|
+
* Davide Borra 2021
|
|
8
|
+
* Andreas Nautsch 2022
|
|
9
|
+
* Sarthak Yadav 2022
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import math
|
|
14
|
+
from typing import Tuple
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
import torchaudio
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SincConv(nn.Module):
|
|
24
|
+
"""This function implements SincConv (SincNet).
|
|
25
|
+
|
|
26
|
+
M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
|
|
27
|
+
SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158)
|
|
28
|
+
|
|
29
|
+
Arguments
|
|
30
|
+
---------
|
|
31
|
+
out_channels : int
|
|
32
|
+
It is the number of output channels.
|
|
33
|
+
kernel_size: int
|
|
34
|
+
Kernel size of the convolutional filters.
|
|
35
|
+
input_shape : tuple
|
|
36
|
+
The shape of the input. Alternatively use ``in_channels``.
|
|
37
|
+
in_channels : int
|
|
38
|
+
The number of input channels. Alternatively use ``input_shape``.
|
|
39
|
+
stride : int
|
|
40
|
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
|
41
|
+
a decimation in time is performed.
|
|
42
|
+
dilation : int
|
|
43
|
+
Dilation factor of the convolutional filters.
|
|
44
|
+
padding : str
|
|
45
|
+
(same, valid, causal). If "valid", no padding is performed.
|
|
46
|
+
If "same" and stride is 1, output shape is the same as the input shape.
|
|
47
|
+
"causal" results in causal (dilated) convolutions.
|
|
48
|
+
padding_mode : str
|
|
49
|
+
This flag specifies the type of padding. See torch.nn documentation
|
|
50
|
+
for more information.
|
|
51
|
+
sample_rate : int
|
|
52
|
+
Sampling rate of the input signals. It is only used for sinc_conv.
|
|
53
|
+
min_low_hz : float
|
|
54
|
+
Lowest possible frequency (in Hz) for a filter. It is only used for
|
|
55
|
+
sinc_conv.
|
|
56
|
+
min_band_hz : float
|
|
57
|
+
Lowest possible value (in Hz) for a filter bandwidth.
|
|
58
|
+
|
|
59
|
+
Example
|
|
60
|
+
-------
|
|
61
|
+
>>> inp_tensor = torch.rand([10, 16000])
|
|
62
|
+
>>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
|
|
63
|
+
>>> out_tensor = conv(inp_tensor)
|
|
64
|
+
>>> out_tensor.shape
|
|
65
|
+
torch.Size([10, 16000, 25])
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
out_channels,
|
|
71
|
+
kernel_size,
|
|
72
|
+
input_shape=None,
|
|
73
|
+
in_channels=None,
|
|
74
|
+
stride=1,
|
|
75
|
+
dilation=1,
|
|
76
|
+
padding="same",
|
|
77
|
+
padding_mode="reflect",
|
|
78
|
+
sample_rate=16000,
|
|
79
|
+
min_low_hz=50,
|
|
80
|
+
min_band_hz=50,
|
|
81
|
+
):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.in_channels = in_channels
|
|
84
|
+
self.out_channels = out_channels
|
|
85
|
+
self.kernel_size = kernel_size
|
|
86
|
+
self.stride = stride
|
|
87
|
+
self.dilation = dilation
|
|
88
|
+
self.padding = padding
|
|
89
|
+
self.padding_mode = padding_mode
|
|
90
|
+
self.sample_rate = sample_rate
|
|
91
|
+
self.min_low_hz = min_low_hz
|
|
92
|
+
self.min_band_hz = min_band_hz
|
|
93
|
+
|
|
94
|
+
# input shape inference
|
|
95
|
+
if input_shape is None and self.in_channels is None:
|
|
96
|
+
raise ValueError("Must provide one of input_shape or in_channels")
|
|
97
|
+
|
|
98
|
+
if self.in_channels is None:
|
|
99
|
+
self.in_channels = self._check_input_shape(input_shape)
|
|
100
|
+
|
|
101
|
+
if self.out_channels % self.in_channels != 0:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"Number of output channels must be divisible by in_channels"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Initialize Sinc filters
|
|
107
|
+
self._init_sinc_conv()
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
"""Returns the output of the convolution.
|
|
111
|
+
|
|
112
|
+
Arguments
|
|
113
|
+
---------
|
|
114
|
+
x : torch.Tensor (batch, time, channel)
|
|
115
|
+
input to convolve. 2d or 4d tensors are expected.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
wx : torch.Tensor
|
|
120
|
+
The convolved outputs.
|
|
121
|
+
"""
|
|
122
|
+
x = x.transpose(1, -1)
|
|
123
|
+
self.device = x.device
|
|
124
|
+
|
|
125
|
+
unsqueeze = x.ndim == 2
|
|
126
|
+
if unsqueeze:
|
|
127
|
+
x = x.unsqueeze(1)
|
|
128
|
+
|
|
129
|
+
if self.padding == "same":
|
|
130
|
+
x = self._manage_padding(
|
|
131
|
+
x, self.kernel_size, self.dilation, self.stride
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
elif self.padding == "causal":
|
|
135
|
+
num_pad = (self.kernel_size - 1) * self.dilation
|
|
136
|
+
x = F.pad(x, (num_pad, 0))
|
|
137
|
+
|
|
138
|
+
elif self.padding == "valid":
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"Padding must be 'same', 'valid' or 'causal'. Got %s."
|
|
144
|
+
% (self.padding)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
sinc_filters = self._get_sinc_filters()
|
|
148
|
+
|
|
149
|
+
wx = F.conv1d(
|
|
150
|
+
x,
|
|
151
|
+
sinc_filters,
|
|
152
|
+
stride=self.stride,
|
|
153
|
+
padding=0,
|
|
154
|
+
dilation=self.dilation,
|
|
155
|
+
groups=self.in_channels,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if unsqueeze:
|
|
159
|
+
wx = wx.squeeze(1)
|
|
160
|
+
|
|
161
|
+
wx = wx.transpose(1, -1)
|
|
162
|
+
|
|
163
|
+
return wx
|
|
164
|
+
|
|
165
|
+
def _check_input_shape(self, shape):
|
|
166
|
+
"""Checks the input shape and returns the number of input channels."""
|
|
167
|
+
|
|
168
|
+
if len(shape) == 2:
|
|
169
|
+
in_channels = 1
|
|
170
|
+
elif len(shape) == 3:
|
|
171
|
+
in_channels = shape[-1]
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"sincconv expects 2d or 3d inputs. Got " + str(len(shape))
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Kernel size must be odd
|
|
178
|
+
if self.kernel_size % 2 == 0:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"The field kernel size must be an odd number. Got %s."
|
|
181
|
+
% (self.kernel_size)
|
|
182
|
+
)
|
|
183
|
+
return in_channels
|
|
184
|
+
|
|
185
|
+
def _get_sinc_filters(self):
|
|
186
|
+
"""This functions creates the sinc-filters to used for sinc-conv."""
|
|
187
|
+
# Computing the low frequencies of the filters
|
|
188
|
+
low = self.min_low_hz + torch.abs(self.low_hz_)
|
|
189
|
+
|
|
190
|
+
# Setting minimum band and minimum freq
|
|
191
|
+
high = torch.clamp(
|
|
192
|
+
low + self.min_band_hz + torch.abs(self.band_hz_),
|
|
193
|
+
self.min_low_hz,
|
|
194
|
+
self.sample_rate / 2,
|
|
195
|
+
)
|
|
196
|
+
band = (high - low)[:, 0]
|
|
197
|
+
|
|
198
|
+
# Passing from n_ to the corresponding f_times_t domain
|
|
199
|
+
self.n_ = self.n_.to(self.device)
|
|
200
|
+
self.window_ = self.window_.to(self.device)
|
|
201
|
+
f_times_t_low = torch.matmul(low, self.n_)
|
|
202
|
+
f_times_t_high = torch.matmul(high, self.n_)
|
|
203
|
+
|
|
204
|
+
# Left part of the filters.
|
|
205
|
+
band_pass_left = (
|
|
206
|
+
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
|
|
207
|
+
/ (self.n_ / 2)
|
|
208
|
+
) * self.window_
|
|
209
|
+
|
|
210
|
+
# Central element of the filter
|
|
211
|
+
band_pass_center = 2 * band.view(-1, 1)
|
|
212
|
+
|
|
213
|
+
# Right part of the filter (sinc filters are symmetric)
|
|
214
|
+
band_pass_right = torch.flip(band_pass_left, dims=[1])
|
|
215
|
+
|
|
216
|
+
# Combining left, central, and right part of the filter
|
|
217
|
+
band_pass = torch.cat(
|
|
218
|
+
[band_pass_left, band_pass_center, band_pass_right], dim=1
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Amplitude normalization
|
|
222
|
+
band_pass = band_pass / (2 * band[:, None])
|
|
223
|
+
|
|
224
|
+
# Setting up the filter coefficients
|
|
225
|
+
filters = band_pass.view(self.out_channels, 1, self.kernel_size)
|
|
226
|
+
|
|
227
|
+
return filters
|
|
228
|
+
|
|
229
|
+
def _init_sinc_conv(self):
|
|
230
|
+
"""Initializes the parameters of the sinc_conv layer."""
|
|
231
|
+
|
|
232
|
+
# Initialize filterbanks such that they are equally spaced in Mel scale
|
|
233
|
+
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
|
|
234
|
+
|
|
235
|
+
mel = torch.linspace(
|
|
236
|
+
self._to_mel(self.min_low_hz),
|
|
237
|
+
self._to_mel(high_hz),
|
|
238
|
+
self.out_channels + 1,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
hz = self._to_hz(mel)
|
|
242
|
+
|
|
243
|
+
# Filter lower frequency and bands
|
|
244
|
+
self.low_hz_ = hz[:-1].unsqueeze(1)
|
|
245
|
+
self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)
|
|
246
|
+
|
|
247
|
+
# Maiking freq and bands learnable
|
|
248
|
+
self.low_hz_ = nn.Parameter(self.low_hz_)
|
|
249
|
+
self.band_hz_ = nn.Parameter(self.band_hz_)
|
|
250
|
+
|
|
251
|
+
# Hamming window
|
|
252
|
+
n_lin = torch.linspace(
|
|
253
|
+
0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
|
|
254
|
+
)
|
|
255
|
+
self.window_ = 0.54 - 0.46 * torch.cos(
|
|
256
|
+
2 * math.pi * n_lin / self.kernel_size
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Time axis (only half is needed due to symmetry)
|
|
260
|
+
n = (self.kernel_size - 1) / 2.0
|
|
261
|
+
self.n_ = (
|
|
262
|
+
2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def _to_mel(self, hz):
|
|
266
|
+
"""Converts frequency in Hz to the mel scale."""
|
|
267
|
+
return 2595 * np.log10(1 + hz / 700)
|
|
268
|
+
|
|
269
|
+
def _to_hz(self, mel):
|
|
270
|
+
"""Converts frequency in the mel scale to Hz."""
|
|
271
|
+
return 700 * (10 ** (mel / 2595) - 1)
|
|
272
|
+
|
|
273
|
+
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
|
274
|
+
"""This function performs zero-padding on the time axis
|
|
275
|
+
such that their lengths is unchanged after the convolution.
|
|
276
|
+
|
|
277
|
+
Arguments
|
|
278
|
+
---------
|
|
279
|
+
x : torch.Tensor
|
|
280
|
+
Input tensor.
|
|
281
|
+
kernel_size : int
|
|
282
|
+
Size of kernel.
|
|
283
|
+
dilation : int
|
|
284
|
+
Dilation used.
|
|
285
|
+
stride : int
|
|
286
|
+
Stride.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
x : torch.Tensor
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
# Detecting input shape
|
|
294
|
+
L_in = self.in_channels
|
|
295
|
+
|
|
296
|
+
# Time padding
|
|
297
|
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
|
298
|
+
|
|
299
|
+
# Applying padding
|
|
300
|
+
x = F.pad(x, padding, mode=self.padding_mode)
|
|
301
|
+
|
|
302
|
+
return x
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class Conv1d(nn.Module):
|
|
306
|
+
"""This function implements 1d convolution.
|
|
307
|
+
|
|
308
|
+
Arguments
|
|
309
|
+
---------
|
|
310
|
+
out_channels : int
|
|
311
|
+
It is the number of output channels.
|
|
312
|
+
kernel_size : int
|
|
313
|
+
Kernel size of the convolutional filters.
|
|
314
|
+
input_shape : tuple
|
|
315
|
+
The shape of the input. Alternatively use ``in_channels``.
|
|
316
|
+
in_channels : int
|
|
317
|
+
The number of input channels. Alternatively use ``input_shape``.
|
|
318
|
+
stride : int
|
|
319
|
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
|
320
|
+
a decimation in time is performed.
|
|
321
|
+
dilation : int
|
|
322
|
+
Dilation factor of the convolutional filters.
|
|
323
|
+
padding : str
|
|
324
|
+
(same, valid, causal). If "valid", no padding is performed.
|
|
325
|
+
If "same" and stride is 1, output shape is the same as the input shape.
|
|
326
|
+
"causal" results in causal (dilated) convolutions.
|
|
327
|
+
groups : int
|
|
328
|
+
Number of blocked connections from input channels to output channels.
|
|
329
|
+
bias : bool
|
|
330
|
+
Whether to add a bias term to convolution operation.
|
|
331
|
+
padding_mode : str
|
|
332
|
+
This flag specifies the type of padding. See torch.nn documentation
|
|
333
|
+
for more information.
|
|
334
|
+
skip_transpose : bool
|
|
335
|
+
If False, uses batch x time x channel convention of speechbrain.
|
|
336
|
+
If True, uses batch x channel x time convention.
|
|
337
|
+
weight_norm : bool
|
|
338
|
+
If True, use weight normalization,
|
|
339
|
+
to be removed with self.remove_weight_norm() at inference
|
|
340
|
+
conv_init : str
|
|
341
|
+
Weight initialization for the convolution network
|
|
342
|
+
default_padding: str or int
|
|
343
|
+
This sets the default padding mode that will be used by the pytorch Conv1d backend.
|
|
344
|
+
|
|
345
|
+
Example
|
|
346
|
+
-------
|
|
347
|
+
>>> inp_tensor = torch.rand([10, 40, 16])
|
|
348
|
+
>>> cnn_1d = Conv1d(
|
|
349
|
+
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
|
|
350
|
+
... )
|
|
351
|
+
>>> out_tensor = cnn_1d(inp_tensor)
|
|
352
|
+
>>> out_tensor.shape
|
|
353
|
+
torch.Size([10, 40, 8])
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(
|
|
357
|
+
self,
|
|
358
|
+
out_channels,
|
|
359
|
+
kernel_size,
|
|
360
|
+
input_shape=None,
|
|
361
|
+
in_channels=None,
|
|
362
|
+
stride=1,
|
|
363
|
+
dilation=1,
|
|
364
|
+
padding="same",
|
|
365
|
+
groups=1,
|
|
366
|
+
bias=True,
|
|
367
|
+
padding_mode="reflect",
|
|
368
|
+
skip_transpose=False,
|
|
369
|
+
weight_norm=False,
|
|
370
|
+
conv_init=None,
|
|
371
|
+
default_padding=0,
|
|
372
|
+
):
|
|
373
|
+
super().__init__()
|
|
374
|
+
self.kernel_size = kernel_size
|
|
375
|
+
self.stride = stride
|
|
376
|
+
self.dilation = dilation
|
|
377
|
+
self.padding = padding
|
|
378
|
+
self.padding_mode = padding_mode
|
|
379
|
+
self.unsqueeze = False
|
|
380
|
+
self.skip_transpose = skip_transpose
|
|
381
|
+
|
|
382
|
+
if input_shape is None and in_channels is None:
|
|
383
|
+
raise ValueError("Must provide one of input_shape or in_channels")
|
|
384
|
+
|
|
385
|
+
if in_channels is None:
|
|
386
|
+
in_channels = self._check_input_shape(input_shape)
|
|
387
|
+
|
|
388
|
+
self.in_channels = in_channels
|
|
389
|
+
|
|
390
|
+
self.conv = nn.Conv1d(
|
|
391
|
+
in_channels,
|
|
392
|
+
out_channels,
|
|
393
|
+
self.kernel_size,
|
|
394
|
+
stride=self.stride,
|
|
395
|
+
dilation=self.dilation,
|
|
396
|
+
padding=default_padding,
|
|
397
|
+
groups=groups,
|
|
398
|
+
bias=bias,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if conv_init == "kaiming":
|
|
402
|
+
nn.init.kaiming_normal_(self.conv.weight)
|
|
403
|
+
elif conv_init == "zero":
|
|
404
|
+
nn.init.zeros_(self.conv.weight)
|
|
405
|
+
elif conv_init == "normal":
|
|
406
|
+
nn.init.normal_(self.conv.weight, std=1e-6)
|
|
407
|
+
|
|
408
|
+
if weight_norm:
|
|
409
|
+
self.conv = nn.utils.weight_norm(self.conv)
|
|
410
|
+
|
|
411
|
+
def forward(self, x):
|
|
412
|
+
"""Returns the output of the convolution.
|
|
413
|
+
|
|
414
|
+
Arguments
|
|
415
|
+
---------
|
|
416
|
+
x : torch.Tensor (batch, time, channel)
|
|
417
|
+
input to convolve. 2d or 4d tensors are expected.
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
wx : torch.Tensor
|
|
422
|
+
The convolved outputs.
|
|
423
|
+
"""
|
|
424
|
+
if not self.skip_transpose:
|
|
425
|
+
x = x.transpose(1, -1)
|
|
426
|
+
|
|
427
|
+
if self.unsqueeze:
|
|
428
|
+
x = x.unsqueeze(1)
|
|
429
|
+
|
|
430
|
+
if self.padding == "same":
|
|
431
|
+
x = self._manage_padding(
|
|
432
|
+
x, self.kernel_size, self.dilation, self.stride
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
elif self.padding == "causal":
|
|
436
|
+
num_pad = (self.kernel_size - 1) * self.dilation
|
|
437
|
+
x = F.pad(x, (num_pad, 0))
|
|
438
|
+
|
|
439
|
+
elif self.padding == "valid":
|
|
440
|
+
pass
|
|
441
|
+
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
"Padding must be 'same', 'valid' or 'causal'. Got "
|
|
445
|
+
+ self.padding
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
wx = self.conv(x)
|
|
449
|
+
|
|
450
|
+
if self.unsqueeze:
|
|
451
|
+
wx = wx.squeeze(1)
|
|
452
|
+
|
|
453
|
+
if not self.skip_transpose:
|
|
454
|
+
wx = wx.transpose(1, -1)
|
|
455
|
+
|
|
456
|
+
return wx
|
|
457
|
+
|
|
458
|
+
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
|
459
|
+
"""This function performs zero-padding on the time axis
|
|
460
|
+
such that their lengths is unchanged after the convolution.
|
|
461
|
+
|
|
462
|
+
Arguments
|
|
463
|
+
---------
|
|
464
|
+
x : torch.Tensor
|
|
465
|
+
Input tensor.
|
|
466
|
+
kernel_size : int
|
|
467
|
+
Size of kernel.
|
|
468
|
+
dilation : int
|
|
469
|
+
Dilation used.
|
|
470
|
+
stride : int
|
|
471
|
+
Stride.
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
x : torch.Tensor
|
|
476
|
+
The padded outputs.
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
# Detecting input shape
|
|
480
|
+
L_in = self.in_channels
|
|
481
|
+
|
|
482
|
+
# Time padding
|
|
483
|
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
|
484
|
+
|
|
485
|
+
# Applying padding
|
|
486
|
+
x = F.pad(x, padding, mode=self.padding_mode)
|
|
487
|
+
|
|
488
|
+
return x
|
|
489
|
+
|
|
490
|
+
def _check_input_shape(self, shape):
|
|
491
|
+
"""Checks the input shape and returns the number of input channels."""
|
|
492
|
+
|
|
493
|
+
if len(shape) == 2:
|
|
494
|
+
self.unsqueeze = True
|
|
495
|
+
in_channels = 1
|
|
496
|
+
elif self.skip_transpose:
|
|
497
|
+
in_channels = shape[1]
|
|
498
|
+
elif len(shape) == 3:
|
|
499
|
+
in_channels = shape[2]
|
|
500
|
+
else:
|
|
501
|
+
raise ValueError(
|
|
502
|
+
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Kernel size must be odd
|
|
506
|
+
if not self.padding == "valid" and self.kernel_size % 2 == 0:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"The field kernel size must be an odd number. Got %s."
|
|
509
|
+
% (self.kernel_size)
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return in_channels
|
|
513
|
+
|
|
514
|
+
def remove_weight_norm(self):
|
|
515
|
+
"""Removes weight normalization at inference if used during training."""
|
|
516
|
+
self.conv = nn.utils.remove_weight_norm(self.conv)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
|
520
|
+
"""This function computes the number of elements to add for zero-padding.
|
|
521
|
+
|
|
522
|
+
Arguments
|
|
523
|
+
---------
|
|
524
|
+
L_in : int
|
|
525
|
+
stride: int
|
|
526
|
+
kernel_size : int
|
|
527
|
+
dilation : int
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
padding : int
|
|
532
|
+
The size of the padding to be added
|
|
533
|
+
"""
|
|
534
|
+
if stride > 1:
|
|
535
|
+
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
|
536
|
+
|
|
537
|
+
else:
|
|
538
|
+
L_out = (
|
|
539
|
+
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
|
|
540
|
+
)
|
|
541
|
+
padding = [
|
|
542
|
+
math.floor((L_in - L_out) / 2),
|
|
543
|
+
math.floor((L_in - L_out) / 2),
|
|
544
|
+
]
|
|
545
|
+
return padding
|
|
546
|
+
|
|
File without changes
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Library implementing linear transformation.
|
|
2
|
+
|
|
3
|
+
Authors
|
|
4
|
+
* Mirco Ravanelli 2020
|
|
5
|
+
* Davide Borra 2021
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Linear(torch.nn.Module):
|
|
15
|
+
"""Computes a linear transformation y = wx + b.
|
|
16
|
+
|
|
17
|
+
Arguments
|
|
18
|
+
---------
|
|
19
|
+
n_neurons : int
|
|
20
|
+
It is the number of output neurons (i.e, the dimensionality of the
|
|
21
|
+
output).
|
|
22
|
+
input_shape : tuple
|
|
23
|
+
It is the shape of the input tensor.
|
|
24
|
+
input_size : int
|
|
25
|
+
Size of the input tensor.
|
|
26
|
+
bias : bool
|
|
27
|
+
If True, the additive bias b is adopted.
|
|
28
|
+
max_norm : float
|
|
29
|
+
weight max-norm.
|
|
30
|
+
combine_dims : bool
|
|
31
|
+
If True and the input is 4D, combine 3rd and 4th dimensions of input.
|
|
32
|
+
|
|
33
|
+
Example
|
|
34
|
+
-------
|
|
35
|
+
>>> inputs = torch.rand(10, 50, 40)
|
|
36
|
+
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
|
|
37
|
+
>>> output = lin_t(inputs)
|
|
38
|
+
>>> output.shape
|
|
39
|
+
torch.Size([10, 50, 100])
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
n_neurons,
|
|
45
|
+
input_shape=None,
|
|
46
|
+
input_size=None,
|
|
47
|
+
bias=True,
|
|
48
|
+
max_norm=None,
|
|
49
|
+
combine_dims=False,
|
|
50
|
+
):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.max_norm = max_norm
|
|
53
|
+
self.combine_dims = combine_dims
|
|
54
|
+
|
|
55
|
+
if input_shape is None and input_size is None:
|
|
56
|
+
raise ValueError("Expected one of input_shape or input_size")
|
|
57
|
+
|
|
58
|
+
if input_size is None:
|
|
59
|
+
input_size = input_shape[-1]
|
|
60
|
+
if len(input_shape) == 4 and self.combine_dims:
|
|
61
|
+
input_size = input_shape[2] * input_shape[3]
|
|
62
|
+
|
|
63
|
+
# Weights are initialized following pytorch approach
|
|
64
|
+
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
|
65
|
+
|
|
66
|
+
def forward(self, x):
|
|
67
|
+
"""Returns the linear transformation of input tensor.
|
|
68
|
+
|
|
69
|
+
Arguments
|
|
70
|
+
---------
|
|
71
|
+
x : torch.Tensor
|
|
72
|
+
Input to transform linearly.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
wx : torch.Tensor
|
|
77
|
+
The linearly transformed outputs.
|
|
78
|
+
"""
|
|
79
|
+
if x.ndim == 4 and self.combine_dims:
|
|
80
|
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
|
81
|
+
|
|
82
|
+
if self.max_norm is not None:
|
|
83
|
+
self.w.weight.data = torch.renorm(
|
|
84
|
+
self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
wx = self.w(x)
|
|
88
|
+
|
|
89
|
+
return wx
|