xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
|
2
|
+
# LICENSE is in incl_licenses directory.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, pow, sin
|
|
6
|
+
from torch.nn import Parameter
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Snake(nn.Module):
|
|
10
|
+
'''
|
|
11
|
+
Implementation of a sine-based periodic activation function
|
|
12
|
+
Shape:
|
|
13
|
+
- Input: (B, C, T)
|
|
14
|
+
- Output: (B, C, T), same shape as the input
|
|
15
|
+
Parameters:
|
|
16
|
+
- alpha - trainable parameter
|
|
17
|
+
References:
|
|
18
|
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
|
19
|
+
https://arxiv.org/abs/2006.08195
|
|
20
|
+
Examples:
|
|
21
|
+
>>> a1 = snake(256)
|
|
22
|
+
>>> x = torch.randn(256)
|
|
23
|
+
>>> x = a1(x)
|
|
24
|
+
'''
|
|
25
|
+
|
|
26
|
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
|
27
|
+
'''
|
|
28
|
+
Initialization.
|
|
29
|
+
INPUT:
|
|
30
|
+
- in_features: shape of the input
|
|
31
|
+
- alpha: trainable parameter
|
|
32
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
|
33
|
+
alpha will be trained along with the rest of your model.
|
|
34
|
+
'''
|
|
35
|
+
super(Snake, self).__init__()
|
|
36
|
+
self.in_features = in_features
|
|
37
|
+
|
|
38
|
+
# initialize alpha
|
|
39
|
+
self.alpha_logscale = alpha_logscale
|
|
40
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
41
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
|
42
|
+
else: # linear scale alphas initialized to ones
|
|
43
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
|
44
|
+
|
|
45
|
+
self.alpha.requires_grad = alpha_trainable
|
|
46
|
+
|
|
47
|
+
self.no_div_by_zero = 0.000000001
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
'''
|
|
51
|
+
Forward pass of the function.
|
|
52
|
+
Applies the function to the input elementwise.
|
|
53
|
+
Snake ∶= x + 1/a * sin^2 (xa)
|
|
54
|
+
'''
|
|
55
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
|
56
|
+
if self.alpha_logscale:
|
|
57
|
+
alpha = torch.exp(alpha)
|
|
58
|
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
|
59
|
+
|
|
60
|
+
return x
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class SnakeBeta(nn.Module):
|
|
64
|
+
'''
|
|
65
|
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
|
66
|
+
Shape:
|
|
67
|
+
- Input: (B, C, T)
|
|
68
|
+
- Output: (B, C, T), same shape as the input
|
|
69
|
+
Parameters:
|
|
70
|
+
- alpha - trainable parameter that controls frequency
|
|
71
|
+
- beta - trainable parameter that controls magnitude
|
|
72
|
+
References:
|
|
73
|
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
|
74
|
+
https://arxiv.org/abs/2006.08195
|
|
75
|
+
Examples:
|
|
76
|
+
>>> a1 = snakebeta(256)
|
|
77
|
+
>>> x = torch.randn(256)
|
|
78
|
+
>>> x = a1(x)
|
|
79
|
+
'''
|
|
80
|
+
|
|
81
|
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
|
82
|
+
'''
|
|
83
|
+
Initialization.
|
|
84
|
+
INPUT:
|
|
85
|
+
- in_features: shape of the input
|
|
86
|
+
- alpha - trainable parameter that controls frequency
|
|
87
|
+
- beta - trainable parameter that controls magnitude
|
|
88
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
|
89
|
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
|
90
|
+
alpha will be trained along with the rest of your model.
|
|
91
|
+
'''
|
|
92
|
+
super(SnakeBeta, self).__init__()
|
|
93
|
+
self.in_features = in_features
|
|
94
|
+
|
|
95
|
+
# initialize alpha
|
|
96
|
+
self.alpha_logscale = alpha_logscale
|
|
97
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
98
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
|
99
|
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
|
100
|
+
else: # linear scale alphas initialized to ones
|
|
101
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
|
102
|
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
|
103
|
+
|
|
104
|
+
self.alpha.requires_grad = alpha_trainable
|
|
105
|
+
self.beta.requires_grad = alpha_trainable
|
|
106
|
+
|
|
107
|
+
self.no_div_by_zero = 0.000000001
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
'''
|
|
111
|
+
Forward pass of the function.
|
|
112
|
+
Applies the function to the input elementwise.
|
|
113
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
|
114
|
+
'''
|
|
115
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
|
116
|
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
|
117
|
+
if self.alpha_logscale:
|
|
118
|
+
alpha = torch.exp(alpha)
|
|
119
|
+
beta = torch.exp(beta)
|
|
120
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
|
121
|
+
|
|
122
|
+
return x
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
/build
|
|
File without changes
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
+
# Licensed under the MIT license.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
|
7
|
+
from indextts.BigVGAN.alias_free_activation.cuda import load
|
|
8
|
+
from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
|
9
|
+
|
|
10
|
+
anti_alias_activation_cuda = load.load()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
|
14
|
+
"""
|
|
15
|
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
|
16
|
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
|
17
|
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
|
22
|
+
activation_results = anti_alias_activation_cuda.forward(
|
|
23
|
+
inputs, up_ftr, down_ftr, alpha, beta
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return activation_results
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def backward(ctx, output_grads):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
return output_grads, None, None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Activation1d(nn.Module):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
activation,
|
|
38
|
+
up_ratio: int = 2,
|
|
39
|
+
down_ratio: int = 2,
|
|
40
|
+
up_kernel_size: int = 12,
|
|
41
|
+
down_kernel_size: int = 12,
|
|
42
|
+
fused: bool = True,
|
|
43
|
+
):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.up_ratio = up_ratio
|
|
46
|
+
self.down_ratio = down_ratio
|
|
47
|
+
self.act = activation
|
|
48
|
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
|
49
|
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
|
50
|
+
|
|
51
|
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
|
52
|
+
|
|
53
|
+
def forward(self, x):
|
|
54
|
+
if not self.fused:
|
|
55
|
+
x = self.upsample(x)
|
|
56
|
+
x = self.act(x)
|
|
57
|
+
x = self.downsample(x)
|
|
58
|
+
return x
|
|
59
|
+
else:
|
|
60
|
+
if self.act.__class__.__name__ == "Snake":
|
|
61
|
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
|
62
|
+
else:
|
|
63
|
+
beta = (
|
|
64
|
+
self.act.beta.data
|
|
65
|
+
) # Snakebeta uses different params for alpha and beta
|
|
66
|
+
alpha = self.act.alpha.data
|
|
67
|
+
if (
|
|
68
|
+
not self.act.alpha_logscale
|
|
69
|
+
): # Exp baked into cuda kernel, cancel it out with a log
|
|
70
|
+
alpha = torch.log(alpha)
|
|
71
|
+
beta = torch.log(beta)
|
|
72
|
+
|
|
73
|
+
x = FusedAntiAliasActivation.apply(
|
|
74
|
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
|
75
|
+
)
|
|
76
|
+
return x
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/* coding=utf-8
|
|
2
|
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include <torch/extension.h>
|
|
18
|
+
|
|
19
|
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
|
20
|
+
|
|
21
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
22
|
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
|
23
|
+
}
|
xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
/* coding=utf-8
|
|
2
|
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include <ATen/ATen.h>
|
|
18
|
+
#include <cuda.h>
|
|
19
|
+
#include <cuda_runtime.h>
|
|
20
|
+
#include <cuda_fp16.h>
|
|
21
|
+
#include <cuda_profiler_api.h>
|
|
22
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
23
|
+
#include <torch/extension.h>
|
|
24
|
+
#include "type_shim.h"
|
|
25
|
+
#include <assert.h>
|
|
26
|
+
#include <cfloat>
|
|
27
|
+
#include <limits>
|
|
28
|
+
#include <stdint.h>
|
|
29
|
+
#include <c10/macros/Macros.h>
|
|
30
|
+
|
|
31
|
+
namespace
|
|
32
|
+
{
|
|
33
|
+
// Hard-coded hyperparameters
|
|
34
|
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
|
35
|
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
|
36
|
+
constexpr int BUFFER_SIZE = 32;
|
|
37
|
+
constexpr int FILTER_SIZE = 12;
|
|
38
|
+
constexpr int HALF_FILTER_SIZE = 6;
|
|
39
|
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
|
40
|
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
|
41
|
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
|
42
|
+
|
|
43
|
+
template <typename input_t, typename output_t, typename acc_t>
|
|
44
|
+
__global__ void anti_alias_activation_forward(
|
|
45
|
+
output_t *dst,
|
|
46
|
+
const input_t *src,
|
|
47
|
+
const acc_t *up_ftr,
|
|
48
|
+
const acc_t *down_ftr,
|
|
49
|
+
const acc_t *alpha,
|
|
50
|
+
const acc_t *beta,
|
|
51
|
+
int batch_size,
|
|
52
|
+
int channels,
|
|
53
|
+
int seq_len)
|
|
54
|
+
{
|
|
55
|
+
// Up and downsample filters
|
|
56
|
+
input_t up_filter[FILTER_SIZE];
|
|
57
|
+
input_t down_filter[FILTER_SIZE];
|
|
58
|
+
|
|
59
|
+
// Load data from global memory including extra indices reserved for replication paddings
|
|
60
|
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
|
61
|
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
|
62
|
+
|
|
63
|
+
// Output stores downsampled output before writing to dst
|
|
64
|
+
output_t output[BUFFER_SIZE];
|
|
65
|
+
|
|
66
|
+
// blockDim/threadIdx = (128, 1, 1)
|
|
67
|
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
|
68
|
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
|
69
|
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
|
70
|
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
|
71
|
+
|
|
72
|
+
// intermediate have double the seq_len
|
|
73
|
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
|
74
|
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
|
75
|
+
|
|
76
|
+
// Get values needed for replication padding before moving pointer
|
|
77
|
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
|
78
|
+
input_t seq_left_most_value = right_most_pntr[0];
|
|
79
|
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
|
80
|
+
|
|
81
|
+
// Move src and dst pointers
|
|
82
|
+
src += block_offset + local_offset;
|
|
83
|
+
dst += block_offset + local_offset;
|
|
84
|
+
|
|
85
|
+
// Alpha and beta values for snake activatons. Applies exp by default
|
|
86
|
+
alpha = alpha + blockIdx.y;
|
|
87
|
+
beta = beta + blockIdx.y;
|
|
88
|
+
|
|
89
|
+
acc_t alpha_val = expf(alpha[0]);
|
|
90
|
+
acc_t beta_val = expf(beta[0]);
|
|
91
|
+
|
|
92
|
+
#pragma unroll
|
|
93
|
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
|
94
|
+
{
|
|
95
|
+
up_filter[it] = up_ftr[it];
|
|
96
|
+
down_filter[it] = down_ftr[it];
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Apply replication padding for upsampling, matching torch impl
|
|
100
|
+
#pragma unroll
|
|
101
|
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
|
102
|
+
{
|
|
103
|
+
int element_index = seq_offset + it; // index for element
|
|
104
|
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
|
105
|
+
{
|
|
106
|
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
|
107
|
+
}
|
|
108
|
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
|
109
|
+
{
|
|
110
|
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
|
111
|
+
}
|
|
112
|
+
if ((element_index >= 0) && (element_index < seq_len))
|
|
113
|
+
{
|
|
114
|
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
|
119
|
+
#pragma unroll
|
|
120
|
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
|
121
|
+
{
|
|
122
|
+
acc_t acc = 0.0;
|
|
123
|
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
|
124
|
+
#pragma unroll
|
|
125
|
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
|
126
|
+
{
|
|
127
|
+
if ((element_index + f_idx) >= 0)
|
|
128
|
+
{
|
|
129
|
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
|
136
|
+
double no_div_by_zero = 0.000000001;
|
|
137
|
+
#pragma unroll
|
|
138
|
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
|
139
|
+
{
|
|
140
|
+
acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
|
141
|
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// Apply replication padding before downsampling conv from intermediates
|
|
145
|
+
#pragma unroll
|
|
146
|
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
|
147
|
+
{
|
|
148
|
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
|
149
|
+
}
|
|
150
|
+
#pragma unroll
|
|
151
|
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
|
152
|
+
{
|
|
153
|
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
|
157
|
+
#pragma unroll
|
|
158
|
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
|
159
|
+
{
|
|
160
|
+
acc_t acc = 0.0;
|
|
161
|
+
#pragma unroll
|
|
162
|
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
|
163
|
+
{
|
|
164
|
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
|
165
|
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
|
166
|
+
}
|
|
167
|
+
output[it] = acc;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
// Write output to dst
|
|
171
|
+
#pragma unroll
|
|
172
|
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
|
173
|
+
{
|
|
174
|
+
int element_index = seq_offset + it;
|
|
175
|
+
if (element_index < seq_len)
|
|
176
|
+
{
|
|
177
|
+
dst[it] = output[it];
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
template <typename input_t, typename output_t, typename acc_t>
|
|
184
|
+
void dispatch_anti_alias_activation_forward(
|
|
185
|
+
output_t *dst,
|
|
186
|
+
const input_t *src,
|
|
187
|
+
const acc_t *up_ftr,
|
|
188
|
+
const acc_t *down_ftr,
|
|
189
|
+
const acc_t *alpha,
|
|
190
|
+
const acc_t *beta,
|
|
191
|
+
int batch_size,
|
|
192
|
+
int channels,
|
|
193
|
+
int seq_len)
|
|
194
|
+
{
|
|
195
|
+
if (seq_len == 0)
|
|
196
|
+
{
|
|
197
|
+
return;
|
|
198
|
+
}
|
|
199
|
+
else
|
|
200
|
+
{
|
|
201
|
+
// Use 128 threads per block to maximimize gpu utilization
|
|
202
|
+
constexpr int threads_per_block = 128;
|
|
203
|
+
constexpr int seq_len_per_block = 4096;
|
|
204
|
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
|
205
|
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
|
206
|
+
dim3 threads(threads_per_block, 1, 1);
|
|
207
|
+
|
|
208
|
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
|
209
|
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
|
215
|
+
{
|
|
216
|
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
|
217
|
+
const int batches = input.size(0);
|
|
218
|
+
const int channels = input.size(1);
|
|
219
|
+
const int seq_len = input.size(2);
|
|
220
|
+
|
|
221
|
+
// Output
|
|
222
|
+
auto act_options = input.options().requires_grad(false);
|
|
223
|
+
|
|
224
|
+
torch::Tensor anti_alias_activation_results =
|
|
225
|
+
torch::empty({batches, channels, seq_len}, act_options);
|
|
226
|
+
|
|
227
|
+
using float32 = float;
|
|
228
|
+
// The dtype of input is float16, bfloat16, or float32
|
|
229
|
+
// The dtype of up_filter, down_filter, alpha, and beta is float32
|
|
230
|
+
// printf("input scalar type: %d\n", input.scalar_type());
|
|
231
|
+
// printf("up_filter scalar type: %d\n", up_filter.scalar_type());
|
|
232
|
+
// printf("down_filter scalar type: %d\n", down_filter.scalar_type());
|
|
233
|
+
// printf("alpha scalar type: %d\n", alpha.scalar_type());
|
|
234
|
+
// printf("beta scalar type: %d\n", beta.scalar_type());
|
|
235
|
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
|
236
|
+
float32 *up_filter_ptr = static_cast<float32 *>(up_filter.data_ptr());
|
|
237
|
+
float32 *down_filter_ptr = static_cast<float32 *>(down_filter.data_ptr());
|
|
238
|
+
float32 *alpha_ptr = static_cast<float32 *>(alpha.data_ptr());
|
|
239
|
+
float32 *beta_ptr = static_cast<float32 *>(beta.data_ptr());
|
|
240
|
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
|
241
|
+
|
|
242
|
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
|
243
|
+
input.scalar_type(),
|
|
244
|
+
"dispatch anti alias activation_forward",
|
|
245
|
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float32>(
|
|
246
|
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
|
247
|
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
|
248
|
+
reinterpret_cast<const float32 *>(up_filter_ptr),
|
|
249
|
+
reinterpret_cast<const float32 *>(down_filter_ptr),
|
|
250
|
+
reinterpret_cast<const float32 *>(alpha_ptr),
|
|
251
|
+
reinterpret_cast<const float32 *>(beta_ptr),
|
|
252
|
+
batches,
|
|
253
|
+
channels,
|
|
254
|
+
seq_len););
|
|
255
|
+
return anti_alias_activation_results;
|
|
256
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/* coding=utf-8
|
|
2
|
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
/*This code is copied fron NVIDIA apex:
|
|
18
|
+
* https://github.com/NVIDIA/apex
|
|
19
|
+
* with minor changes. */
|
|
20
|
+
|
|
21
|
+
#ifndef TORCH_CHECK
|
|
22
|
+
#define TORCH_CHECK AT_CHECK
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
#ifdef VERSION_GE_1_3
|
|
26
|
+
#define DATA_PTR data_ptr
|
|
27
|
+
#else
|
|
28
|
+
#define DATA_PTR data
|
|
29
|
+
#endif
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
+
# Licensed under the MIT license.
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import subprocess
|
|
7
|
+
|
|
8
|
+
from torch.utils import cpp_extension
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
|
12
|
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
|
13
|
+
"""
|
|
14
|
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import re
|
|
18
|
+
import shutil
|
|
19
|
+
import tempfile
|
|
20
|
+
|
|
21
|
+
# 补丁修复:sources 路径含中文字符时,生成 build.ninja 乱码导致编译失败
|
|
22
|
+
# 使用临时目录来规避 ninja 编译失败(比如中文路径)
|
|
23
|
+
def chinese_path_compile_support(sources, buildpath):
|
|
24
|
+
pattern = re.compile(r'[\u4e00-\u9fff]')
|
|
25
|
+
if not bool(pattern.search(str(sources[0].resolve()))):
|
|
26
|
+
return buildpath # 检测非中文路径跳过
|
|
27
|
+
# Create build directory
|
|
28
|
+
resolves = [ item.name for item in sources]
|
|
29
|
+
ninja_compile_dir = os.path.join(tempfile.gettempdir(), "BigVGAN", "cuda")
|
|
30
|
+
os.makedirs(ninja_compile_dir, exist_ok=True)
|
|
31
|
+
new_buildpath = os.path.join(ninja_compile_dir, "build")
|
|
32
|
+
os.makedirs(new_buildpath, exist_ok=True)
|
|
33
|
+
print(f"ninja_buildpath: {new_buildpath}")
|
|
34
|
+
# Copy files to directory
|
|
35
|
+
sources.clear()
|
|
36
|
+
current_dir = os.path.dirname(__file__)
|
|
37
|
+
ALLOWED_EXTENSIONS = {'.py', '.cu', '.cpp', '.h'}
|
|
38
|
+
for filename in os.listdir(current_dir):
|
|
39
|
+
item = pathlib.Path(current_dir).joinpath(filename)
|
|
40
|
+
tar_path = pathlib.Path(ninja_compile_dir).joinpath(item.name)
|
|
41
|
+
if not item.suffix.lower() in ALLOWED_EXTENSIONS:continue
|
|
42
|
+
pathlib.Path(shutil.copy2(item, tar_path))
|
|
43
|
+
if tar_path.name in resolves:sources.append(tar_path)
|
|
44
|
+
return new_buildpath
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def load():
|
|
49
|
+
# Check if cuda 11 is installed for compute capability 8.0
|
|
50
|
+
cc_flag = []
|
|
51
|
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
|
52
|
+
if int(bare_metal_major) >= 11:
|
|
53
|
+
cc_flag.append("-gencode")
|
|
54
|
+
cc_flag.append("arch=compute_80,code=sm_80")
|
|
55
|
+
|
|
56
|
+
# Build path
|
|
57
|
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
|
58
|
+
buildpath = srcpath / "build"
|
|
59
|
+
_create_build_dir(buildpath)
|
|
60
|
+
|
|
61
|
+
# Helper function to build the kernels.
|
|
62
|
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
|
63
|
+
return cpp_extension.load(
|
|
64
|
+
name=name,
|
|
65
|
+
sources=sources,
|
|
66
|
+
build_directory=buildpath,
|
|
67
|
+
extra_cflags=[
|
|
68
|
+
"-O3",
|
|
69
|
+
],
|
|
70
|
+
extra_cuda_cflags=[
|
|
71
|
+
"-O3",
|
|
72
|
+
"-gencode",
|
|
73
|
+
"arch=compute_70,code=sm_70",
|
|
74
|
+
"--use_fast_math",
|
|
75
|
+
]
|
|
76
|
+
+ extra_cuda_flags
|
|
77
|
+
+ cc_flag,
|
|
78
|
+
verbose=True,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
extra_cuda_flags = [
|
|
82
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
|
83
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
|
84
|
+
"--expt-relaxed-constexpr",
|
|
85
|
+
"--expt-extended-lambda",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
sources = [
|
|
89
|
+
srcpath / "anti_alias_activation.cpp",
|
|
90
|
+
srcpath / "anti_alias_activation_cuda.cu",
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
# 兼容方案:ninja 特殊字符路径编译支持处理(比如中文路径)
|
|
94
|
+
buildpath = chinese_path_compile_support(sources, buildpath)
|
|
95
|
+
|
|
96
|
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
|
97
|
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return anti_alias_activation_cuda
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_cuda_bare_metal_version(cuda_dir):
|
|
104
|
+
raw_output = subprocess.check_output(
|
|
105
|
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
|
106
|
+
)
|
|
107
|
+
output = raw_output.split()
|
|
108
|
+
release_idx = output.index("release") + 1
|
|
109
|
+
release = output[release_idx].split(".")
|
|
110
|
+
bare_metal_major = release[0]
|
|
111
|
+
bare_metal_minor = release[1][0]
|
|
112
|
+
|
|
113
|
+
return raw_output, bare_metal_major, bare_metal_minor
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _create_build_dir(buildpath):
|
|
117
|
+
try:
|
|
118
|
+
os.mkdir(buildpath)
|
|
119
|
+
except OSError:
|
|
120
|
+
if not os.path.isdir(buildpath):
|
|
121
|
+
print(f"Creation of the build directory {buildpath} failed")
|