xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +400 -3
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +111 -49
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +26 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +58 -1
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +71 -3
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +503 -21
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +32 -55
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +138 -53
- xinference/model/llm/vllm/core.py +95 -78
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Torch distributed utilities."""
|
|
14
|
+
|
|
15
|
+
import typing as tp
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def rank():
|
|
21
|
+
if torch.distributed.is_initialized():
|
|
22
|
+
return torch.distributed.get_rank()
|
|
23
|
+
else:
|
|
24
|
+
return 0
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def world_size():
|
|
28
|
+
if torch.distributed.is_initialized():
|
|
29
|
+
return torch.distributed.get_world_size()
|
|
30
|
+
else:
|
|
31
|
+
return 1
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def is_distributed():
|
|
35
|
+
return world_size() > 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
|
39
|
+
if is_distributed():
|
|
40
|
+
return torch.distributed.all_reduce(tensor, op)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _is_complex_or_float(tensor):
|
|
44
|
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
|
48
|
+
# utility function to check that the number of params in all workers is the same,
|
|
49
|
+
# and thus avoid a deadlock with distributed all reduce.
|
|
50
|
+
if not is_distributed() or not params:
|
|
51
|
+
return
|
|
52
|
+
# print('params[0].device ', params[0].device)
|
|
53
|
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
|
54
|
+
all_reduce(tensor)
|
|
55
|
+
if tensor.item() != len(params) * world_size():
|
|
56
|
+
# If not all the workers have the same number, for at least one of them,
|
|
57
|
+
# this inequality will be verified.
|
|
58
|
+
raise RuntimeError(
|
|
59
|
+
f"Mismatch in number of params: ours is {len(params)}, "
|
|
60
|
+
"at least one worker has a different one."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
|
65
|
+
"""Broadcast the tensors from the given parameters to all workers.
|
|
66
|
+
This can be used to ensure that all workers have the same model to start with.
|
|
67
|
+
"""
|
|
68
|
+
if not is_distributed():
|
|
69
|
+
return
|
|
70
|
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
|
71
|
+
_check_number_of_params(tensors)
|
|
72
|
+
handles = []
|
|
73
|
+
for tensor in tensors:
|
|
74
|
+
# src = int(rank()) # added code
|
|
75
|
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
|
76
|
+
handles.append(handle)
|
|
77
|
+
for handle in handles:
|
|
78
|
+
handle.wait()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def sync_buffer(buffers, average=True):
|
|
82
|
+
"""
|
|
83
|
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
|
84
|
+
"""
|
|
85
|
+
if not is_distributed():
|
|
86
|
+
return
|
|
87
|
+
handles = []
|
|
88
|
+
for buffer in buffers:
|
|
89
|
+
if torch.is_floating_point(buffer.data):
|
|
90
|
+
if average:
|
|
91
|
+
handle = torch.distributed.all_reduce(
|
|
92
|
+
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
|
96
|
+
handles.append((buffer, handle))
|
|
97
|
+
for buffer, handle in handles:
|
|
98
|
+
handle.wait()
|
|
99
|
+
if average:
|
|
100
|
+
buffer.data /= world_size
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def sync_grad(params):
|
|
104
|
+
"""
|
|
105
|
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
|
106
|
+
on any black magic. For simple models it can also be as fast.
|
|
107
|
+
Just call this on your model parameters after the call to backward!
|
|
108
|
+
"""
|
|
109
|
+
if not is_distributed():
|
|
110
|
+
return
|
|
111
|
+
handles = []
|
|
112
|
+
for p in params:
|
|
113
|
+
if p.grad is not None:
|
|
114
|
+
handle = torch.distributed.all_reduce(
|
|
115
|
+
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
|
116
|
+
)
|
|
117
|
+
handles.append((p, handle))
|
|
118
|
+
for p, handle in handles:
|
|
119
|
+
handle.wait()
|
|
120
|
+
p.grad.data /= world_size()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
|
124
|
+
"""Average a dictionary of metrics across all workers, using the optional
|
|
125
|
+
`count` as unormalized weight.
|
|
126
|
+
"""
|
|
127
|
+
if not is_distributed():
|
|
128
|
+
return metrics
|
|
129
|
+
keys, values = zip(*metrics.items())
|
|
130
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
131
|
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
|
132
|
+
tensor *= count
|
|
133
|
+
all_reduce(tensor)
|
|
134
|
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
|
135
|
+
return dict(zip(keys, averaged))
|
xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Residual vector quantizer implementation."""
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
import math
|
|
17
|
+
import typing as tp
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
from .core_vq import ResidualVectorQuantization
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class QuantizedResult:
|
|
27
|
+
quantized: torch.Tensor
|
|
28
|
+
codes: torch.Tensor
|
|
29
|
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
|
30
|
+
penalty: tp.Optional[torch.Tensor] = None
|
|
31
|
+
metrics: dict = field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ResidualVectorQuantizer(nn.Module):
|
|
35
|
+
"""Residual Vector Quantizer.
|
|
36
|
+
Args:
|
|
37
|
+
dimension (int): Dimension of the codebooks.
|
|
38
|
+
n_q (int): Number of residual vector quantizers used.
|
|
39
|
+
bins (int): Codebook size.
|
|
40
|
+
decay (float): Decay for exponential moving average over the codebooks.
|
|
41
|
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
42
|
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
43
|
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
44
|
+
that have an exponential moving average cluster size less than the specified threshold with
|
|
45
|
+
randomly selected vector from the current batch.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
dimension: int = 256,
|
|
51
|
+
n_q: int = 8,
|
|
52
|
+
bins: int = 1024,
|
|
53
|
+
decay: float = 0.99,
|
|
54
|
+
kmeans_init: bool = True,
|
|
55
|
+
kmeans_iters: int = 50,
|
|
56
|
+
threshold_ema_dead_code: int = 2,
|
|
57
|
+
):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.n_q = n_q
|
|
60
|
+
self.dimension = dimension
|
|
61
|
+
self.bins = bins
|
|
62
|
+
self.decay = decay
|
|
63
|
+
self.kmeans_init = kmeans_init
|
|
64
|
+
self.kmeans_iters = kmeans_iters
|
|
65
|
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
66
|
+
self.vq = ResidualVectorQuantization(
|
|
67
|
+
dim=self.dimension,
|
|
68
|
+
codebook_size=self.bins,
|
|
69
|
+
num_quantizers=self.n_q,
|
|
70
|
+
decay=self.decay,
|
|
71
|
+
kmeans_init=self.kmeans_init,
|
|
72
|
+
kmeans_iters=self.kmeans_iters,
|
|
73
|
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
x: torch.Tensor,
|
|
79
|
+
n_q: tp.Optional[int] = None,
|
|
80
|
+
layers: tp.Optional[list] = None,
|
|
81
|
+
) -> QuantizedResult:
|
|
82
|
+
"""Residual vector quantization on the given input tensor.
|
|
83
|
+
Args:
|
|
84
|
+
x (torch.Tensor): Input tensor.
|
|
85
|
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
|
86
|
+
layers (list): Layer that need to return quantized. Defalt: None.
|
|
87
|
+
Returns:
|
|
88
|
+
QuantizedResult:
|
|
89
|
+
The quantized (or approximately quantized) representation with
|
|
90
|
+
the associated numbert quantizers and layer quantized required to return.
|
|
91
|
+
"""
|
|
92
|
+
n_q = n_q if n_q else self.n_q
|
|
93
|
+
if layers and max(layers) >= n_q:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
|
96
|
+
)
|
|
97
|
+
quantized, codes, commit_loss, quantized_list = self.vq(
|
|
98
|
+
x, n_q=n_q, layers=layers
|
|
99
|
+
)
|
|
100
|
+
return quantized, codes, torch.mean(commit_loss), quantized_list
|
|
101
|
+
|
|
102
|
+
def encode(
|
|
103
|
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
|
104
|
+
) -> torch.Tensor:
|
|
105
|
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
|
106
|
+
The RVQ encode method sets the appropriate number of quantizer to use
|
|
107
|
+
and returns indices for each quantizer.
|
|
108
|
+
Args:
|
|
109
|
+
x (torch.Tensor): Input tensor.
|
|
110
|
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
|
111
|
+
st (int): Start to encode input from which layers. Default: 0.
|
|
112
|
+
"""
|
|
113
|
+
n_q = n_q if n_q else self.n_q
|
|
114
|
+
st = st or 0
|
|
115
|
+
codes = self.vq.encode(x, n_q=n_q, st=st)
|
|
116
|
+
return codes
|
|
117
|
+
|
|
118
|
+
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
|
119
|
+
"""Decode the given codes to the quantized representation.
|
|
120
|
+
Args:
|
|
121
|
+
codes (torch.Tensor): Input indices for each quantizer.
|
|
122
|
+
st (int): Start to decode input codes from which layers. Default: 0.
|
|
123
|
+
"""
|
|
124
|
+
quantized = self.vq.decode(codes, st=st)
|
|
125
|
+
return quantized
|
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
# Copyright (c) 2023 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# This source file is copied from https://github.com/facebookresearch/encodec
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the license found in the
|
|
11
|
+
# LICENSE file in the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
|
14
|
+
|
|
15
|
+
import typing as tp
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from . import SConv1d, SConvTranspose1d, SLSTM
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@torch.jit.script
|
|
25
|
+
def snake(x, alpha):
|
|
26
|
+
shape = x.shape
|
|
27
|
+
x = x.reshape(shape[0], shape[1], -1)
|
|
28
|
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
|
29
|
+
x = x.reshape(shape)
|
|
30
|
+
return x
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Snake1d(nn.Module):
|
|
34
|
+
def __init__(self, channels):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
return snake(x, self.alpha)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SEANetResnetBlock(nn.Module):
|
|
43
|
+
"""Residual block from SEANet model.
|
|
44
|
+
Args:
|
|
45
|
+
dim (int): Dimension of the input/output
|
|
46
|
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
|
47
|
+
dilations (list): List of dilations for the convolutions.
|
|
48
|
+
activation (str): Activation function.
|
|
49
|
+
activation_params (dict): Parameters to provide to the activation function
|
|
50
|
+
norm (str): Normalization method.
|
|
51
|
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
|
52
|
+
causal (bool): Whether to use fully causal convolution.
|
|
53
|
+
pad_mode (str): Padding mode for the convolutions.
|
|
54
|
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
|
55
|
+
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
dim: int,
|
|
61
|
+
kernel_sizes: tp.List[int] = [3, 1],
|
|
62
|
+
dilations: tp.List[int] = [1, 1],
|
|
63
|
+
activation: str = "ELU",
|
|
64
|
+
activation_params: dict = {"alpha": 1.0},
|
|
65
|
+
norm: str = "weight_norm",
|
|
66
|
+
norm_params: tp.Dict[str, tp.Any] = {},
|
|
67
|
+
causal: bool = False,
|
|
68
|
+
pad_mode: str = "reflect",
|
|
69
|
+
compress: int = 2,
|
|
70
|
+
true_skip: bool = True,
|
|
71
|
+
):
|
|
72
|
+
super().__init__()
|
|
73
|
+
assert len(kernel_sizes) == len(
|
|
74
|
+
dilations
|
|
75
|
+
), "Number of kernel sizes should match number of dilations"
|
|
76
|
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
|
77
|
+
hidden = dim // compress
|
|
78
|
+
block = []
|
|
79
|
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
|
80
|
+
in_chs = dim if i == 0 else hidden
|
|
81
|
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
|
82
|
+
block += [
|
|
83
|
+
act(**activation_params) if activation != "Snake" else act(in_chs),
|
|
84
|
+
SConv1d(
|
|
85
|
+
in_chs,
|
|
86
|
+
out_chs,
|
|
87
|
+
kernel_size=kernel_size,
|
|
88
|
+
dilation=dilation,
|
|
89
|
+
norm=norm,
|
|
90
|
+
norm_kwargs=norm_params,
|
|
91
|
+
causal=causal,
|
|
92
|
+
pad_mode=pad_mode,
|
|
93
|
+
),
|
|
94
|
+
]
|
|
95
|
+
self.block = nn.Sequential(*block)
|
|
96
|
+
self.shortcut: nn.Module
|
|
97
|
+
if true_skip:
|
|
98
|
+
self.shortcut = nn.Identity()
|
|
99
|
+
else:
|
|
100
|
+
self.shortcut = SConv1d(
|
|
101
|
+
dim,
|
|
102
|
+
dim,
|
|
103
|
+
kernel_size=1,
|
|
104
|
+
norm=norm,
|
|
105
|
+
norm_kwargs=norm_params,
|
|
106
|
+
causal=causal,
|
|
107
|
+
pad_mode=pad_mode,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
return self.shortcut(x) + self.block(x)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class SEANetEncoder(nn.Module):
|
|
115
|
+
"""SEANet encoder.
|
|
116
|
+
Args:
|
|
117
|
+
channels (int): Audio channels.
|
|
118
|
+
dimension (int): Intermediate representation dimension.
|
|
119
|
+
n_filters (int): Base width for the model.
|
|
120
|
+
n_residual_layers (int): nb of residual layers.
|
|
121
|
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
|
122
|
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
|
123
|
+
that must match the decoder order
|
|
124
|
+
activation (str): Activation function.
|
|
125
|
+
activation_params (dict): Parameters to provide to the activation function
|
|
126
|
+
norm (str): Normalization method.
|
|
127
|
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
|
128
|
+
kernel_size (int): Kernel size for the initial convolution.
|
|
129
|
+
last_kernel_size (int): Kernel size for the initial convolution.
|
|
130
|
+
residual_kernel_size (int): Kernel size for the residual layers.
|
|
131
|
+
dilation_base (int): How much to increase the dilation with each layer.
|
|
132
|
+
causal (bool): Whether to use fully causal convolution.
|
|
133
|
+
pad_mode (str): Padding mode for the convolutions.
|
|
134
|
+
true_skip (bool): Whether to use true skip connection or a simple
|
|
135
|
+
(streamable) convolution as the skip connection in the residual network blocks.
|
|
136
|
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
|
137
|
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
channels: int = 1,
|
|
143
|
+
dimension: int = 128,
|
|
144
|
+
n_filters: int = 32,
|
|
145
|
+
n_residual_layers: int = 1,
|
|
146
|
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
|
147
|
+
activation: str = "ELU",
|
|
148
|
+
activation_params: dict = {"alpha": 1.0},
|
|
149
|
+
norm: str = "weight_norm",
|
|
150
|
+
norm_params: tp.Dict[str, tp.Any] = {},
|
|
151
|
+
kernel_size: int = 7,
|
|
152
|
+
last_kernel_size: int = 7,
|
|
153
|
+
residual_kernel_size: int = 3,
|
|
154
|
+
dilation_base: int = 2,
|
|
155
|
+
causal: bool = False,
|
|
156
|
+
pad_mode: str = "reflect",
|
|
157
|
+
true_skip: bool = False,
|
|
158
|
+
compress: int = 2,
|
|
159
|
+
lstm: int = 2,
|
|
160
|
+
bidirectional: bool = False,
|
|
161
|
+
):
|
|
162
|
+
super().__init__()
|
|
163
|
+
self.channels = channels
|
|
164
|
+
self.dimension = dimension
|
|
165
|
+
self.n_filters = n_filters
|
|
166
|
+
self.ratios = list(reversed(ratios))
|
|
167
|
+
del ratios
|
|
168
|
+
self.n_residual_layers = n_residual_layers
|
|
169
|
+
self.hop_length = np.prod(self.ratios) # 计算乘积
|
|
170
|
+
|
|
171
|
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
|
172
|
+
mult = 1
|
|
173
|
+
model: tp.List[nn.Module] = [
|
|
174
|
+
SConv1d(
|
|
175
|
+
channels,
|
|
176
|
+
mult * n_filters,
|
|
177
|
+
kernel_size,
|
|
178
|
+
norm=norm,
|
|
179
|
+
norm_kwargs=norm_params,
|
|
180
|
+
causal=causal,
|
|
181
|
+
pad_mode=pad_mode,
|
|
182
|
+
)
|
|
183
|
+
]
|
|
184
|
+
# Downsample to raw audio scale
|
|
185
|
+
for i, ratio in enumerate(self.ratios):
|
|
186
|
+
# Add residual layers
|
|
187
|
+
for j in range(n_residual_layers):
|
|
188
|
+
model += [
|
|
189
|
+
SEANetResnetBlock(
|
|
190
|
+
mult * n_filters,
|
|
191
|
+
kernel_sizes=[residual_kernel_size, 1],
|
|
192
|
+
dilations=[dilation_base**j, 1],
|
|
193
|
+
norm=norm,
|
|
194
|
+
norm_params=norm_params,
|
|
195
|
+
activation=activation,
|
|
196
|
+
activation_params=activation_params,
|
|
197
|
+
causal=causal,
|
|
198
|
+
pad_mode=pad_mode,
|
|
199
|
+
compress=compress,
|
|
200
|
+
true_skip=true_skip,
|
|
201
|
+
)
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
# Add downsampling layers
|
|
205
|
+
model += [
|
|
206
|
+
(
|
|
207
|
+
act(**activation_params)
|
|
208
|
+
if activation != "Snake"
|
|
209
|
+
else act(mult * n_filters)
|
|
210
|
+
),
|
|
211
|
+
SConv1d(
|
|
212
|
+
mult * n_filters,
|
|
213
|
+
mult * n_filters * 2,
|
|
214
|
+
kernel_size=ratio * 2,
|
|
215
|
+
stride=ratio,
|
|
216
|
+
norm=norm,
|
|
217
|
+
norm_kwargs=norm_params,
|
|
218
|
+
causal=causal,
|
|
219
|
+
pad_mode=pad_mode,
|
|
220
|
+
),
|
|
221
|
+
]
|
|
222
|
+
mult *= 2
|
|
223
|
+
|
|
224
|
+
if lstm:
|
|
225
|
+
model += [
|
|
226
|
+
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
mult = mult * 2 if bidirectional else mult
|
|
230
|
+
model += [
|
|
231
|
+
(
|
|
232
|
+
act(**activation_params)
|
|
233
|
+
if activation != "Snake"
|
|
234
|
+
else act(mult * n_filters)
|
|
235
|
+
),
|
|
236
|
+
SConv1d(
|
|
237
|
+
mult * n_filters,
|
|
238
|
+
dimension,
|
|
239
|
+
last_kernel_size,
|
|
240
|
+
norm=norm,
|
|
241
|
+
norm_kwargs=norm_params,
|
|
242
|
+
causal=causal,
|
|
243
|
+
pad_mode=pad_mode,
|
|
244
|
+
),
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
self.model = nn.Sequential(*model)
|
|
248
|
+
|
|
249
|
+
def forward(self, x):
|
|
250
|
+
return self.model(x)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class SEANetDecoder(nn.Module):
|
|
254
|
+
"""SEANet decoder.
|
|
255
|
+
Args:
|
|
256
|
+
channels (int): Audio channels.
|
|
257
|
+
dimension (int): Intermediate representation dimension.
|
|
258
|
+
n_filters (int): Base width for the model.
|
|
259
|
+
n_residual_layers (int): nb of residual layers.
|
|
260
|
+
ratios (Sequence[int]): kernel size and stride ratios
|
|
261
|
+
activation (str): Activation function.
|
|
262
|
+
activation_params (dict): Parameters to provide to the activation function
|
|
263
|
+
final_activation (str): Final activation function after all convolutions.
|
|
264
|
+
final_activation_params (dict): Parameters to provide to the activation function
|
|
265
|
+
norm (str): Normalization method.
|
|
266
|
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
|
267
|
+
kernel_size (int): Kernel size for the initial convolution.
|
|
268
|
+
last_kernel_size (int): Kernel size for the initial convolution.
|
|
269
|
+
residual_kernel_size (int): Kernel size for the residual layers.
|
|
270
|
+
dilation_base (int): How much to increase the dilation with each layer.
|
|
271
|
+
causal (bool): Whether to use fully causal convolution.
|
|
272
|
+
pad_mode (str): Padding mode for the convolutions.
|
|
273
|
+
true_skip (bool): Whether to use true skip connection or a simple
|
|
274
|
+
(streamable) convolution as the skip connection in the residual network blocks.
|
|
275
|
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
|
276
|
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
|
277
|
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
|
278
|
+
If equal to 1.0, it means that all the trimming is done at the right.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(
|
|
282
|
+
self,
|
|
283
|
+
channels: int = 1,
|
|
284
|
+
dimension: int = 128,
|
|
285
|
+
n_filters: int = 32,
|
|
286
|
+
n_residual_layers: int = 1,
|
|
287
|
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
|
288
|
+
activation: str = "ELU",
|
|
289
|
+
activation_params: dict = {"alpha": 1.0},
|
|
290
|
+
final_activation: tp.Optional[str] = None,
|
|
291
|
+
final_activation_params: tp.Optional[dict] = None,
|
|
292
|
+
norm: str = "weight_norm",
|
|
293
|
+
norm_params: tp.Dict[str, tp.Any] = {},
|
|
294
|
+
kernel_size: int = 7,
|
|
295
|
+
last_kernel_size: int = 7,
|
|
296
|
+
residual_kernel_size: int = 3,
|
|
297
|
+
dilation_base: int = 2,
|
|
298
|
+
causal: bool = False,
|
|
299
|
+
pad_mode: str = "reflect",
|
|
300
|
+
true_skip: bool = False,
|
|
301
|
+
compress: int = 2,
|
|
302
|
+
lstm: int = 2,
|
|
303
|
+
trim_right_ratio: float = 1.0,
|
|
304
|
+
bidirectional: bool = False,
|
|
305
|
+
):
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.dimension = dimension
|
|
308
|
+
self.channels = channels
|
|
309
|
+
self.n_filters = n_filters
|
|
310
|
+
self.ratios = ratios
|
|
311
|
+
del ratios
|
|
312
|
+
self.n_residual_layers = n_residual_layers
|
|
313
|
+
self.hop_length = np.prod(self.ratios)
|
|
314
|
+
|
|
315
|
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
|
316
|
+
mult = int(2 ** len(self.ratios))
|
|
317
|
+
model: tp.List[nn.Module] = [
|
|
318
|
+
SConv1d(
|
|
319
|
+
dimension,
|
|
320
|
+
mult * n_filters,
|
|
321
|
+
kernel_size,
|
|
322
|
+
norm=norm,
|
|
323
|
+
norm_kwargs=norm_params,
|
|
324
|
+
causal=causal,
|
|
325
|
+
pad_mode=pad_mode,
|
|
326
|
+
)
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
if lstm:
|
|
330
|
+
model += [
|
|
331
|
+
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
# Upsample to raw audio scale
|
|
335
|
+
for i, ratio in enumerate(self.ratios):
|
|
336
|
+
# Add upsampling layers
|
|
337
|
+
model += [
|
|
338
|
+
(
|
|
339
|
+
act(**activation_params)
|
|
340
|
+
if activation != "Snake"
|
|
341
|
+
else act(mult * n_filters)
|
|
342
|
+
),
|
|
343
|
+
SConvTranspose1d(
|
|
344
|
+
mult * n_filters,
|
|
345
|
+
mult * n_filters // 2,
|
|
346
|
+
kernel_size=ratio * 2,
|
|
347
|
+
stride=ratio,
|
|
348
|
+
norm=norm,
|
|
349
|
+
norm_kwargs=norm_params,
|
|
350
|
+
causal=causal,
|
|
351
|
+
trim_right_ratio=trim_right_ratio,
|
|
352
|
+
),
|
|
353
|
+
]
|
|
354
|
+
# Add residual layers
|
|
355
|
+
for j in range(n_residual_layers):
|
|
356
|
+
model += [
|
|
357
|
+
SEANetResnetBlock(
|
|
358
|
+
mult * n_filters // 2,
|
|
359
|
+
kernel_sizes=[residual_kernel_size, 1],
|
|
360
|
+
dilations=[dilation_base**j, 1],
|
|
361
|
+
activation=activation,
|
|
362
|
+
activation_params=activation_params,
|
|
363
|
+
norm=norm,
|
|
364
|
+
norm_params=norm_params,
|
|
365
|
+
causal=causal,
|
|
366
|
+
pad_mode=pad_mode,
|
|
367
|
+
compress=compress,
|
|
368
|
+
true_skip=true_skip,
|
|
369
|
+
)
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
mult //= 2
|
|
373
|
+
|
|
374
|
+
# Add final layers
|
|
375
|
+
model += [
|
|
376
|
+
act(**activation_params) if activation != "Snake" else act(n_filters),
|
|
377
|
+
SConv1d(
|
|
378
|
+
n_filters,
|
|
379
|
+
channels,
|
|
380
|
+
last_kernel_size,
|
|
381
|
+
norm=norm,
|
|
382
|
+
norm_kwargs=norm_params,
|
|
383
|
+
causal=causal,
|
|
384
|
+
pad_mode=pad_mode,
|
|
385
|
+
),
|
|
386
|
+
]
|
|
387
|
+
# Add optional final activation to decoder (eg. tanh)
|
|
388
|
+
if final_activation is not None:
|
|
389
|
+
final_act = getattr(nn, final_activation)
|
|
390
|
+
final_activation_params = final_activation_params or {}
|
|
391
|
+
model += [final_act(**final_activation_params)]
|
|
392
|
+
self.model = nn.Sequential(*model)
|
|
393
|
+
|
|
394
|
+
def forward(self, z):
|
|
395
|
+
y = self.model(z)
|
|
396
|
+
return y
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def test():
|
|
400
|
+
import torch
|
|
401
|
+
|
|
402
|
+
encoder = SEANetEncoder()
|
|
403
|
+
decoder = SEANetDecoder()
|
|
404
|
+
x = torch.randn(1, 1, 24000)
|
|
405
|
+
z = encoder(x)
|
|
406
|
+
print("z ", z.shape)
|
|
407
|
+
assert 1 == 2
|
|
408
|
+
assert list(z.shape) == [1, 128, 75], z.shape
|
|
409
|
+
y = decoder(z)
|
|
410
|
+
assert y.shape == x.shape, (x.shape, y.shape)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
if __name__ == "__main__":
|
|
414
|
+
test()
|