xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +400 -3
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +111 -49
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +26 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +58 -1
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +71 -3
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +503 -21
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +32 -55
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +138 -53
- xinference/model/llm/vllm/core.py +95 -78
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
4
|
+
import typing
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseModel(nn.Module):
|
|
12
|
+
"""This is a class that adds useful save/load functionality to a
|
|
13
|
+
``torch.nn.Module`` object. ``BaseModel`` objects can be saved
|
|
14
|
+
as ``torch.package`` easily, making them super easy to port between
|
|
15
|
+
machines without requiring a ton of dependencies. Files can also be
|
|
16
|
+
saved as just weights, in the standard way.
|
|
17
|
+
|
|
18
|
+
>>> class Model(ml.BaseModel):
|
|
19
|
+
>>> def __init__(self, arg1: float = 1.0):
|
|
20
|
+
>>> super().__init__()
|
|
21
|
+
>>> self.arg1 = arg1
|
|
22
|
+
>>> self.linear = nn.Linear(1, 1)
|
|
23
|
+
>>>
|
|
24
|
+
>>> def forward(self, x):
|
|
25
|
+
>>> return self.linear(x)
|
|
26
|
+
>>>
|
|
27
|
+
>>> model1 = Model()
|
|
28
|
+
>>>
|
|
29
|
+
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
|
30
|
+
>>> model1.save(
|
|
31
|
+
>>> f.name,
|
|
32
|
+
>>> )
|
|
33
|
+
>>> model2 = Model.load(f.name)
|
|
34
|
+
>>> out2 = seed_and_run(model2, x)
|
|
35
|
+
>>> assert torch.allclose(out1, out2)
|
|
36
|
+
>>>
|
|
37
|
+
>>> model1.save(f.name, package=True)
|
|
38
|
+
>>> model2 = Model.load(f.name)
|
|
39
|
+
>>> model2.save(f.name, package=False)
|
|
40
|
+
>>> model3 = Model.load(f.name)
|
|
41
|
+
>>> out3 = seed_and_run(model3, x)
|
|
42
|
+
>>>
|
|
43
|
+
>>> with tempfile.TemporaryDirectory() as d:
|
|
44
|
+
>>> model1.save_to_folder(d, {"data": 1.0})
|
|
45
|
+
>>> Model.load_from_folder(d)
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
EXTERN = [
|
|
50
|
+
"audiotools.**",
|
|
51
|
+
"tqdm",
|
|
52
|
+
"__main__",
|
|
53
|
+
"numpy.**",
|
|
54
|
+
"julius.**",
|
|
55
|
+
"torchaudio.**",
|
|
56
|
+
"scipy.**",
|
|
57
|
+
"einops",
|
|
58
|
+
]
|
|
59
|
+
"""Names of libraries that are external to the torch.package saving mechanism.
|
|
60
|
+
Source code from these libraries will not be packaged into the model. This can
|
|
61
|
+
be edited by the user of this class by editing ``model.EXTERN``."""
|
|
62
|
+
INTERN = []
|
|
63
|
+
"""Names of libraries that are internal to the torch.package saving mechanism.
|
|
64
|
+
Source code from these libraries will be saved alongside the model."""
|
|
65
|
+
|
|
66
|
+
def save(
|
|
67
|
+
self,
|
|
68
|
+
path: str,
|
|
69
|
+
metadata: dict = None,
|
|
70
|
+
package: bool = True,
|
|
71
|
+
intern: list = [],
|
|
72
|
+
extern: list = [],
|
|
73
|
+
mock: list = [],
|
|
74
|
+
):
|
|
75
|
+
"""Saves the model, either as a torch package, or just as
|
|
76
|
+
weights, alongside some specified metadata.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
path : str
|
|
81
|
+
Path to save model to.
|
|
82
|
+
metadata : dict, optional
|
|
83
|
+
Any metadata to save alongside the model,
|
|
84
|
+
by default None
|
|
85
|
+
package : bool, optional
|
|
86
|
+
Whether to use ``torch.package`` to save the model in
|
|
87
|
+
a format that is portable, by default True
|
|
88
|
+
intern : list, optional
|
|
89
|
+
List of additional libraries that are internal
|
|
90
|
+
to the model, used with torch.package, by default []
|
|
91
|
+
extern : list, optional
|
|
92
|
+
List of additional libraries that are external to
|
|
93
|
+
the model, used with torch.package, by default []
|
|
94
|
+
mock : list, optional
|
|
95
|
+
List of libraries to mock, used with torch.package,
|
|
96
|
+
by default []
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
str
|
|
101
|
+
Path to saved model.
|
|
102
|
+
"""
|
|
103
|
+
sig = inspect.signature(self.__class__)
|
|
104
|
+
args = {}
|
|
105
|
+
|
|
106
|
+
for key, val in sig.parameters.items():
|
|
107
|
+
arg_val = val.default
|
|
108
|
+
if arg_val is not inspect.Parameter.empty:
|
|
109
|
+
args[key] = arg_val
|
|
110
|
+
|
|
111
|
+
# Look up attibutes in self, and if any of them are in args,
|
|
112
|
+
# overwrite them in args.
|
|
113
|
+
for attribute in dir(self):
|
|
114
|
+
if attribute in args:
|
|
115
|
+
args[attribute] = getattr(self, attribute)
|
|
116
|
+
|
|
117
|
+
metadata = {} if metadata is None else metadata
|
|
118
|
+
metadata["kwargs"] = args
|
|
119
|
+
if not hasattr(self, "metadata"):
|
|
120
|
+
self.metadata = {}
|
|
121
|
+
self.metadata.update(metadata)
|
|
122
|
+
|
|
123
|
+
if not package:
|
|
124
|
+
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
|
125
|
+
torch.save(state_dict, path)
|
|
126
|
+
else:
|
|
127
|
+
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
|
128
|
+
|
|
129
|
+
return path
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def device(self):
|
|
133
|
+
"""Gets the device the model is on by looking at the device of
|
|
134
|
+
the first parameter. May not be valid if model is split across
|
|
135
|
+
multiple devices.
|
|
136
|
+
"""
|
|
137
|
+
return list(self.parameters())[0].device
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def load(
|
|
141
|
+
cls,
|
|
142
|
+
location: str,
|
|
143
|
+
*args,
|
|
144
|
+
package_name: str = None,
|
|
145
|
+
strict: bool = False,
|
|
146
|
+
**kwargs,
|
|
147
|
+
):
|
|
148
|
+
"""Load model from a path. Tries first to load as a package, and if
|
|
149
|
+
that fails, tries to load as weights. The arguments to the class are
|
|
150
|
+
specified inside the model weights file.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
location : str
|
|
155
|
+
Path to file.
|
|
156
|
+
package_name : str, optional
|
|
157
|
+
Name of package, by default ``cls.__name__``.
|
|
158
|
+
strict : bool, optional
|
|
159
|
+
Ignore unmatched keys, by default False
|
|
160
|
+
kwargs : dict
|
|
161
|
+
Additional keyword arguments to the model instantiation, if
|
|
162
|
+
not loading from package.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
BaseModel
|
|
167
|
+
A model that inherits from BaseModel.
|
|
168
|
+
"""
|
|
169
|
+
try:
|
|
170
|
+
model = cls._load_package(location, package_name=package_name)
|
|
171
|
+
except:
|
|
172
|
+
model_dict = torch.load(location, "cpu")
|
|
173
|
+
metadata = model_dict["metadata"]
|
|
174
|
+
metadata["kwargs"].update(kwargs)
|
|
175
|
+
|
|
176
|
+
sig = inspect.signature(cls)
|
|
177
|
+
class_keys = list(sig.parameters.keys())
|
|
178
|
+
for k in list(metadata["kwargs"].keys()):
|
|
179
|
+
if k not in class_keys:
|
|
180
|
+
metadata["kwargs"].pop(k)
|
|
181
|
+
|
|
182
|
+
model = cls(*args, **metadata["kwargs"])
|
|
183
|
+
model.load_state_dict(model_dict["state_dict"], strict=strict)
|
|
184
|
+
model.metadata = metadata
|
|
185
|
+
|
|
186
|
+
return model
|
|
187
|
+
|
|
188
|
+
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
|
189
|
+
package_name = type(self).__name__
|
|
190
|
+
resource_name = f"{type(self).__name__}.pth"
|
|
191
|
+
|
|
192
|
+
# Below is for loading and re-saving a package.
|
|
193
|
+
if hasattr(self, "importer"):
|
|
194
|
+
kwargs["importer"] = (self.importer, torch.package.sys_importer)
|
|
195
|
+
del self.importer
|
|
196
|
+
|
|
197
|
+
# Why do we use a tempfile, you ask?
|
|
198
|
+
# It's so we can load a packaged model and then re-save
|
|
199
|
+
# it to the same location. torch.package throws an
|
|
200
|
+
# error if it's loading and writing to the same
|
|
201
|
+
# file (this is undocumented).
|
|
202
|
+
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
|
203
|
+
with torch.package.PackageExporter(f.name, **kwargs) as exp:
|
|
204
|
+
exp.intern(self.INTERN + intern)
|
|
205
|
+
exp.mock(mock)
|
|
206
|
+
exp.extern(self.EXTERN + extern)
|
|
207
|
+
exp.save_pickle(package_name, resource_name, self)
|
|
208
|
+
|
|
209
|
+
if hasattr(self, "metadata"):
|
|
210
|
+
exp.save_pickle(
|
|
211
|
+
package_name, f"{package_name}.metadata", self.metadata
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
shutil.copyfile(f.name, path)
|
|
215
|
+
|
|
216
|
+
# Must reset the importer back to `self` if it existed
|
|
217
|
+
# so that you can save the model again!
|
|
218
|
+
if "importer" in kwargs:
|
|
219
|
+
self.importer = kwargs["importer"][0]
|
|
220
|
+
return path
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def _load_package(cls, path, package_name=None):
|
|
224
|
+
package_name = cls.__name__ if package_name is None else package_name
|
|
225
|
+
resource_name = f"{package_name}.pth"
|
|
226
|
+
|
|
227
|
+
imp = torch.package.PackageImporter(path)
|
|
228
|
+
model = imp.load_pickle(package_name, resource_name, "cpu")
|
|
229
|
+
try:
|
|
230
|
+
model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
|
|
231
|
+
except: # pragma: no cover
|
|
232
|
+
pass
|
|
233
|
+
model.importer = imp
|
|
234
|
+
|
|
235
|
+
return model
|
|
236
|
+
|
|
237
|
+
def save_to_folder(
|
|
238
|
+
self,
|
|
239
|
+
folder: typing.Union[str, Path],
|
|
240
|
+
extra_data: dict = None,
|
|
241
|
+
package: bool = True,
|
|
242
|
+
):
|
|
243
|
+
"""Dumps a model into a folder, as both a package
|
|
244
|
+
and as weights, as well as anything specified in
|
|
245
|
+
``extra_data``. ``extra_data`` is a dictionary of other
|
|
246
|
+
pickleable files, with the keys being the paths
|
|
247
|
+
to save them in. The model is saved under a subfolder
|
|
248
|
+
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
|
|
249
|
+
if the model name was ``Generator``).
|
|
250
|
+
|
|
251
|
+
>>> with tempfile.TemporaryDirectory() as d:
|
|
252
|
+
>>> extra_data = {
|
|
253
|
+
>>> "optimizer.pth": optimizer.state_dict()
|
|
254
|
+
>>> }
|
|
255
|
+
>>> model.save_to_folder(d, extra_data)
|
|
256
|
+
>>> Model.load_from_folder(d)
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
folder : typing.Union[str, Path]
|
|
261
|
+
_description_
|
|
262
|
+
extra_data : dict, optional
|
|
263
|
+
_description_, by default None
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
str
|
|
268
|
+
Path to folder
|
|
269
|
+
"""
|
|
270
|
+
extra_data = {} if extra_data is None else extra_data
|
|
271
|
+
model_name = type(self).__name__.lower()
|
|
272
|
+
target_base = Path(f"{folder}/{model_name}/")
|
|
273
|
+
target_base.mkdir(exist_ok=True, parents=True)
|
|
274
|
+
|
|
275
|
+
if package:
|
|
276
|
+
package_path = target_base / f"package.pth"
|
|
277
|
+
self.save(package_path)
|
|
278
|
+
|
|
279
|
+
weights_path = target_base / f"weights.pth"
|
|
280
|
+
self.save(weights_path, package=False)
|
|
281
|
+
|
|
282
|
+
for path, obj in extra_data.items():
|
|
283
|
+
torch.save(obj, target_base / path)
|
|
284
|
+
|
|
285
|
+
return target_base
|
|
286
|
+
|
|
287
|
+
@classmethod
|
|
288
|
+
def load_from_folder(
|
|
289
|
+
cls,
|
|
290
|
+
folder: typing.Union[str, Path],
|
|
291
|
+
package: bool = True,
|
|
292
|
+
strict: bool = False,
|
|
293
|
+
**kwargs,
|
|
294
|
+
):
|
|
295
|
+
"""Loads the model from a folder generated by
|
|
296
|
+
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
|
297
|
+
Like that function, this one looks for a subfolder that has
|
|
298
|
+
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
|
|
299
|
+
model name was ``Generator``).
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
folder : typing.Union[str, Path]
|
|
304
|
+
_description_
|
|
305
|
+
package : bool, optional
|
|
306
|
+
Whether to use ``torch.package`` to load the model,
|
|
307
|
+
loading the model from ``package.pth``.
|
|
308
|
+
strict : bool, optional
|
|
309
|
+
Ignore unmatched keys, by default False
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
tuple
|
|
314
|
+
tuple of model and extra data as saved by
|
|
315
|
+
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
|
316
|
+
"""
|
|
317
|
+
folder = Path(folder) / cls.__name__.lower()
|
|
318
|
+
model_pth = "package.pth" if package else "weights.pth"
|
|
319
|
+
model_pth = folder / model_pth
|
|
320
|
+
|
|
321
|
+
model = cls.load(model_pth, strict=strict)
|
|
322
|
+
extra_data = {}
|
|
323
|
+
excluded = ["package.pth", "weights.pth"]
|
|
324
|
+
files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
|
|
325
|
+
for f in files:
|
|
326
|
+
extra_data[f.name] = torch.load(f, **kwargs)
|
|
327
|
+
|
|
328
|
+
return model, extra_data
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
from ...core import AudioSignal
|
|
6
|
+
from ...core import STFTParams
|
|
7
|
+
from ...core import util
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SpectralGate(nn.Module):
|
|
11
|
+
"""Spectral gating algorithm for noise reduction,
|
|
12
|
+
as in Audacity/Ocenaudio. The steps are as follows:
|
|
13
|
+
|
|
14
|
+
1. An FFT is calculated over the noise audio clip
|
|
15
|
+
2. Statistics are calculated over FFT of the the noise
|
|
16
|
+
(in frequency)
|
|
17
|
+
3. A threshold is calculated based upon the statistics
|
|
18
|
+
of the noise (and the desired sensitivity of the algorithm)
|
|
19
|
+
4. An FFT is calculated over the signal
|
|
20
|
+
5. A mask is determined by comparing the signal FFT to the
|
|
21
|
+
threshold
|
|
22
|
+
6. The mask is smoothed with a filter over frequency and time
|
|
23
|
+
7. The mask is appled to the FFT of the signal, and is inverted
|
|
24
|
+
|
|
25
|
+
Implementation inspired by Tim Sainburg's noisereduce:
|
|
26
|
+
|
|
27
|
+
https://timsainburg.com/noise-reduction-python.html
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
n_freq : int, optional
|
|
32
|
+
Number of frequency bins to smooth by, by default 3
|
|
33
|
+
n_time : int, optional
|
|
34
|
+
Number of time bins to smooth by, by default 5
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, n_freq: int = 3, n_time: int = 5):
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
smoothing_filter = torch.outer(
|
|
41
|
+
torch.cat(
|
|
42
|
+
[
|
|
43
|
+
torch.linspace(0, 1, n_freq + 2)[:-1],
|
|
44
|
+
torch.linspace(1, 0, n_freq + 2),
|
|
45
|
+
]
|
|
46
|
+
)[..., 1:-1],
|
|
47
|
+
torch.cat(
|
|
48
|
+
[
|
|
49
|
+
torch.linspace(0, 1, n_time + 2)[:-1],
|
|
50
|
+
torch.linspace(1, 0, n_time + 2),
|
|
51
|
+
]
|
|
52
|
+
)[..., 1:-1],
|
|
53
|
+
)
|
|
54
|
+
smoothing_filter = smoothing_filter / smoothing_filter.sum()
|
|
55
|
+
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
|
|
56
|
+
self.register_buffer("smoothing_filter", smoothing_filter)
|
|
57
|
+
|
|
58
|
+
def forward(
|
|
59
|
+
self,
|
|
60
|
+
audio_signal: AudioSignal,
|
|
61
|
+
nz_signal: AudioSignal,
|
|
62
|
+
denoise_amount: float = 1.0,
|
|
63
|
+
n_std: float = 3.0,
|
|
64
|
+
win_length: int = 2048,
|
|
65
|
+
hop_length: int = 512,
|
|
66
|
+
):
|
|
67
|
+
"""Perform noise reduction.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
audio_signal : AudioSignal
|
|
72
|
+
Audio signal that noise will be removed from.
|
|
73
|
+
nz_signal : AudioSignal, optional
|
|
74
|
+
Noise signal to compute noise statistics from.
|
|
75
|
+
denoise_amount : float, optional
|
|
76
|
+
Amount to denoise by, by default 1.0
|
|
77
|
+
n_std : float, optional
|
|
78
|
+
Number of standard deviations above which to consider
|
|
79
|
+
noise, by default 3.0
|
|
80
|
+
win_length : int, optional
|
|
81
|
+
Length of window for STFT, by default 2048
|
|
82
|
+
hop_length : int, optional
|
|
83
|
+
Hop length for STFT, by default 512
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
AudioSignal
|
|
88
|
+
Denoised audio signal.
|
|
89
|
+
"""
|
|
90
|
+
stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
|
|
91
|
+
|
|
92
|
+
audio_signal = audio_signal.clone()
|
|
93
|
+
audio_signal.stft_data = None
|
|
94
|
+
audio_signal.stft_params = stft_params
|
|
95
|
+
|
|
96
|
+
nz_signal = nz_signal.clone()
|
|
97
|
+
nz_signal.stft_params = stft_params
|
|
98
|
+
|
|
99
|
+
nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
|
|
100
|
+
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
|
|
101
|
+
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
|
|
102
|
+
|
|
103
|
+
nz_thresh = nz_freq_mean + nz_freq_std * n_std
|
|
104
|
+
|
|
105
|
+
stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
|
|
106
|
+
nb, nac, nf, nt = stft_db.shape
|
|
107
|
+
db_thresh = nz_thresh.expand(nb, nac, -1, nt)
|
|
108
|
+
|
|
109
|
+
stft_mask = (stft_db < db_thresh).float()
|
|
110
|
+
shape = stft_mask.shape
|
|
111
|
+
|
|
112
|
+
stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
|
|
113
|
+
pad_tuple = (
|
|
114
|
+
self.smoothing_filter.shape[-2] // 2,
|
|
115
|
+
self.smoothing_filter.shape[-1] // 2,
|
|
116
|
+
)
|
|
117
|
+
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
|
|
118
|
+
stft_mask = stft_mask.reshape(*shape)
|
|
119
|
+
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
|
|
120
|
+
audio_signal.device
|
|
121
|
+
)
|
|
122
|
+
stft_mask = 1 - stft_mask
|
|
123
|
+
|
|
124
|
+
audio_signal.stft_data *= stft_mask
|
|
125
|
+
audio_signal.istft()
|
|
126
|
+
|
|
127
|
+
return audio_signal
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
import typing
|
|
3
|
+
import zipfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import markdown2 as md
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import torch
|
|
9
|
+
from IPython.display import HTML
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def audio_table(
|
|
13
|
+
audio_dict: dict,
|
|
14
|
+
first_column: str = None,
|
|
15
|
+
format_fn: typing.Callable = None,
|
|
16
|
+
**kwargs,
|
|
17
|
+
): # pragma: no cover
|
|
18
|
+
"""Embeds an audio table into HTML, or as the output cell
|
|
19
|
+
in a notebook.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
audio_dict : dict
|
|
24
|
+
Dictionary of data to embed.
|
|
25
|
+
first_column : str, optional
|
|
26
|
+
The label for the first column of the table, by default None
|
|
27
|
+
format_fn : typing.Callable, optional
|
|
28
|
+
How to format the data, by default None
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
str
|
|
33
|
+
Table as a string
|
|
34
|
+
|
|
35
|
+
Examples
|
|
36
|
+
--------
|
|
37
|
+
|
|
38
|
+
>>> audio_dict = {}
|
|
39
|
+
>>> for i in range(signal_batch.batch_size):
|
|
40
|
+
>>> audio_dict[i] = {
|
|
41
|
+
>>> "input": signal_batch[i],
|
|
42
|
+
>>> "output": output_batch[i]
|
|
43
|
+
>>> }
|
|
44
|
+
>>> audiotools.post.audio_zip(audio_dict)
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
from audiotools import AudioSignal
|
|
48
|
+
|
|
49
|
+
output = []
|
|
50
|
+
columns = None
|
|
51
|
+
|
|
52
|
+
def _default_format_fn(label, x, **kwargs):
|
|
53
|
+
if torch.is_tensor(x):
|
|
54
|
+
x = x.tolist()
|
|
55
|
+
|
|
56
|
+
if x is None:
|
|
57
|
+
return "."
|
|
58
|
+
elif isinstance(x, AudioSignal):
|
|
59
|
+
return x.embed(display=False, return_html=True, **kwargs)
|
|
60
|
+
else:
|
|
61
|
+
return str(x)
|
|
62
|
+
|
|
63
|
+
if format_fn is None:
|
|
64
|
+
format_fn = _default_format_fn
|
|
65
|
+
|
|
66
|
+
if first_column is None:
|
|
67
|
+
first_column = "."
|
|
68
|
+
|
|
69
|
+
for k, v in audio_dict.items():
|
|
70
|
+
if not isinstance(v, dict):
|
|
71
|
+
v = {"Audio": v}
|
|
72
|
+
|
|
73
|
+
v_keys = list(v.keys())
|
|
74
|
+
if columns is None:
|
|
75
|
+
columns = [first_column] + v_keys
|
|
76
|
+
output.append(" | ".join(columns))
|
|
77
|
+
|
|
78
|
+
layout = "|---" + len(v_keys) * "|:-:"
|
|
79
|
+
output.append(layout)
|
|
80
|
+
|
|
81
|
+
formatted_audio = []
|
|
82
|
+
for col in columns[1:]:
|
|
83
|
+
formatted_audio.append(format_fn(col, v[col], **kwargs))
|
|
84
|
+
|
|
85
|
+
row = f"| {k} | "
|
|
86
|
+
row += " | ".join(formatted_audio)
|
|
87
|
+
output.append(row)
|
|
88
|
+
|
|
89
|
+
output = "\n" + "\n".join(output)
|
|
90
|
+
return output
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def in_notebook(): # pragma: no cover
|
|
94
|
+
"""Determines if code is running in a notebook.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
bool
|
|
99
|
+
Whether or not this is running in a notebook.
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
from IPython import get_ipython
|
|
103
|
+
|
|
104
|
+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
|
105
|
+
return False
|
|
106
|
+
except ImportError:
|
|
107
|
+
return False
|
|
108
|
+
except AttributeError:
|
|
109
|
+
return False
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def disp(obj, **kwargs): # pragma: no cover
|
|
114
|
+
"""Displays an object, depending on if its in a notebook
|
|
115
|
+
or not.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
obj : typing.Any
|
|
120
|
+
Any object to display.
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
from audiotools import AudioSignal
|
|
124
|
+
|
|
125
|
+
IN_NOTEBOOK = in_notebook()
|
|
126
|
+
|
|
127
|
+
if isinstance(obj, AudioSignal):
|
|
128
|
+
audio_elem = obj.embed(display=False, return_html=True)
|
|
129
|
+
if IN_NOTEBOOK:
|
|
130
|
+
return HTML(audio_elem)
|
|
131
|
+
else:
|
|
132
|
+
print(audio_elem)
|
|
133
|
+
if isinstance(obj, dict):
|
|
134
|
+
table = audio_table(obj, **kwargs)
|
|
135
|
+
if IN_NOTEBOOK:
|
|
136
|
+
return HTML(md.markdown(table, extras=["tables"]))
|
|
137
|
+
else:
|
|
138
|
+
print(table)
|
|
139
|
+
if isinstance(obj, plt.Figure):
|
|
140
|
+
plt.show()
|