xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,881 @@
|
|
|
1
|
+
# Copyright (c) 2024 Amphion.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import scipy
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn, view_as_real, view_as_complex
|
|
12
|
+
from torch import nn
|
|
13
|
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
|
14
|
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
|
15
|
+
import librosa
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
|
19
|
+
"""
|
|
20
|
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
x (Tensor): Input tensor.
|
|
24
|
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
|
28
|
+
"""
|
|
29
|
+
return torch.log(torch.clip(x, min=clip_val))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
return torch.sign(x) * torch.log1p(x.abs())
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class STFT(nn.Module):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
n_fft: int,
|
|
44
|
+
hop_length: int,
|
|
45
|
+
win_length: int,
|
|
46
|
+
center=True,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.center = center
|
|
50
|
+
self.n_fft = n_fft
|
|
51
|
+
self.hop_length = hop_length
|
|
52
|
+
self.win_length = win_length
|
|
53
|
+
window = torch.hann_window(win_length)
|
|
54
|
+
self.register_buffer("window", window)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
# x: (B, T * hop_length)
|
|
58
|
+
|
|
59
|
+
if not self.center:
|
|
60
|
+
pad = self.win_length - self.hop_length
|
|
61
|
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
|
62
|
+
|
|
63
|
+
stft_spec = torch.stft(
|
|
64
|
+
x,
|
|
65
|
+
self.n_fft,
|
|
66
|
+
hop_length=self.hop_length,
|
|
67
|
+
win_length=self.win_length,
|
|
68
|
+
window=self.window,
|
|
69
|
+
center=self.center,
|
|
70
|
+
return_complex=False,
|
|
71
|
+
) # (B, n_fft // 2 + 1, T, 2)
|
|
72
|
+
|
|
73
|
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
|
74
|
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
|
75
|
+
|
|
76
|
+
log_mag = torch.log(
|
|
77
|
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
|
78
|
+
) # (B, n_fft // 2 + 1, T)
|
|
79
|
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
|
80
|
+
|
|
81
|
+
return log_mag, phase
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ISTFT(nn.Module):
|
|
85
|
+
"""
|
|
86
|
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
|
87
|
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
|
88
|
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
|
89
|
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
|
90
|
+
The NOLA constraint is met as we trim padded samples anyway.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
n_fft (int): Size of Fourier transform.
|
|
94
|
+
hop_length (int): The distance between neighboring sliding window frames.
|
|
95
|
+
win_length (int): The size of window frame and STFT filter.
|
|
96
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
|
101
|
+
):
|
|
102
|
+
super().__init__()
|
|
103
|
+
if padding not in ["center", "same"]:
|
|
104
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
105
|
+
self.padding = padding
|
|
106
|
+
self.n_fft = n_fft
|
|
107
|
+
self.hop_length = hop_length
|
|
108
|
+
self.win_length = win_length
|
|
109
|
+
window = torch.hann_window(win_length)
|
|
110
|
+
self.register_buffer("window", window)
|
|
111
|
+
|
|
112
|
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
113
|
+
"""
|
|
114
|
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
|
118
|
+
N is the number of frequency bins, and T is the number of time frames.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
|
122
|
+
"""
|
|
123
|
+
if self.padding == "center":
|
|
124
|
+
# Fallback to pytorch native implementation
|
|
125
|
+
return torch.istft(
|
|
126
|
+
spec,
|
|
127
|
+
self.n_fft,
|
|
128
|
+
self.hop_length,
|
|
129
|
+
self.win_length,
|
|
130
|
+
self.window,
|
|
131
|
+
center=True,
|
|
132
|
+
)
|
|
133
|
+
elif self.padding == "same":
|
|
134
|
+
pad = (self.win_length - self.hop_length) // 2
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
137
|
+
|
|
138
|
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
|
139
|
+
B, N, T = spec.shape
|
|
140
|
+
|
|
141
|
+
# Inverse FFT
|
|
142
|
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
|
143
|
+
ifft = ifft * self.window[None, :, None]
|
|
144
|
+
|
|
145
|
+
# Overlap and Add
|
|
146
|
+
output_size = (T - 1) * self.hop_length + self.win_length
|
|
147
|
+
y = torch.nn.functional.fold(
|
|
148
|
+
ifft,
|
|
149
|
+
output_size=(1, output_size),
|
|
150
|
+
kernel_size=(1, self.win_length),
|
|
151
|
+
stride=(1, self.hop_length),
|
|
152
|
+
)[:, 0, 0, pad:-pad]
|
|
153
|
+
|
|
154
|
+
# Window envelope
|
|
155
|
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
|
156
|
+
window_envelope = torch.nn.functional.fold(
|
|
157
|
+
window_sq,
|
|
158
|
+
output_size=(1, output_size),
|
|
159
|
+
kernel_size=(1, self.win_length),
|
|
160
|
+
stride=(1, self.hop_length),
|
|
161
|
+
).squeeze()[pad:-pad]
|
|
162
|
+
|
|
163
|
+
# Normalize
|
|
164
|
+
assert (window_envelope > 1e-11).all()
|
|
165
|
+
y = y / window_envelope
|
|
166
|
+
|
|
167
|
+
return y
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class MDCT(nn.Module):
|
|
171
|
+
"""
|
|
172
|
+
Modified Discrete Cosine Transform (MDCT) module.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
frame_len (int): Length of the MDCT frame.
|
|
176
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, frame_len: int, padding: str = "same"):
|
|
180
|
+
super().__init__()
|
|
181
|
+
if padding not in ["center", "same"]:
|
|
182
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
183
|
+
self.padding = padding
|
|
184
|
+
self.frame_len = frame_len
|
|
185
|
+
N = frame_len // 2
|
|
186
|
+
n0 = (N + 1) / 2
|
|
187
|
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
|
188
|
+
self.register_buffer("window", window)
|
|
189
|
+
|
|
190
|
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
|
191
|
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
|
192
|
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
|
193
|
+
# https://github.com/pytorch/pytorch/issues/71613
|
|
194
|
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
|
195
|
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
|
196
|
+
|
|
197
|
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
|
198
|
+
"""
|
|
199
|
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
|
203
|
+
and T is the length of the audio.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
|
207
|
+
and N is the number of frequency bins.
|
|
208
|
+
"""
|
|
209
|
+
if self.padding == "center":
|
|
210
|
+
audio = torch.nn.functional.pad(
|
|
211
|
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
|
212
|
+
)
|
|
213
|
+
elif self.padding == "same":
|
|
214
|
+
# hop_length is 1/2 frame_len
|
|
215
|
+
audio = torch.nn.functional.pad(
|
|
216
|
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
220
|
+
|
|
221
|
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
|
222
|
+
N = self.frame_len // 2
|
|
223
|
+
x = x * self.window.expand(x.shape)
|
|
224
|
+
X = torch.fft.fft(
|
|
225
|
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
|
226
|
+
)[..., :N]
|
|
227
|
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
|
228
|
+
return torch.real(res) * np.sqrt(2)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class IMDCT(nn.Module):
|
|
232
|
+
"""
|
|
233
|
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
frame_len (int): Length of the MDCT frame.
|
|
237
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, frame_len: int, padding: str = "same"):
|
|
241
|
+
super().__init__()
|
|
242
|
+
if padding not in ["center", "same"]:
|
|
243
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
244
|
+
self.padding = padding
|
|
245
|
+
self.frame_len = frame_len
|
|
246
|
+
N = frame_len // 2
|
|
247
|
+
n0 = (N + 1) / 2
|
|
248
|
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
|
249
|
+
self.register_buffer("window", window)
|
|
250
|
+
|
|
251
|
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
|
252
|
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
|
253
|
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
|
254
|
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
|
255
|
+
|
|
256
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
257
|
+
"""
|
|
258
|
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
|
262
|
+
L is the number of frames, and N is the number of frequency bins.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
|
266
|
+
"""
|
|
267
|
+
B, L, N = X.shape
|
|
268
|
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
|
269
|
+
Y[..., :N] = X
|
|
270
|
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
|
271
|
+
y = torch.fft.ifft(
|
|
272
|
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
|
273
|
+
)
|
|
274
|
+
y = (
|
|
275
|
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
|
276
|
+
* np.sqrt(N)
|
|
277
|
+
* np.sqrt(2)
|
|
278
|
+
)
|
|
279
|
+
result = y * self.window.expand(y.shape)
|
|
280
|
+
output_size = (1, (L + 1) * N)
|
|
281
|
+
audio = torch.nn.functional.fold(
|
|
282
|
+
result.transpose(1, 2),
|
|
283
|
+
output_size=output_size,
|
|
284
|
+
kernel_size=(1, self.frame_len),
|
|
285
|
+
stride=(1, self.frame_len // 2),
|
|
286
|
+
)[:, 0, 0, :]
|
|
287
|
+
|
|
288
|
+
if self.padding == "center":
|
|
289
|
+
pad = self.frame_len // 2
|
|
290
|
+
elif self.padding == "same":
|
|
291
|
+
pad = self.frame_len // 4
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError("Padding must be 'center' or 'same'.")
|
|
294
|
+
|
|
295
|
+
audio = audio[:, pad:-pad]
|
|
296
|
+
return audio
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class FourierHead(nn.Module):
|
|
300
|
+
"""Base class for inverse fourier modules."""
|
|
301
|
+
|
|
302
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
303
|
+
"""
|
|
304
|
+
Args:
|
|
305
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
306
|
+
L is the sequence length, and H denotes the model dimension.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
310
|
+
"""
|
|
311
|
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class ISTFTHead(FourierHead):
|
|
315
|
+
"""
|
|
316
|
+
ISTFT Head module for predicting STFT complex coefficients.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
dim (int): Hidden dimension of the model.
|
|
320
|
+
n_fft (int): Size of Fourier transform.
|
|
321
|
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
|
322
|
+
the resolution of the input features.
|
|
323
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
|
327
|
+
super().__init__()
|
|
328
|
+
out_dim = n_fft + 2
|
|
329
|
+
self.out = torch.nn.Linear(dim, out_dim)
|
|
330
|
+
self.istft = ISTFT(
|
|
331
|
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
335
|
+
"""
|
|
336
|
+
Forward pass of the ISTFTHead module.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
340
|
+
L is the sequence length, and H denotes the model dimension.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
344
|
+
"""
|
|
345
|
+
x = self.out(x).transpose(1, 2)
|
|
346
|
+
mag, p = x.chunk(2, dim=1)
|
|
347
|
+
mag = torch.exp(mag)
|
|
348
|
+
mag = torch.clip(
|
|
349
|
+
mag, max=1e2
|
|
350
|
+
) # safeguard to prevent excessively large magnitudes
|
|
351
|
+
# wrapping happens here. These two lines produce real and imaginary value
|
|
352
|
+
x = torch.cos(p)
|
|
353
|
+
y = torch.sin(p)
|
|
354
|
+
# recalculating phase here does not produce anything new
|
|
355
|
+
# only costs time
|
|
356
|
+
# phase = torch.atan2(y, x)
|
|
357
|
+
# S = mag * torch.exp(phase * 1j)
|
|
358
|
+
# better directly produce the complex value
|
|
359
|
+
S = mag * (x + 1j * y)
|
|
360
|
+
audio = self.istft(S)
|
|
361
|
+
return audio
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class IMDCTSymExpHead(FourierHead):
|
|
365
|
+
"""
|
|
366
|
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
dim (int): Hidden dimension of the model.
|
|
370
|
+
mdct_frame_len (int): Length of the MDCT frame.
|
|
371
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
372
|
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
|
373
|
+
based on perceptual scaling. Defaults to None.
|
|
374
|
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
def __init__(
|
|
378
|
+
self,
|
|
379
|
+
dim: int,
|
|
380
|
+
mdct_frame_len: int,
|
|
381
|
+
padding: str = "same",
|
|
382
|
+
sample_rate: Optional[int] = None,
|
|
383
|
+
clip_audio: bool = False,
|
|
384
|
+
):
|
|
385
|
+
super().__init__()
|
|
386
|
+
out_dim = mdct_frame_len // 2
|
|
387
|
+
self.out = nn.Linear(dim, out_dim)
|
|
388
|
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
389
|
+
self.clip_audio = clip_audio
|
|
390
|
+
|
|
391
|
+
if sample_rate is not None:
|
|
392
|
+
# optionally init the last layer following mel-scale
|
|
393
|
+
m_max = _hz_to_mel(sample_rate // 2)
|
|
394
|
+
m_pts = torch.linspace(0, m_max, out_dim)
|
|
395
|
+
f_pts = _mel_to_hz(m_pts)
|
|
396
|
+
scale = 1 - (f_pts / f_pts.max())
|
|
397
|
+
|
|
398
|
+
with torch.no_grad():
|
|
399
|
+
self.out.weight.mul_(scale.view(-1, 1))
|
|
400
|
+
|
|
401
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
402
|
+
"""
|
|
403
|
+
Forward pass of the IMDCTSymExpHead module.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
407
|
+
L is the sequence length, and H denotes the model dimension.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
411
|
+
"""
|
|
412
|
+
x = self.out(x)
|
|
413
|
+
x = symexp(x)
|
|
414
|
+
x = torch.clip(
|
|
415
|
+
x, min=-1e2, max=1e2
|
|
416
|
+
) # safeguard to prevent excessively large magnitudes
|
|
417
|
+
audio = self.imdct(x)
|
|
418
|
+
if self.clip_audio:
|
|
419
|
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
420
|
+
|
|
421
|
+
return audio
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class IMDCTCosHead(FourierHead):
|
|
425
|
+
"""
|
|
426
|
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
dim (int): Hidden dimension of the model.
|
|
430
|
+
mdct_frame_len (int): Length of the MDCT frame.
|
|
431
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
432
|
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
def __init__(
|
|
436
|
+
self,
|
|
437
|
+
dim: int,
|
|
438
|
+
mdct_frame_len: int,
|
|
439
|
+
padding: str = "same",
|
|
440
|
+
clip_audio: bool = False,
|
|
441
|
+
):
|
|
442
|
+
super().__init__()
|
|
443
|
+
self.clip_audio = clip_audio
|
|
444
|
+
self.out = nn.Linear(dim, mdct_frame_len)
|
|
445
|
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
446
|
+
|
|
447
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
448
|
+
"""
|
|
449
|
+
Forward pass of the IMDCTCosHead module.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
453
|
+
L is the sequence length, and H denotes the model dimension.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
457
|
+
"""
|
|
458
|
+
x = self.out(x)
|
|
459
|
+
m, p = x.chunk(2, dim=2)
|
|
460
|
+
m = torch.exp(m).clip(
|
|
461
|
+
max=1e2
|
|
462
|
+
) # safeguard to prevent excessively large magnitudes
|
|
463
|
+
audio = self.imdct(m * torch.cos(p))
|
|
464
|
+
if self.clip_audio:
|
|
465
|
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
466
|
+
return audio
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class ConvNeXtBlock(nn.Module):
|
|
470
|
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
dim (int): Number of input channels.
|
|
474
|
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
|
475
|
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
476
|
+
Defaults to None.
|
|
477
|
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
478
|
+
None means non-conditional LayerNorm. Defaults to None.
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
def __init__(
|
|
482
|
+
self,
|
|
483
|
+
dim: int,
|
|
484
|
+
intermediate_dim: int,
|
|
485
|
+
layer_scale_init_value: float,
|
|
486
|
+
adanorm_num_embeddings: Optional[int] = None,
|
|
487
|
+
):
|
|
488
|
+
super().__init__()
|
|
489
|
+
self.dwconv = nn.Conv1d(
|
|
490
|
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
|
491
|
+
) # depthwise conv
|
|
492
|
+
self.adanorm = adanorm_num_embeddings is not None
|
|
493
|
+
if adanorm_num_embeddings:
|
|
494
|
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
495
|
+
else:
|
|
496
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
497
|
+
self.pwconv1 = nn.Linear(
|
|
498
|
+
dim, intermediate_dim
|
|
499
|
+
) # pointwise/1x1 convs, implemented with linear layers
|
|
500
|
+
self.act = nn.GELU()
|
|
501
|
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
502
|
+
self.gamma = (
|
|
503
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
504
|
+
if layer_scale_init_value > 0
|
|
505
|
+
else None
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def forward(
|
|
509
|
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
|
510
|
+
) -> torch.Tensor:
|
|
511
|
+
residual = x
|
|
512
|
+
x = self.dwconv(x)
|
|
513
|
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
514
|
+
if self.adanorm:
|
|
515
|
+
assert cond_embedding_id is not None
|
|
516
|
+
x = self.norm(x, cond_embedding_id)
|
|
517
|
+
else:
|
|
518
|
+
x = self.norm(x)
|
|
519
|
+
x = self.pwconv1(x)
|
|
520
|
+
x = self.act(x)
|
|
521
|
+
x = self.pwconv2(x)
|
|
522
|
+
if self.gamma is not None:
|
|
523
|
+
x = self.gamma * x
|
|
524
|
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
|
525
|
+
|
|
526
|
+
x = residual + x
|
|
527
|
+
return x
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class AdaLayerNorm(nn.Module):
|
|
531
|
+
"""
|
|
532
|
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
num_embeddings (int): Number of embeddings.
|
|
536
|
+
embedding_dim (int): Dimension of the embeddings.
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
|
540
|
+
super().__init__()
|
|
541
|
+
self.eps = eps
|
|
542
|
+
self.dim = embedding_dim
|
|
543
|
+
self.scale = nn.Embedding(
|
|
544
|
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
545
|
+
)
|
|
546
|
+
self.shift = nn.Embedding(
|
|
547
|
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
548
|
+
)
|
|
549
|
+
torch.nn.init.ones_(self.scale.weight)
|
|
550
|
+
torch.nn.init.zeros_(self.shift.weight)
|
|
551
|
+
|
|
552
|
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
|
553
|
+
scale = self.scale(cond_embedding_id)
|
|
554
|
+
shift = self.shift(cond_embedding_id)
|
|
555
|
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
|
556
|
+
x = x * scale + shift
|
|
557
|
+
return x
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class ResBlock1(nn.Module):
|
|
561
|
+
"""
|
|
562
|
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
|
563
|
+
but without upsampling layers.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
dim (int): Number of input channels.
|
|
567
|
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
|
568
|
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
|
569
|
+
Defaults to (1, 3, 5).
|
|
570
|
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
|
571
|
+
Defaults to 0.1.
|
|
572
|
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
573
|
+
Defaults to None.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
def __init__(
|
|
577
|
+
self,
|
|
578
|
+
dim: int,
|
|
579
|
+
kernel_size: int = 3,
|
|
580
|
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
|
581
|
+
lrelu_slope: float = 0.1,
|
|
582
|
+
layer_scale_init_value: Optional[float] = None,
|
|
583
|
+
):
|
|
584
|
+
super().__init__()
|
|
585
|
+
self.lrelu_slope = lrelu_slope
|
|
586
|
+
self.convs1 = nn.ModuleList(
|
|
587
|
+
[
|
|
588
|
+
weight_norm(
|
|
589
|
+
nn.Conv1d(
|
|
590
|
+
dim,
|
|
591
|
+
dim,
|
|
592
|
+
kernel_size,
|
|
593
|
+
1,
|
|
594
|
+
dilation=dilation[0],
|
|
595
|
+
padding=self.get_padding(kernel_size, dilation[0]),
|
|
596
|
+
)
|
|
597
|
+
),
|
|
598
|
+
weight_norm(
|
|
599
|
+
nn.Conv1d(
|
|
600
|
+
dim,
|
|
601
|
+
dim,
|
|
602
|
+
kernel_size,
|
|
603
|
+
1,
|
|
604
|
+
dilation=dilation[1],
|
|
605
|
+
padding=self.get_padding(kernel_size, dilation[1]),
|
|
606
|
+
)
|
|
607
|
+
),
|
|
608
|
+
weight_norm(
|
|
609
|
+
nn.Conv1d(
|
|
610
|
+
dim,
|
|
611
|
+
dim,
|
|
612
|
+
kernel_size,
|
|
613
|
+
1,
|
|
614
|
+
dilation=dilation[2],
|
|
615
|
+
padding=self.get_padding(kernel_size, dilation[2]),
|
|
616
|
+
)
|
|
617
|
+
),
|
|
618
|
+
]
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
self.convs2 = nn.ModuleList(
|
|
622
|
+
[
|
|
623
|
+
weight_norm(
|
|
624
|
+
nn.Conv1d(
|
|
625
|
+
dim,
|
|
626
|
+
dim,
|
|
627
|
+
kernel_size,
|
|
628
|
+
1,
|
|
629
|
+
dilation=1,
|
|
630
|
+
padding=self.get_padding(kernel_size, 1),
|
|
631
|
+
)
|
|
632
|
+
),
|
|
633
|
+
weight_norm(
|
|
634
|
+
nn.Conv1d(
|
|
635
|
+
dim,
|
|
636
|
+
dim,
|
|
637
|
+
kernel_size,
|
|
638
|
+
1,
|
|
639
|
+
dilation=1,
|
|
640
|
+
padding=self.get_padding(kernel_size, 1),
|
|
641
|
+
)
|
|
642
|
+
),
|
|
643
|
+
weight_norm(
|
|
644
|
+
nn.Conv1d(
|
|
645
|
+
dim,
|
|
646
|
+
dim,
|
|
647
|
+
kernel_size,
|
|
648
|
+
1,
|
|
649
|
+
dilation=1,
|
|
650
|
+
padding=self.get_padding(kernel_size, 1),
|
|
651
|
+
)
|
|
652
|
+
),
|
|
653
|
+
]
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
self.gamma = nn.ParameterList(
|
|
657
|
+
[
|
|
658
|
+
(
|
|
659
|
+
nn.Parameter(
|
|
660
|
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
|
661
|
+
)
|
|
662
|
+
if layer_scale_init_value is not None
|
|
663
|
+
else None
|
|
664
|
+
),
|
|
665
|
+
(
|
|
666
|
+
nn.Parameter(
|
|
667
|
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
|
668
|
+
)
|
|
669
|
+
if layer_scale_init_value is not None
|
|
670
|
+
else None
|
|
671
|
+
),
|
|
672
|
+
(
|
|
673
|
+
nn.Parameter(
|
|
674
|
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
|
675
|
+
)
|
|
676
|
+
if layer_scale_init_value is not None
|
|
677
|
+
else None
|
|
678
|
+
),
|
|
679
|
+
]
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
683
|
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
|
684
|
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
|
685
|
+
xt = c1(xt)
|
|
686
|
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
|
687
|
+
xt = c2(xt)
|
|
688
|
+
if gamma is not None:
|
|
689
|
+
xt = gamma * xt
|
|
690
|
+
x = xt + x
|
|
691
|
+
return x
|
|
692
|
+
|
|
693
|
+
def remove_weight_norm(self):
|
|
694
|
+
for l in self.convs1:
|
|
695
|
+
remove_weight_norm(l)
|
|
696
|
+
for l in self.convs2:
|
|
697
|
+
remove_weight_norm(l)
|
|
698
|
+
|
|
699
|
+
@staticmethod
|
|
700
|
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
|
701
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class Backbone(nn.Module):
|
|
705
|
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
|
706
|
+
|
|
707
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
708
|
+
"""
|
|
709
|
+
Args:
|
|
710
|
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
|
711
|
+
C denotes output features, and L is the sequence length.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
|
715
|
+
and H denotes the model dimension.
|
|
716
|
+
"""
|
|
717
|
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
class VocosBackbone(Backbone):
|
|
721
|
+
"""
|
|
722
|
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
input_channels (int): Number of input features channels.
|
|
726
|
+
dim (int): Hidden dimension of the model.
|
|
727
|
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
|
728
|
+
num_layers (int): Number of ConvNeXtBlock layers.
|
|
729
|
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
|
730
|
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
731
|
+
None means non-conditional model. Defaults to None.
|
|
732
|
+
"""
|
|
733
|
+
|
|
734
|
+
def __init__(
|
|
735
|
+
self,
|
|
736
|
+
input_channels: int,
|
|
737
|
+
dim: int,
|
|
738
|
+
intermediate_dim: int,
|
|
739
|
+
num_layers: int,
|
|
740
|
+
layer_scale_init_value: Optional[float] = None,
|
|
741
|
+
adanorm_num_embeddings: Optional[int] = None,
|
|
742
|
+
):
|
|
743
|
+
super().__init__()
|
|
744
|
+
self.input_channels = input_channels
|
|
745
|
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
|
746
|
+
self.adanorm = adanorm_num_embeddings is not None
|
|
747
|
+
if adanorm_num_embeddings:
|
|
748
|
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
749
|
+
else:
|
|
750
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
751
|
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
|
752
|
+
self.convnext = nn.ModuleList(
|
|
753
|
+
[
|
|
754
|
+
ConvNeXtBlock(
|
|
755
|
+
dim=dim,
|
|
756
|
+
intermediate_dim=intermediate_dim,
|
|
757
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
758
|
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
|
759
|
+
)
|
|
760
|
+
for _ in range(num_layers)
|
|
761
|
+
]
|
|
762
|
+
)
|
|
763
|
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
764
|
+
self.apply(self._init_weights)
|
|
765
|
+
|
|
766
|
+
def _init_weights(self, m):
|
|
767
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
768
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
769
|
+
nn.init.constant_(m.bias, 0)
|
|
770
|
+
|
|
771
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
772
|
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
|
773
|
+
x = self.embed(x)
|
|
774
|
+
if self.adanorm:
|
|
775
|
+
assert bandwidth_id is not None
|
|
776
|
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
|
777
|
+
else:
|
|
778
|
+
x = self.norm(x.transpose(1, 2))
|
|
779
|
+
x = x.transpose(1, 2)
|
|
780
|
+
for conv_block in self.convnext:
|
|
781
|
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
|
782
|
+
x = self.final_layer_norm(x.transpose(1, 2))
|
|
783
|
+
return x
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
class VocosResNetBackbone(Backbone):
|
|
787
|
+
"""
|
|
788
|
+
Vocos backbone module built with ResBlocks.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
input_channels (int): Number of input features channels.
|
|
792
|
+
dim (int): Hidden dimension of the model.
|
|
793
|
+
num_blocks (int): Number of ResBlock1 blocks.
|
|
794
|
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
|
795
|
+
"""
|
|
796
|
+
|
|
797
|
+
def __init__(
|
|
798
|
+
self,
|
|
799
|
+
input_channels,
|
|
800
|
+
dim,
|
|
801
|
+
num_blocks,
|
|
802
|
+
layer_scale_init_value=None,
|
|
803
|
+
):
|
|
804
|
+
super().__init__()
|
|
805
|
+
self.input_channels = input_channels
|
|
806
|
+
self.embed = weight_norm(
|
|
807
|
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
|
808
|
+
)
|
|
809
|
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
|
810
|
+
self.resnet = nn.Sequential(
|
|
811
|
+
*[
|
|
812
|
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
|
813
|
+
for _ in range(num_blocks)
|
|
814
|
+
]
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
818
|
+
x = self.embed(x)
|
|
819
|
+
x = self.resnet(x)
|
|
820
|
+
x = x.transpose(1, 2)
|
|
821
|
+
return x
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
class Vocos(nn.Module):
|
|
825
|
+
def __init__(
|
|
826
|
+
self,
|
|
827
|
+
input_channels: int = 256,
|
|
828
|
+
dim: int = 384,
|
|
829
|
+
intermediate_dim: int = 1152,
|
|
830
|
+
num_layers: int = 8,
|
|
831
|
+
n_fft: int = 800,
|
|
832
|
+
hop_size: int = 200,
|
|
833
|
+
padding: str = "same",
|
|
834
|
+
adanorm_num_embeddings=None,
|
|
835
|
+
cfg=None,
|
|
836
|
+
):
|
|
837
|
+
super().__init__()
|
|
838
|
+
|
|
839
|
+
input_channels = (
|
|
840
|
+
cfg.input_channels
|
|
841
|
+
if cfg is not None and hasattr(cfg, "input_channels")
|
|
842
|
+
else input_channels
|
|
843
|
+
)
|
|
844
|
+
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
|
|
845
|
+
intermediate_dim = (
|
|
846
|
+
cfg.intermediate_dim
|
|
847
|
+
if cfg is not None and hasattr(cfg, "intermediate_dim")
|
|
848
|
+
else intermediate_dim
|
|
849
|
+
)
|
|
850
|
+
num_layers = (
|
|
851
|
+
cfg.num_layers
|
|
852
|
+
if cfg is not None and hasattr(cfg, "num_layers")
|
|
853
|
+
else num_layers
|
|
854
|
+
)
|
|
855
|
+
adanorm_num_embeddings = (
|
|
856
|
+
cfg.adanorm_num_embeddings
|
|
857
|
+
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
|
|
858
|
+
else adanorm_num_embeddings
|
|
859
|
+
)
|
|
860
|
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
|
861
|
+
hop_size = (
|
|
862
|
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
|
863
|
+
)
|
|
864
|
+
padding = (
|
|
865
|
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
self.backbone = VocosBackbone(
|
|
869
|
+
input_channels=input_channels,
|
|
870
|
+
dim=dim,
|
|
871
|
+
intermediate_dim=intermediate_dim,
|
|
872
|
+
num_layers=num_layers,
|
|
873
|
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
|
874
|
+
)
|
|
875
|
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
|
876
|
+
|
|
877
|
+
def forward(self, x):
|
|
878
|
+
x = self.backbone(x)
|
|
879
|
+
x = self.head(x)
|
|
880
|
+
|
|
881
|
+
return x[:, None, :]
|