xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
|
|
4
|
+
# Licensed under Apache License 2.0
|
|
5
|
+
|
|
6
|
+
from .modules.seanet import SEANetEncoder, SEANetDecoder
|
|
7
|
+
from .modules.quantization import ResidualVectorQuantizer
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
import torch
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SpeechTokenizer(nn.Module):
|
|
15
|
+
def __init__(self, config):
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
config : json
|
|
21
|
+
Model Config.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.encoder = SEANetEncoder(
|
|
26
|
+
n_filters=config.get("n_filters"),
|
|
27
|
+
dimension=config.get("dimension"),
|
|
28
|
+
ratios=config.get("strides"),
|
|
29
|
+
lstm=config.get("lstm_layers"),
|
|
30
|
+
bidirectional=config.get("bidirectional"),
|
|
31
|
+
dilation_base=config.get("dilation_base"),
|
|
32
|
+
residual_kernel_size=config.get("residual_kernel_size"),
|
|
33
|
+
n_residual_layers=config.get("n_residual_layers"),
|
|
34
|
+
activation=config.get("activation"),
|
|
35
|
+
)
|
|
36
|
+
self.sample_rate = config.get("sample_rate")
|
|
37
|
+
self.n_q = config.get("n_q")
|
|
38
|
+
self.downsample_rate = np.prod(config.get("strides"))
|
|
39
|
+
if config.get("dimension") != config.get("semantic_dimension"):
|
|
40
|
+
self.transform = nn.Linear(
|
|
41
|
+
config.get("dimension"), config.get("semantic_dimension")
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
self.transform = nn.Identity()
|
|
45
|
+
self.quantizer = ResidualVectorQuantizer(
|
|
46
|
+
dimension=config.get("dimension"),
|
|
47
|
+
n_q=config.get("n_q"),
|
|
48
|
+
bins=config.get("codebook_size"),
|
|
49
|
+
)
|
|
50
|
+
self.decoder = SEANetDecoder(
|
|
51
|
+
n_filters=config.get("n_filters"),
|
|
52
|
+
dimension=config.get("dimension"),
|
|
53
|
+
ratios=config.get("strides"),
|
|
54
|
+
lstm=config.get("lstm_layers"),
|
|
55
|
+
bidirectional=False,
|
|
56
|
+
dilation_base=config.get("dilation_base"),
|
|
57
|
+
residual_kernel_size=config.get("residual_kernel_size"),
|
|
58
|
+
n_residual_layers=config.get("n_residual_layers"),
|
|
59
|
+
activation=config.get("activation"),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
config_path : str
|
|
69
|
+
Path of model configuration file.
|
|
70
|
+
ckpt_path : str
|
|
71
|
+
Path of model checkpoint.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
model : SpeechTokenizer
|
|
76
|
+
SpeechTokenizer model.
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
import json
|
|
80
|
+
|
|
81
|
+
with open(config_path) as f:
|
|
82
|
+
cfg = json.load(f)
|
|
83
|
+
model = cls(cfg)
|
|
84
|
+
params = torch.load(ckpt_path, map_location="cpu")
|
|
85
|
+
model.load_state_dict(params)
|
|
86
|
+
return model
|
|
87
|
+
|
|
88
|
+
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
x : torch.tensor
|
|
94
|
+
Input wavs. Shape: (batch, channels, timesteps).
|
|
95
|
+
n_q : int, optional
|
|
96
|
+
Number of quantizers in RVQ used to encode. The default is all layers.
|
|
97
|
+
layers : list[int], optional
|
|
98
|
+
Layers of RVQ should return quantized result. The default is the first layer.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
o : torch.tensor
|
|
103
|
+
Output wavs. Shape: (batch, channels, timesteps).
|
|
104
|
+
commit_loss : torch.tensor
|
|
105
|
+
Commitment loss from residual vector quantizers.
|
|
106
|
+
feature : torch.tensor
|
|
107
|
+
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
n_q = n_q if n_q else self.n_q
|
|
111
|
+
e = self.encoder(x)
|
|
112
|
+
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
|
113
|
+
e, n_q=n_q, layers=layers
|
|
114
|
+
)
|
|
115
|
+
feature = rearrange(quantized_list[0], "b d t -> b t d")
|
|
116
|
+
feature = self.transform(feature)
|
|
117
|
+
o = self.decoder(quantized)
|
|
118
|
+
return o, commit_loss, feature
|
|
119
|
+
|
|
120
|
+
def forward_feature(self, x: torch.tensor, layers: list = None):
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
x : torch.tensor
|
|
126
|
+
Input wavs. Shape should be (batch, channels, timesteps).
|
|
127
|
+
layers : list[int], optional
|
|
128
|
+
Layers of RVQ should return quantized result. The default is all layers.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
quantized_list : list[torch.tensor]
|
|
133
|
+
Quantized of required layers.
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
e = self.encoder(x)
|
|
137
|
+
layers = layers if layers else list(range(self.n_q))
|
|
138
|
+
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
|
|
139
|
+
return quantized_list
|
|
140
|
+
|
|
141
|
+
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
x : torch.tensor
|
|
147
|
+
Input wavs. Shape: (batch, channels, timesteps).
|
|
148
|
+
n_q : int, optional
|
|
149
|
+
Number of quantizers in RVQ used to encode. The default is all layers.
|
|
150
|
+
st : int, optional
|
|
151
|
+
Start quantizer index in RVQ. The default is 0.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
codes : torch.tensor
|
|
156
|
+
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
e = self.encoder(x)
|
|
160
|
+
if st is None:
|
|
161
|
+
st = 0
|
|
162
|
+
n_q = n_q if n_q else self.n_q
|
|
163
|
+
codes = self.quantizer.encode(e, n_q=n_q, st=st)
|
|
164
|
+
return codes
|
|
165
|
+
|
|
166
|
+
def decode(self, codes: torch.tensor, st: int = 0):
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
codes : torch.tensor
|
|
172
|
+
Indices for each quantizer. Shape: (n_q, batch, timesteps).
|
|
173
|
+
st : int, optional
|
|
174
|
+
Start quantizer index in RVQ. The default is 0.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
o : torch.tensor
|
|
179
|
+
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
quantized = self.quantizer.decode(codes, st=st)
|
|
183
|
+
o = self.decoder(quantized)
|
|
184
|
+
return o
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Torch modules."""
|
|
14
|
+
|
|
15
|
+
# flake8: noqa
|
|
16
|
+
from .conv import (
|
|
17
|
+
pad1d,
|
|
18
|
+
unpad1d,
|
|
19
|
+
NormConv1d,
|
|
20
|
+
NormConvTranspose1d,
|
|
21
|
+
NormConv2d,
|
|
22
|
+
NormConvTranspose2d,
|
|
23
|
+
SConv1d,
|
|
24
|
+
SConvTranspose1d,
|
|
25
|
+
)
|
|
26
|
+
from .lstm import SLSTM
|
|
27
|
+
from .seanet import SEANetEncoder, SEANetDecoder
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Convolutional layers wrappers and utilities."""
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import typing as tp
|
|
17
|
+
import warnings
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn
|
|
21
|
+
from torch.nn import functional as F
|
|
22
|
+
from torch.nn.utils import spectral_norm, weight_norm
|
|
23
|
+
|
|
24
|
+
from .norm import ConvLayerNorm
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
CONV_NORMALIZATIONS = frozenset(
|
|
28
|
+
[
|
|
29
|
+
"none",
|
|
30
|
+
"weight_norm",
|
|
31
|
+
"spectral_norm",
|
|
32
|
+
"time_layer_norm",
|
|
33
|
+
"layer_norm",
|
|
34
|
+
"time_group_norm",
|
|
35
|
+
]
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
|
40
|
+
assert norm in CONV_NORMALIZATIONS
|
|
41
|
+
if norm == "weight_norm":
|
|
42
|
+
return weight_norm(module)
|
|
43
|
+
elif norm == "spectral_norm":
|
|
44
|
+
return spectral_norm(module)
|
|
45
|
+
else:
|
|
46
|
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
|
47
|
+
# doesn't need reparametrization.
|
|
48
|
+
return module
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_norm_module(
|
|
52
|
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
|
53
|
+
) -> nn.Module:
|
|
54
|
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
|
55
|
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
|
56
|
+
"""
|
|
57
|
+
assert norm in CONV_NORMALIZATIONS
|
|
58
|
+
if norm == "layer_norm":
|
|
59
|
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
|
60
|
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
|
61
|
+
elif norm == "time_group_norm":
|
|
62
|
+
if causal:
|
|
63
|
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
|
64
|
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
|
65
|
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
|
66
|
+
else:
|
|
67
|
+
return nn.Identity()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_extra_padding_for_conv1d(
|
|
71
|
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
72
|
+
) -> int:
|
|
73
|
+
"""See `pad_for_conv1d`."""
|
|
74
|
+
length = x.shape[-1]
|
|
75
|
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
76
|
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
77
|
+
return ideal_length - length
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def pad_for_conv1d(
|
|
81
|
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
82
|
+
):
|
|
83
|
+
"""Pad for a convolution to make sure that the last window is full.
|
|
84
|
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
|
85
|
+
an output of the same length, as otherwise, even with padding, some time steps
|
|
86
|
+
might get removed.
|
|
87
|
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
|
88
|
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
|
89
|
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
|
90
|
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
|
91
|
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
|
92
|
+
"""
|
|
93
|
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
94
|
+
return F.pad(x, (0, extra_padding))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def pad1d(
|
|
98
|
+
x: torch.Tensor,
|
|
99
|
+
paddings: tp.Tuple[int, int],
|
|
100
|
+
mode: str = "zero",
|
|
101
|
+
value: float = 0.0,
|
|
102
|
+
):
|
|
103
|
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
104
|
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
|
105
|
+
"""
|
|
106
|
+
length = x.shape[-1]
|
|
107
|
+
padding_left, padding_right = paddings
|
|
108
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
109
|
+
if mode == "reflect":
|
|
110
|
+
max_pad = max(padding_left, padding_right)
|
|
111
|
+
extra_pad = 0
|
|
112
|
+
if length <= max_pad:
|
|
113
|
+
extra_pad = max_pad - length + 1
|
|
114
|
+
x = F.pad(x, (0, extra_pad))
|
|
115
|
+
padded = F.pad(x, paddings, mode, value)
|
|
116
|
+
end = padded.shape[-1] - extra_pad
|
|
117
|
+
return padded[..., :end]
|
|
118
|
+
else:
|
|
119
|
+
return F.pad(x, paddings, mode, value)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
123
|
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
|
124
|
+
padding_left, padding_right = paddings
|
|
125
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
126
|
+
assert (padding_left + padding_right) <= x.shape[-1]
|
|
127
|
+
end = x.shape[-1] - padding_right
|
|
128
|
+
return x[..., padding_left:end]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class NormConv1d(nn.Module):
|
|
132
|
+
"""Wrapper around Conv1d and normalization applied to this conv
|
|
133
|
+
to provide a uniform interface across normalization approaches.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
*args,
|
|
139
|
+
causal: bool = False,
|
|
140
|
+
norm: str = "none",
|
|
141
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
142
|
+
**kwargs,
|
|
143
|
+
):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
|
146
|
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
|
147
|
+
self.norm_type = norm
|
|
148
|
+
|
|
149
|
+
def forward(self, x):
|
|
150
|
+
x = self.conv(x)
|
|
151
|
+
x = self.norm(x)
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class NormConv2d(nn.Module):
|
|
156
|
+
"""Wrapper around Conv2d and normalization applied to this conv
|
|
157
|
+
to provide a uniform interface across normalization approaches.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
*args,
|
|
163
|
+
norm: str = "none",
|
|
164
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
165
|
+
**kwargs,
|
|
166
|
+
):
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
|
169
|
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
|
170
|
+
self.norm_type = norm
|
|
171
|
+
|
|
172
|
+
def forward(self, x):
|
|
173
|
+
x = self.conv(x)
|
|
174
|
+
x = self.norm(x)
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class NormConvTranspose1d(nn.Module):
|
|
179
|
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
|
180
|
+
to provide a uniform interface across normalization approaches.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
*args,
|
|
186
|
+
causal: bool = False,
|
|
187
|
+
norm: str = "none",
|
|
188
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
189
|
+
**kwargs,
|
|
190
|
+
):
|
|
191
|
+
super().__init__()
|
|
192
|
+
self.convtr = apply_parametrization_norm(
|
|
193
|
+
nn.ConvTranspose1d(*args, **kwargs), norm
|
|
194
|
+
)
|
|
195
|
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
|
196
|
+
self.norm_type = norm
|
|
197
|
+
|
|
198
|
+
def forward(self, x):
|
|
199
|
+
x = self.convtr(x)
|
|
200
|
+
x = self.norm(x)
|
|
201
|
+
return x
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class NormConvTranspose2d(nn.Module):
|
|
205
|
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
|
206
|
+
to provide a uniform interface across normalization approaches.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
*args,
|
|
212
|
+
norm: str = "none",
|
|
213
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
214
|
+
**kwargs,
|
|
215
|
+
):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.convtr = apply_parametrization_norm(
|
|
218
|
+
nn.ConvTranspose2d(*args, **kwargs), norm
|
|
219
|
+
)
|
|
220
|
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
|
221
|
+
|
|
222
|
+
def forward(self, x):
|
|
223
|
+
x = self.convtr(x)
|
|
224
|
+
x = self.norm(x)
|
|
225
|
+
return x
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class SConv1d(nn.Module):
|
|
229
|
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
|
230
|
+
and normalization.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
in_channels: int,
|
|
236
|
+
out_channels: int,
|
|
237
|
+
kernel_size: int,
|
|
238
|
+
stride: int = 1,
|
|
239
|
+
dilation: int = 1,
|
|
240
|
+
groups: int = 1,
|
|
241
|
+
bias: bool = True,
|
|
242
|
+
causal: bool = False,
|
|
243
|
+
norm: str = "none",
|
|
244
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
245
|
+
pad_mode: str = "reflect",
|
|
246
|
+
):
|
|
247
|
+
super().__init__()
|
|
248
|
+
# warn user on unusual setup between dilation and stride
|
|
249
|
+
if stride > 1 and dilation > 1:
|
|
250
|
+
warnings.warn(
|
|
251
|
+
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
|
252
|
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
|
253
|
+
)
|
|
254
|
+
self.conv = NormConv1d(
|
|
255
|
+
in_channels,
|
|
256
|
+
out_channels,
|
|
257
|
+
kernel_size,
|
|
258
|
+
stride,
|
|
259
|
+
dilation=dilation,
|
|
260
|
+
groups=groups,
|
|
261
|
+
bias=bias,
|
|
262
|
+
causal=causal,
|
|
263
|
+
norm=norm,
|
|
264
|
+
norm_kwargs=norm_kwargs,
|
|
265
|
+
)
|
|
266
|
+
self.causal = causal
|
|
267
|
+
self.pad_mode = pad_mode
|
|
268
|
+
|
|
269
|
+
def forward(self, x):
|
|
270
|
+
B, C, T = x.shape
|
|
271
|
+
kernel_size = self.conv.conv.kernel_size[0]
|
|
272
|
+
stride = self.conv.conv.stride[0]
|
|
273
|
+
dilation = self.conv.conv.dilation[0]
|
|
274
|
+
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
|
275
|
+
extra_padding = get_extra_padding_for_conv1d(
|
|
276
|
+
x, kernel_size, stride, padding_total
|
|
277
|
+
)
|
|
278
|
+
if self.causal:
|
|
279
|
+
# Left padding for causal
|
|
280
|
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
|
281
|
+
else:
|
|
282
|
+
# Asymmetric padding required for odd strides
|
|
283
|
+
padding_right = padding_total // 2
|
|
284
|
+
padding_left = padding_total - padding_right
|
|
285
|
+
x = pad1d(
|
|
286
|
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
|
287
|
+
)
|
|
288
|
+
return self.conv(x)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class SConvTranspose1d(nn.Module):
|
|
292
|
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
|
293
|
+
and normalization.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
in_channels: int,
|
|
299
|
+
out_channels: int,
|
|
300
|
+
kernel_size: int,
|
|
301
|
+
stride: int = 1,
|
|
302
|
+
causal: bool = False,
|
|
303
|
+
norm: str = "none",
|
|
304
|
+
trim_right_ratio: float = 1.0,
|
|
305
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
306
|
+
):
|
|
307
|
+
super().__init__()
|
|
308
|
+
self.convtr = NormConvTranspose1d(
|
|
309
|
+
in_channels,
|
|
310
|
+
out_channels,
|
|
311
|
+
kernel_size,
|
|
312
|
+
stride,
|
|
313
|
+
causal=causal,
|
|
314
|
+
norm=norm,
|
|
315
|
+
norm_kwargs=norm_kwargs,
|
|
316
|
+
)
|
|
317
|
+
self.causal = causal
|
|
318
|
+
self.trim_right_ratio = trim_right_ratio
|
|
319
|
+
assert (
|
|
320
|
+
self.causal or self.trim_right_ratio == 1.0
|
|
321
|
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
|
322
|
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
|
323
|
+
|
|
324
|
+
def forward(self, x):
|
|
325
|
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
|
326
|
+
stride = self.convtr.convtr.stride[0]
|
|
327
|
+
padding_total = kernel_size - stride
|
|
328
|
+
|
|
329
|
+
y = self.convtr(x)
|
|
330
|
+
|
|
331
|
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
|
332
|
+
# removed at the very end, when keeping only the right length for the output,
|
|
333
|
+
# as removing it here would require also passing the length at the matching layer
|
|
334
|
+
# in the encoder.
|
|
335
|
+
if self.causal:
|
|
336
|
+
# Trim the padding on the right according to the specified ratio
|
|
337
|
+
# if trim_right_ratio = 1.0, trim everything from right
|
|
338
|
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
|
339
|
+
padding_left = padding_total - padding_right
|
|
340
|
+
y = unpad1d(y, (padding_left, padding_right))
|
|
341
|
+
else:
|
|
342
|
+
# Asymmetric padding required for odd strides
|
|
343
|
+
padding_right = padding_total // 2
|
|
344
|
+
padding_left = padding_total - padding_right
|
|
345
|
+
y = unpad1d(y, (padding_left, padding_right))
|
|
346
|
+
return y
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""LSTM layers module."""
|
|
14
|
+
|
|
15
|
+
from torch import nn
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SLSTM(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
|
21
|
+
Expects input as convolutional layout.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
dimension: int,
|
|
27
|
+
num_layers: int = 2,
|
|
28
|
+
skip: bool = True,
|
|
29
|
+
bidirectional: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.bidirectional = bidirectional
|
|
33
|
+
self.skip = skip
|
|
34
|
+
self.lstm = nn.LSTM(
|
|
35
|
+
dimension, dimension, num_layers, bidirectional=bidirectional
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
x = x.permute(2, 0, 1)
|
|
40
|
+
y, _ = self.lstm(x)
|
|
41
|
+
if self.bidirectional:
|
|
42
|
+
x = x.repeat(1, 1, 2)
|
|
43
|
+
if self.skip:
|
|
44
|
+
y = y + x
|
|
45
|
+
y = y.permute(1, 2, 0)
|
|
46
|
+
return y
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Normalization modules."""
|
|
14
|
+
|
|
15
|
+
import typing as tp
|
|
16
|
+
|
|
17
|
+
import einops
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConvLayerNorm(nn.LayerNorm):
|
|
23
|
+
"""
|
|
24
|
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
|
25
|
+
before running the normalization and moves them back to original position right after.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
|
|
30
|
+
):
|
|
31
|
+
super().__init__(normalized_shape, **kwargs)
|
|
32
|
+
|
|
33
|
+
def forward(self, x):
|
|
34
|
+
x = einops.rearrange(x, "b ... t -> b t ...")
|
|
35
|
+
x = super().forward(x)
|
|
36
|
+
x = einops.rearrange(x, "b t ... -> b ... t")
|
|
37
|
+
return
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
# flake8: noqa
|
|
14
|
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|