xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +11 -28
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/core/supervisor.py +87 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +38 -1
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +464 -2
- xinference/model/llm/sglang/core.py +30 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +12 -9
- xinference/model/llm/vllm/core.py +93 -17
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from indextts.s2mel.modules.gpt_fast.model import ModelArgs, Transformer
|
|
6
|
+
from indextts.s2mel.modules.wavenet import WN
|
|
7
|
+
from indextts.s2mel.modules.commons import sequence_mask
|
|
8
|
+
|
|
9
|
+
from torch.nn.utils import weight_norm
|
|
10
|
+
|
|
11
|
+
def modulate(x, shift, scale):
|
|
12
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
#################################################################################
|
|
16
|
+
# Embedding Layers for Timesteps and Class Labels #
|
|
17
|
+
#################################################################################
|
|
18
|
+
|
|
19
|
+
class TimestepEmbedder(nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Embeds scalar timesteps into vector representations.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.mlp = nn.Sequential(
|
|
26
|
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
27
|
+
nn.SiLU(),
|
|
28
|
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
29
|
+
)
|
|
30
|
+
self.frequency_embedding_size = frequency_embedding_size
|
|
31
|
+
self.max_period = 10000
|
|
32
|
+
self.scale = 1000
|
|
33
|
+
|
|
34
|
+
half = frequency_embedding_size // 2
|
|
35
|
+
freqs = torch.exp(
|
|
36
|
+
-math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
37
|
+
)
|
|
38
|
+
self.register_buffer("freqs", freqs)
|
|
39
|
+
|
|
40
|
+
def timestep_embedding(self, t):
|
|
41
|
+
"""
|
|
42
|
+
Create sinusoidal timestep embeddings.
|
|
43
|
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
44
|
+
These may be fractional.
|
|
45
|
+
:param dim: the dimension of the output.
|
|
46
|
+
:param max_period: controls the minimum frequency of the embeddings.
|
|
47
|
+
:return: an (N, D) Tensor of positional embeddings.
|
|
48
|
+
"""
|
|
49
|
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
50
|
+
|
|
51
|
+
args = self.scale * t[:, None].float() * self.freqs[None]
|
|
52
|
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
53
|
+
if self.frequency_embedding_size % 2:
|
|
54
|
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
55
|
+
return embedding
|
|
56
|
+
|
|
57
|
+
def forward(self, t):
|
|
58
|
+
t_freq = self.timestep_embedding(t)
|
|
59
|
+
t_emb = self.mlp(t_freq)
|
|
60
|
+
return t_emb
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class StyleEmbedder(nn.Module):
|
|
64
|
+
"""
|
|
65
|
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, input_size, hidden_size, dropout_prob):
|
|
68
|
+
super().__init__()
|
|
69
|
+
use_cfg_embedding = dropout_prob > 0
|
|
70
|
+
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
|
|
71
|
+
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
|
|
72
|
+
self.input_size = input_size
|
|
73
|
+
self.dropout_prob = dropout_prob
|
|
74
|
+
|
|
75
|
+
def forward(self, labels, train, force_drop_ids=None):
|
|
76
|
+
use_dropout = self.dropout_prob > 0
|
|
77
|
+
if (train and use_dropout) or (force_drop_ids is not None):
|
|
78
|
+
labels = self.token_drop(labels, force_drop_ids)
|
|
79
|
+
else:
|
|
80
|
+
labels = self.style_in(labels)
|
|
81
|
+
embeddings = labels
|
|
82
|
+
return embeddings
|
|
83
|
+
|
|
84
|
+
class FinalLayer(nn.Module):
|
|
85
|
+
"""
|
|
86
|
+
The final layer of DiT.
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, hidden_size, patch_size, out_channels):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
91
|
+
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
|
|
92
|
+
self.adaLN_modulation = nn.Sequential(
|
|
93
|
+
nn.SiLU(),
|
|
94
|
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def forward(self, x, c):
|
|
98
|
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
99
|
+
x = modulate(self.norm_final(x), shift, scale)
|
|
100
|
+
x = self.linear(x)
|
|
101
|
+
return x
|
|
102
|
+
|
|
103
|
+
class DiT(torch.nn.Module):
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
args
|
|
107
|
+
):
|
|
108
|
+
super(DiT, self).__init__()
|
|
109
|
+
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
|
|
110
|
+
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
|
|
111
|
+
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
|
|
112
|
+
model_args = ModelArgs(
|
|
113
|
+
block_size=16384,#args.DiT.block_size,
|
|
114
|
+
n_layer=args.DiT.depth,
|
|
115
|
+
n_head=args.DiT.num_heads,
|
|
116
|
+
dim=args.DiT.hidden_dim,
|
|
117
|
+
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
|
|
118
|
+
vocab_size=1024,
|
|
119
|
+
uvit_skip_connection=self.uvit_skip_connection,
|
|
120
|
+
time_as_token=self.time_as_token,
|
|
121
|
+
)
|
|
122
|
+
self.transformer = Transformer(model_args)
|
|
123
|
+
self.in_channels = args.DiT.in_channels
|
|
124
|
+
self.out_channels = args.DiT.in_channels
|
|
125
|
+
self.num_heads = args.DiT.num_heads
|
|
126
|
+
|
|
127
|
+
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
|
|
128
|
+
|
|
129
|
+
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
|
|
130
|
+
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
|
|
131
|
+
self.content_dim = args.DiT.content_dim # for continuous content
|
|
132
|
+
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
|
|
133
|
+
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
|
|
134
|
+
|
|
135
|
+
self.is_causal = args.DiT.is_causal
|
|
136
|
+
|
|
137
|
+
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
|
|
138
|
+
|
|
139
|
+
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
|
|
140
|
+
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
|
|
141
|
+
|
|
142
|
+
input_pos = torch.arange(16384)
|
|
143
|
+
self.register_buffer("input_pos", input_pos)
|
|
144
|
+
|
|
145
|
+
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
|
|
146
|
+
if self.final_layer_type == 'wavenet':
|
|
147
|
+
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
|
|
148
|
+
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
|
|
149
|
+
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
|
|
150
|
+
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
|
|
151
|
+
kernel_size=args.wavenet.kernel_size,
|
|
152
|
+
dilation_rate=args.wavenet.dilation_rate,
|
|
153
|
+
n_layers=args.wavenet.num_layers,
|
|
154
|
+
gin_channels=args.wavenet.hidden_dim,
|
|
155
|
+
p_dropout=args.wavenet.p_dropout,
|
|
156
|
+
causal=False)
|
|
157
|
+
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
|
|
158
|
+
self.res_projection = nn.Linear(args.DiT.hidden_dim,
|
|
159
|
+
args.wavenet.hidden_dim) # residual connection from tranformer output to final output
|
|
160
|
+
self.wavenet_style_condition = args.wavenet.style_condition
|
|
161
|
+
assert args.DiT.style_condition == args.wavenet.style_condition
|
|
162
|
+
else:
|
|
163
|
+
self.final_mlp = nn.Sequential(
|
|
164
|
+
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
|
|
165
|
+
nn.SiLU(),
|
|
166
|
+
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
|
|
167
|
+
)
|
|
168
|
+
self.transformer_style_condition = args.DiT.style_condition
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
self.class_dropout_prob = args.DiT.class_dropout_prob
|
|
172
|
+
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
|
|
173
|
+
|
|
174
|
+
self.long_skip_connection = args.DiT.long_skip_connection
|
|
175
|
+
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
|
|
176
|
+
|
|
177
|
+
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
|
|
178
|
+
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
|
|
179
|
+
args.DiT.hidden_dim)
|
|
180
|
+
if self.style_as_token:
|
|
181
|
+
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
|
|
182
|
+
|
|
183
|
+
def setup_caches(self, max_batch_size, max_seq_length):
|
|
184
|
+
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
|
|
185
|
+
|
|
186
|
+
def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
|
|
187
|
+
"""
|
|
188
|
+
x (torch.Tensor): random noise
|
|
189
|
+
prompt_x (torch.Tensor): reference mel + zero mel
|
|
190
|
+
shape: (batch_size, 80, 795+1068)
|
|
191
|
+
x_lens (torch.Tensor): mel frames output
|
|
192
|
+
shape: (batch_size, mel_timesteps)
|
|
193
|
+
t (torch.Tensor): radshape:
|
|
194
|
+
shape: (batch_size)
|
|
195
|
+
style (torch.Tensor): reference global style
|
|
196
|
+
shape: (batch_size, 192)
|
|
197
|
+
cond (torch.Tensor): semantic info of reference audio and altered audio
|
|
198
|
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
|
199
|
+
|
|
200
|
+
"""
|
|
201
|
+
class_dropout = False
|
|
202
|
+
if self.training and torch.rand(1) < self.class_dropout_prob:
|
|
203
|
+
class_dropout = True
|
|
204
|
+
if not self.training and mask_content:
|
|
205
|
+
class_dropout = True
|
|
206
|
+
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
|
|
207
|
+
cond_in_module = self.cond_projection
|
|
208
|
+
|
|
209
|
+
B, _, T = x.size()
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
|
|
213
|
+
cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
|
|
214
|
+
|
|
215
|
+
x = x.transpose(1, 2) # [2,1863,80]
|
|
216
|
+
prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
|
|
217
|
+
|
|
218
|
+
x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
|
|
219
|
+
|
|
220
|
+
if self.transformer_style_condition and not self.style_as_token: # True and True
|
|
221
|
+
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
|
|
222
|
+
|
|
223
|
+
if class_dropout: #False
|
|
224
|
+
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
|
|
225
|
+
|
|
226
|
+
x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
|
|
227
|
+
|
|
228
|
+
if self.style_as_token: # False
|
|
229
|
+
style = self.style_in(style)
|
|
230
|
+
style = torch.zeros_like(style) if class_dropout else style
|
|
231
|
+
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
|
|
232
|
+
|
|
233
|
+
if self.time_as_token: # False
|
|
234
|
+
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
|
235
|
+
|
|
236
|
+
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
|
|
237
|
+
input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
|
|
238
|
+
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
|
|
239
|
+
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
|
|
240
|
+
x_res = x_res[:, 1:] if self.time_as_token else x_res
|
|
241
|
+
x_res = x_res[:, 1:] if self.style_as_token else x_res
|
|
242
|
+
|
|
243
|
+
if self.long_skip_connection: #True
|
|
244
|
+
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
|
|
245
|
+
if self.final_layer_type == 'wavenet':
|
|
246
|
+
x = self.conv1(x_res)
|
|
247
|
+
x = x.transpose(1, 2)
|
|
248
|
+
t2 = self.t_embedder2(t)
|
|
249
|
+
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
|
|
250
|
+
x_res) # long residual connection
|
|
251
|
+
x = self.final_layer(x, t1).transpose(1, 2)
|
|
252
|
+
x = self.conv2(x)
|
|
253
|
+
else:
|
|
254
|
+
x = self.final_mlp(x_res)
|
|
255
|
+
x = x.transpose(1, 2)
|
|
256
|
+
# x [2,80,1863]
|
|
257
|
+
return x
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Convolutional layers wrappers and utilities."""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import typing as tp
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch import nn
|
|
15
|
+
from torch.nn import functional as F
|
|
16
|
+
from torch.nn.utils import spectral_norm, weight_norm
|
|
17
|
+
|
|
18
|
+
import typing as tp
|
|
19
|
+
|
|
20
|
+
import einops
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ConvLayerNorm(nn.LayerNorm):
|
|
24
|
+
"""
|
|
25
|
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
|
26
|
+
before running the normalization and moves them back to original position right after.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
|
29
|
+
super().__init__(normalized_shape, **kwargs)
|
|
30
|
+
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
|
33
|
+
x = super().forward(x)
|
|
34
|
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
|
39
|
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
|
43
|
+
assert norm in CONV_NORMALIZATIONS
|
|
44
|
+
if norm == 'weight_norm':
|
|
45
|
+
return weight_norm(module)
|
|
46
|
+
elif norm == 'spectral_norm':
|
|
47
|
+
return spectral_norm(module)
|
|
48
|
+
else:
|
|
49
|
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
|
50
|
+
# doesn't need reparametrization.
|
|
51
|
+
return module
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
|
55
|
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
|
56
|
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
|
57
|
+
"""
|
|
58
|
+
assert norm in CONV_NORMALIZATIONS
|
|
59
|
+
if norm == 'layer_norm':
|
|
60
|
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
|
61
|
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
|
62
|
+
elif norm == 'time_group_norm':
|
|
63
|
+
if causal:
|
|
64
|
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
|
65
|
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
|
66
|
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
|
67
|
+
else:
|
|
68
|
+
return nn.Identity()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
|
72
|
+
padding_total: int = 0) -> int:
|
|
73
|
+
"""See `pad_for_conv1d`.
|
|
74
|
+
"""
|
|
75
|
+
length = x.shape[-1]
|
|
76
|
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
77
|
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
78
|
+
return ideal_length - length
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
|
82
|
+
"""Pad for a convolution to make sure that the last window is full.
|
|
83
|
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
|
84
|
+
an output of the same length, as otherwise, even with padding, some time steps
|
|
85
|
+
might get removed.
|
|
86
|
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
|
87
|
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
|
88
|
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
|
89
|
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
|
90
|
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
|
91
|
+
"""
|
|
92
|
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
93
|
+
return F.pad(x, (0, extra_padding))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
|
97
|
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
98
|
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
|
99
|
+
"""
|
|
100
|
+
length = x.shape[-1]
|
|
101
|
+
padding_left, padding_right = paddings
|
|
102
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
103
|
+
if mode == 'reflect':
|
|
104
|
+
max_pad = max(padding_left, padding_right)
|
|
105
|
+
extra_pad = 0
|
|
106
|
+
if length <= max_pad:
|
|
107
|
+
extra_pad = max_pad - length + 1
|
|
108
|
+
x = F.pad(x, (0, extra_pad))
|
|
109
|
+
padded = F.pad(x, paddings, mode, value)
|
|
110
|
+
end = padded.shape[-1] - extra_pad
|
|
111
|
+
return padded[..., :end]
|
|
112
|
+
else:
|
|
113
|
+
return F.pad(x, paddings, mode, value)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
117
|
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
|
118
|
+
padding_left, padding_right = paddings
|
|
119
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
120
|
+
assert (padding_left + padding_right) <= x.shape[-1]
|
|
121
|
+
end = x.shape[-1] - padding_right
|
|
122
|
+
return x[..., padding_left: end]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class NormConv1d(nn.Module):
|
|
126
|
+
"""Wrapper around Conv1d and normalization applied to this conv
|
|
127
|
+
to provide a uniform interface across normalization approaches.
|
|
128
|
+
"""
|
|
129
|
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
|
130
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
|
133
|
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
|
134
|
+
self.norm_type = norm
|
|
135
|
+
|
|
136
|
+
def forward(self, x):
|
|
137
|
+
x = self.conv(x)
|
|
138
|
+
x = self.norm(x)
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class NormConv2d(nn.Module):
|
|
143
|
+
"""Wrapper around Conv2d and normalization applied to this conv
|
|
144
|
+
to provide a uniform interface across normalization approaches.
|
|
145
|
+
"""
|
|
146
|
+
def __init__(self, *args, norm: str = 'none',
|
|
147
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
|
150
|
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
|
151
|
+
self.norm_type = norm
|
|
152
|
+
|
|
153
|
+
def forward(self, x):
|
|
154
|
+
x = self.conv(x)
|
|
155
|
+
x = self.norm(x)
|
|
156
|
+
return x
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class NormConvTranspose1d(nn.Module):
|
|
160
|
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
|
161
|
+
to provide a uniform interface across normalization approaches.
|
|
162
|
+
"""
|
|
163
|
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
|
164
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
|
167
|
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
|
168
|
+
self.norm_type = norm
|
|
169
|
+
|
|
170
|
+
def forward(self, x):
|
|
171
|
+
x = self.convtr(x)
|
|
172
|
+
x = self.norm(x)
|
|
173
|
+
return x
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class NormConvTranspose2d(nn.Module):
|
|
177
|
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
|
178
|
+
to provide a uniform interface across normalization approaches.
|
|
179
|
+
"""
|
|
180
|
+
def __init__(self, *args, norm: str = 'none',
|
|
181
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
|
184
|
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
|
185
|
+
|
|
186
|
+
def forward(self, x):
|
|
187
|
+
x = self.convtr(x)
|
|
188
|
+
x = self.norm(x)
|
|
189
|
+
return x
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class SConv1d(nn.Module):
|
|
193
|
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
|
194
|
+
and normalization.
|
|
195
|
+
"""
|
|
196
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
197
|
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
|
198
|
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
|
199
|
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
|
200
|
+
pad_mode: str = 'reflect', **kwargs):
|
|
201
|
+
super().__init__()
|
|
202
|
+
# warn user on unusual setup between dilation and stride
|
|
203
|
+
if stride > 1 and dilation > 1:
|
|
204
|
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
|
205
|
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
|
206
|
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
|
207
|
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
|
208
|
+
norm=norm, norm_kwargs=norm_kwargs)
|
|
209
|
+
self.causal = causal
|
|
210
|
+
self.pad_mode = pad_mode
|
|
211
|
+
|
|
212
|
+
def forward(self, x):
|
|
213
|
+
B, C, T = x.shape
|
|
214
|
+
kernel_size = self.conv.conv.kernel_size[0]
|
|
215
|
+
stride = self.conv.conv.stride[0]
|
|
216
|
+
dilation = self.conv.conv.dilation[0]
|
|
217
|
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
|
218
|
+
padding_total = kernel_size - stride
|
|
219
|
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
220
|
+
if self.causal:
|
|
221
|
+
# Left padding for causal
|
|
222
|
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
|
223
|
+
else:
|
|
224
|
+
# Asymmetric padding required for odd strides
|
|
225
|
+
padding_right = padding_total // 2
|
|
226
|
+
padding_left = padding_total - padding_right
|
|
227
|
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
|
228
|
+
return self.conv(x)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class SConvTranspose1d(nn.Module):
|
|
232
|
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
|
233
|
+
and normalization.
|
|
234
|
+
"""
|
|
235
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
236
|
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
|
237
|
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
|
238
|
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
|
241
|
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
|
242
|
+
self.causal = causal
|
|
243
|
+
self.trim_right_ratio = trim_right_ratio
|
|
244
|
+
assert self.causal or self.trim_right_ratio == 1., \
|
|
245
|
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
|
246
|
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
|
247
|
+
|
|
248
|
+
def forward(self, x):
|
|
249
|
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
|
250
|
+
stride = self.convtr.convtr.stride[0]
|
|
251
|
+
padding_total = kernel_size - stride
|
|
252
|
+
|
|
253
|
+
y = self.convtr(x)
|
|
254
|
+
|
|
255
|
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
|
256
|
+
# removed at the very end, when keeping only the right length for the output,
|
|
257
|
+
# as removing it here would require also passing the length at the matching layer
|
|
258
|
+
# in the encoder.
|
|
259
|
+
if self.causal:
|
|
260
|
+
# Trim the padding on the right according to the specified ratio
|
|
261
|
+
# if trim_right_ratio = 1.0, trim everything from right
|
|
262
|
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
|
263
|
+
padding_left = padding_total - padding_right
|
|
264
|
+
y = unpad1d(y, (padding_left, padding_right))
|
|
265
|
+
else:
|
|
266
|
+
# Asymmetric padding required for odd strides
|
|
267
|
+
padding_right = padding_total // 2
|
|
268
|
+
padding_left = padding_total - padding_right
|
|
269
|
+
y = unpad1d(y, (padding_left, padding_right))
|
|
270
|
+
return y
|
|
271
|
+
|
|
272
|
+
class SLSTM(nn.Module):
|
|
273
|
+
"""
|
|
274
|
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
|
275
|
+
Expects input as convolutional layout.
|
|
276
|
+
"""
|
|
277
|
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
|
278
|
+
super().__init__()
|
|
279
|
+
self.skip = skip
|
|
280
|
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
|
281
|
+
self.hidden = None
|
|
282
|
+
|
|
283
|
+
def forward(self, x):
|
|
284
|
+
x = x.permute(2, 0, 1)
|
|
285
|
+
if self.training:
|
|
286
|
+
y, _ = self.lstm(x)
|
|
287
|
+
else:
|
|
288
|
+
y, self.hidden = self.lstm(x, self.hidden)
|
|
289
|
+
if self.skip:
|
|
290
|
+
y = y + x
|
|
291
|
+
y = y.permute(1, 2, 0)
|
|
292
|
+
return y
|