xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +11 -28
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/core/supervisor.py +87 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +38 -1
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +464 -2
- xinference/model/llm/sglang/core.py +30 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +12 -9
- xinference/model/llm/vllm/core.py +93 -17
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
|
6
|
+
|
|
7
|
+
from .spectral_ops import IMDCT, ISTFT
|
|
8
|
+
from .modules import symexp
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FourierHead(nn.Module):
|
|
12
|
+
"""Base class for inverse fourier modules."""
|
|
13
|
+
|
|
14
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
15
|
+
"""
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
18
|
+
L is the sequence length, and H denotes the model dimension.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
22
|
+
"""
|
|
23
|
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ISTFTHead(FourierHead):
|
|
27
|
+
"""
|
|
28
|
+
ISTFT Head module for predicting STFT complex coefficients.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
dim (int): Hidden dimension of the model.
|
|
32
|
+
n_fft (int): Size of Fourier transform.
|
|
33
|
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
|
34
|
+
the resolution of the input features.
|
|
35
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
|
39
|
+
super().__init__()
|
|
40
|
+
out_dim = n_fft + 2
|
|
41
|
+
self.out = torch.nn.Linear(dim, out_dim)
|
|
42
|
+
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
|
|
43
|
+
|
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
Forward pass of the ISTFTHead module.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
50
|
+
L is the sequence length, and H denotes the model dimension.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
54
|
+
"""
|
|
55
|
+
x = self.out(x).transpose(1, 2)
|
|
56
|
+
mag, p = x.chunk(2, dim=1)
|
|
57
|
+
mag = torch.exp(mag)
|
|
58
|
+
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
|
|
59
|
+
# wrapping happens here. These two lines produce real and imaginary value
|
|
60
|
+
x = torch.cos(p)
|
|
61
|
+
y = torch.sin(p)
|
|
62
|
+
# recalculating phase here does not produce anything new
|
|
63
|
+
# only costs time
|
|
64
|
+
# phase = torch.atan2(y, x)
|
|
65
|
+
# S = mag * torch.exp(phase * 1j)
|
|
66
|
+
# better directly produce the complex value
|
|
67
|
+
S = mag * (x + 1j * y)
|
|
68
|
+
audio = self.istft(S)
|
|
69
|
+
return audio
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class IMDCTSymExpHead(FourierHead):
|
|
73
|
+
"""
|
|
74
|
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
dim (int): Hidden dimension of the model.
|
|
78
|
+
mdct_frame_len (int): Length of the MDCT frame.
|
|
79
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
80
|
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
|
81
|
+
based on perceptual scaling. Defaults to None.
|
|
82
|
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
dim: int,
|
|
88
|
+
mdct_frame_len: int,
|
|
89
|
+
padding: str = "same",
|
|
90
|
+
sample_rate: Optional[int] = None,
|
|
91
|
+
clip_audio: bool = False,
|
|
92
|
+
):
|
|
93
|
+
super().__init__()
|
|
94
|
+
out_dim = mdct_frame_len // 2
|
|
95
|
+
self.out = nn.Linear(dim, out_dim)
|
|
96
|
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
97
|
+
self.clip_audio = clip_audio
|
|
98
|
+
|
|
99
|
+
if sample_rate is not None:
|
|
100
|
+
# optionally init the last layer following mel-scale
|
|
101
|
+
m_max = _hz_to_mel(sample_rate // 2)
|
|
102
|
+
m_pts = torch.linspace(0, m_max, out_dim)
|
|
103
|
+
f_pts = _mel_to_hz(m_pts)
|
|
104
|
+
scale = 1 - (f_pts / f_pts.max())
|
|
105
|
+
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
self.out.weight.mul_(scale.view(-1, 1))
|
|
108
|
+
|
|
109
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Forward pass of the IMDCTSymExpHead module.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
115
|
+
L is the sequence length, and H denotes the model dimension.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
119
|
+
"""
|
|
120
|
+
x = self.out(x)
|
|
121
|
+
x = symexp(x)
|
|
122
|
+
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
|
|
123
|
+
audio = self.imdct(x)
|
|
124
|
+
if self.clip_audio:
|
|
125
|
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
126
|
+
|
|
127
|
+
return audio
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class IMDCTCosHead(FourierHead):
|
|
131
|
+
"""
|
|
132
|
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
dim (int): Hidden dimension of the model.
|
|
136
|
+
mdct_frame_len (int): Length of the MDCT frame.
|
|
137
|
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
138
|
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.clip_audio = clip_audio
|
|
144
|
+
self.out = nn.Linear(dim, mdct_frame_len)
|
|
145
|
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
146
|
+
|
|
147
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
"""
|
|
149
|
+
Forward pass of the IMDCTCosHead module.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
153
|
+
L is the sequence length, and H denotes the model dimension.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
157
|
+
"""
|
|
158
|
+
x = self.out(x)
|
|
159
|
+
m, p = x.chunk(2, dim=2)
|
|
160
|
+
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
|
|
161
|
+
audio = self.imdct(m * torch.cos(p))
|
|
162
|
+
if self.clip_audio:
|
|
163
|
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
164
|
+
return audio
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import matplotlib
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from matplotlib import pyplot as plt
|
|
5
|
+
from pytorch_lightning import Callback
|
|
6
|
+
|
|
7
|
+
matplotlib.use("Agg")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
|
|
11
|
+
"""
|
|
12
|
+
Save a matplotlib figure to a numpy array.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
fig (Figure): Matplotlib figure object.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
ndarray: Numpy array representing the figure.
|
|
19
|
+
"""
|
|
20
|
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
|
21
|
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
22
|
+
return data
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
|
|
26
|
+
"""
|
|
27
|
+
Plot a spectrogram and convert it to a numpy array.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
spectrogram (ndarray): Spectrogram data.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
ndarray: Numpy array representing the plotted spectrogram.
|
|
34
|
+
"""
|
|
35
|
+
spectrogram = spectrogram.astype(np.float32)
|
|
36
|
+
fig, ax = plt.subplots(figsize=(12, 3))
|
|
37
|
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
38
|
+
plt.colorbar(im, ax=ax)
|
|
39
|
+
plt.xlabel("Frames")
|
|
40
|
+
plt.ylabel("Channels")
|
|
41
|
+
plt.tight_layout()
|
|
42
|
+
|
|
43
|
+
fig.canvas.draw()
|
|
44
|
+
data = save_figure_to_numpy(fig)
|
|
45
|
+
plt.close()
|
|
46
|
+
return data
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GradNormCallback(Callback):
|
|
50
|
+
"""
|
|
51
|
+
Callback to log the gradient norm.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def on_after_backward(self, trainer, model):
|
|
55
|
+
model.log("grad_norm", gradient_norm(model))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Compute the gradient norm.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model (Module): PyTorch model.
|
|
64
|
+
norm_type (float, optional): Type of the norm. Defaults to 2.0.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Tensor: Gradient norm.
|
|
68
|
+
"""
|
|
69
|
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
|
70
|
+
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
|
|
71
|
+
return total_norm
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchaudio
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from vocos.modules import safe_log
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MelSpecReconstructionLoss(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
|
20
|
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def forward(self, y_hat, y) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Args:
|
|
26
|
+
y_hat (Tensor): Predicted audio waveform.
|
|
27
|
+
y (Tensor): Ground truth audio waveform.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
|
31
|
+
"""
|
|
32
|
+
mel_hat = safe_log(self.mel_spec(y_hat))
|
|
33
|
+
mel = safe_log(self.mel_spec(y))
|
|
34
|
+
|
|
35
|
+
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
|
36
|
+
|
|
37
|
+
return loss
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GeneratorLoss(nn.Module):
|
|
41
|
+
"""
|
|
42
|
+
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
46
|
+
"""
|
|
47
|
+
Args:
|
|
48
|
+
disc_outputs (List[Tensor]): List of discriminator outputs.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
|
52
|
+
the sub-discriminators
|
|
53
|
+
"""
|
|
54
|
+
loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
|
|
55
|
+
gen_losses = []
|
|
56
|
+
for dg in disc_outputs:
|
|
57
|
+
l = torch.mean(torch.clamp(1 - dg, min=0))
|
|
58
|
+
gen_losses.append(l)
|
|
59
|
+
loss += l
|
|
60
|
+
|
|
61
|
+
return loss, gen_losses
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DiscriminatorLoss(nn.Module):
|
|
65
|
+
"""
|
|
66
|
+
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def forward(
|
|
70
|
+
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
|
71
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
|
72
|
+
"""
|
|
73
|
+
Args:
|
|
74
|
+
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
|
75
|
+
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
|
79
|
+
the sub-discriminators for real outputs, and a list of
|
|
80
|
+
loss values for generated outputs.
|
|
81
|
+
"""
|
|
82
|
+
loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
|
|
83
|
+
r_losses = []
|
|
84
|
+
g_losses = []
|
|
85
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
86
|
+
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
|
87
|
+
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
|
88
|
+
loss += r_loss + g_loss
|
|
89
|
+
r_losses.append(r_loss)
|
|
90
|
+
g_losses.append(g_loss)
|
|
91
|
+
|
|
92
|
+
return loss, r_losses, g_losses
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class FeatureMatchingLoss(nn.Module):
|
|
96
|
+
"""
|
|
97
|
+
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
|
101
|
+
"""
|
|
102
|
+
Args:
|
|
103
|
+
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
|
104
|
+
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tensor: The calculated feature matching loss.
|
|
108
|
+
"""
|
|
109
|
+
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
|
110
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
|
111
|
+
for rl, gl in zip(dr, dg):
|
|
112
|
+
loss += torch.mean(torch.abs(rl - gl))
|
|
113
|
+
|
|
114
|
+
return loss
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.nn.utils import weight_norm
|
|
6
|
+
|
|
7
|
+
from .modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Backbone(nn.Module):
|
|
11
|
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
|
12
|
+
|
|
13
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
14
|
+
"""
|
|
15
|
+
Args:
|
|
16
|
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
|
17
|
+
C denotes output features, and L is the sequence length.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
|
21
|
+
and H denotes the model dimension.
|
|
22
|
+
"""
|
|
23
|
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VocosBackbone(Backbone):
|
|
27
|
+
"""
|
|
28
|
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
input_channels (int): Number of input features channels.
|
|
32
|
+
dim (int): Hidden dimension of the model.
|
|
33
|
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
|
34
|
+
num_layers (int): Number of ConvNeXtBlock layers.
|
|
35
|
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
|
36
|
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
37
|
+
None means non-conditional model. Defaults to None.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
input_channels: int,
|
|
43
|
+
dim: int,
|
|
44
|
+
intermediate_dim: int,
|
|
45
|
+
num_layers: int,
|
|
46
|
+
layer_scale_init_value: Optional[float] = None,
|
|
47
|
+
adanorm_num_embeddings: Optional[int] = None,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.input_channels = input_channels
|
|
51
|
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
|
52
|
+
self.adanorm = adanorm_num_embeddings is not None
|
|
53
|
+
if adanorm_num_embeddings:
|
|
54
|
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
55
|
+
else:
|
|
56
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
57
|
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
|
58
|
+
self.convnext = nn.ModuleList(
|
|
59
|
+
[
|
|
60
|
+
ConvNeXtBlock(
|
|
61
|
+
dim=dim,
|
|
62
|
+
intermediate_dim=intermediate_dim,
|
|
63
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
64
|
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
|
65
|
+
)
|
|
66
|
+
for _ in range(num_layers)
|
|
67
|
+
]
|
|
68
|
+
)
|
|
69
|
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
70
|
+
self.apply(self._init_weights)
|
|
71
|
+
|
|
72
|
+
def _init_weights(self, m):
|
|
73
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
74
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
75
|
+
nn.init.constant_(m.bias, 0)
|
|
76
|
+
|
|
77
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
78
|
+
bandwidth_id = kwargs.get('bandwidth_id', None)
|
|
79
|
+
x = self.embed(x)
|
|
80
|
+
if self.adanorm:
|
|
81
|
+
assert bandwidth_id is not None
|
|
82
|
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
|
83
|
+
else:
|
|
84
|
+
x = self.norm(x.transpose(1, 2))
|
|
85
|
+
x = x.transpose(1, 2)
|
|
86
|
+
for conv_block in self.convnext:
|
|
87
|
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
|
88
|
+
x = self.final_layer_norm(x.transpose(1, 2))
|
|
89
|
+
return x
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class VocosResNetBackbone(Backbone):
|
|
93
|
+
"""
|
|
94
|
+
Vocos backbone module built with ResBlocks.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
input_channels (int): Number of input features channels.
|
|
98
|
+
dim (int): Hidden dimension of the model.
|
|
99
|
+
num_blocks (int): Number of ResBlock1 blocks.
|
|
100
|
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.input_channels = input_channels
|
|
108
|
+
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
|
|
109
|
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
|
110
|
+
self.resnet = nn.Sequential(
|
|
111
|
+
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
115
|
+
x = self.embed(x)
|
|
116
|
+
x = self.resnet(x)
|
|
117
|
+
x = x.transpose(1, 2)
|
|
118
|
+
return x
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConvNeXtBlock(nn.Module):
|
|
9
|
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
dim (int): Number of input channels.
|
|
13
|
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
|
14
|
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
15
|
+
Defaults to None.
|
|
16
|
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
17
|
+
None means non-conditional LayerNorm. Defaults to None.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dim: int,
|
|
23
|
+
intermediate_dim: int,
|
|
24
|
+
layer_scale_init_value: float,
|
|
25
|
+
adanorm_num_embeddings: Optional[int] = None,
|
|
26
|
+
):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
|
29
|
+
self.adanorm = adanorm_num_embeddings is not None
|
|
30
|
+
if adanorm_num_embeddings:
|
|
31
|
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
32
|
+
else:
|
|
33
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
34
|
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
|
35
|
+
self.act = nn.GELU()
|
|
36
|
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
37
|
+
self.gamma = (
|
|
38
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
39
|
+
if layer_scale_init_value > 0
|
|
40
|
+
else None
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
44
|
+
residual = x
|
|
45
|
+
x = self.dwconv(x)
|
|
46
|
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
47
|
+
if self.adanorm:
|
|
48
|
+
assert cond_embedding_id is not None
|
|
49
|
+
x = self.norm(x, cond_embedding_id)
|
|
50
|
+
else:
|
|
51
|
+
x = self.norm(x)
|
|
52
|
+
x = self.pwconv1(x)
|
|
53
|
+
x = self.act(x)
|
|
54
|
+
x = self.pwconv2(x)
|
|
55
|
+
if self.gamma is not None:
|
|
56
|
+
x = self.gamma * x
|
|
57
|
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
|
58
|
+
|
|
59
|
+
x = residual + x
|
|
60
|
+
return x
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class AdaLayerNorm(nn.Module):
|
|
64
|
+
"""
|
|
65
|
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
num_embeddings (int): Number of embeddings.
|
|
69
|
+
embedding_dim (int): Dimension of the embeddings.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.eps = eps
|
|
75
|
+
self.dim = embedding_dim
|
|
76
|
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
|
77
|
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
|
78
|
+
torch.nn.init.ones_(self.scale.weight)
|
|
79
|
+
torch.nn.init.zeros_(self.shift.weight)
|
|
80
|
+
|
|
81
|
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
scale = self.scale(cond_embedding_id)
|
|
83
|
+
shift = self.shift(cond_embedding_id)
|
|
84
|
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
|
85
|
+
x = x * scale + shift
|
|
86
|
+
return x
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ResBlock1(nn.Module):
|
|
90
|
+
"""
|
|
91
|
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
|
92
|
+
but without upsampling layers.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
dim (int): Number of input channels.
|
|
96
|
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
|
97
|
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
|
98
|
+
Defaults to (1, 3, 5).
|
|
99
|
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
|
100
|
+
Defaults to 0.1.
|
|
101
|
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
102
|
+
Defaults to None.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
dim: int,
|
|
108
|
+
kernel_size: int = 3,
|
|
109
|
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
|
110
|
+
lrelu_slope: float = 0.1,
|
|
111
|
+
layer_scale_init_value: Optional[float] = None,
|
|
112
|
+
):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.lrelu_slope = lrelu_slope
|
|
115
|
+
self.convs1 = nn.ModuleList(
|
|
116
|
+
[
|
|
117
|
+
weight_norm(
|
|
118
|
+
nn.Conv1d(
|
|
119
|
+
dim,
|
|
120
|
+
dim,
|
|
121
|
+
kernel_size,
|
|
122
|
+
1,
|
|
123
|
+
dilation=dilation[0],
|
|
124
|
+
padding=self.get_padding(kernel_size, dilation[0]),
|
|
125
|
+
)
|
|
126
|
+
),
|
|
127
|
+
weight_norm(
|
|
128
|
+
nn.Conv1d(
|
|
129
|
+
dim,
|
|
130
|
+
dim,
|
|
131
|
+
kernel_size,
|
|
132
|
+
1,
|
|
133
|
+
dilation=dilation[1],
|
|
134
|
+
padding=self.get_padding(kernel_size, dilation[1]),
|
|
135
|
+
)
|
|
136
|
+
),
|
|
137
|
+
weight_norm(
|
|
138
|
+
nn.Conv1d(
|
|
139
|
+
dim,
|
|
140
|
+
dim,
|
|
141
|
+
kernel_size,
|
|
142
|
+
1,
|
|
143
|
+
dilation=dilation[2],
|
|
144
|
+
padding=self.get_padding(kernel_size, dilation[2]),
|
|
145
|
+
)
|
|
146
|
+
),
|
|
147
|
+
]
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.convs2 = nn.ModuleList(
|
|
151
|
+
[
|
|
152
|
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
|
153
|
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
|
154
|
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.gamma = nn.ParameterList(
|
|
159
|
+
[
|
|
160
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
|
161
|
+
if layer_scale_init_value is not None
|
|
162
|
+
else None,
|
|
163
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
|
164
|
+
if layer_scale_init_value is not None
|
|
165
|
+
else None,
|
|
166
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
|
167
|
+
if layer_scale_init_value is not None
|
|
168
|
+
else None,
|
|
169
|
+
]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
173
|
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
|
174
|
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
|
175
|
+
xt = c1(xt)
|
|
176
|
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
|
177
|
+
xt = c2(xt)
|
|
178
|
+
if gamma is not None:
|
|
179
|
+
xt = gamma * xt
|
|
180
|
+
x = xt + x
|
|
181
|
+
return x
|
|
182
|
+
|
|
183
|
+
def remove_weight_norm(self):
|
|
184
|
+
for l in self.convs1:
|
|
185
|
+
remove_weight_norm(l)
|
|
186
|
+
for l in self.convs2:
|
|
187
|
+
remove_weight_norm(l)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
|
191
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
|
195
|
+
"""
|
|
196
|
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
x (Tensor): Input tensor.
|
|
200
|
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
|
204
|
+
"""
|
|
205
|
+
return torch.log(torch.clip(x, min=clip_val))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
|
209
|
+
return torch.sign(x) * torch.log1p(x.abs())
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
|
213
|
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|