xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +400 -3
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +111 -49
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +26 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +58 -1
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +71 -3
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +503 -21
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +32 -55
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +138 -53
- xinference/model/llm/vllm/core.py +95 -78
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1682 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import hashlib
|
|
4
|
+
import math
|
|
5
|
+
import pathlib
|
|
6
|
+
import tempfile
|
|
7
|
+
import typing
|
|
8
|
+
import warnings
|
|
9
|
+
from collections import namedtuple
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import julius
|
|
13
|
+
import numpy as np
|
|
14
|
+
import soundfile
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from . import util
|
|
18
|
+
from .display import DisplayMixin
|
|
19
|
+
from .dsp import DSPMixin
|
|
20
|
+
from .effects import EffectMixin
|
|
21
|
+
from .effects import ImpulseResponseMixin
|
|
22
|
+
from .ffmpeg import FFMPEGMixin
|
|
23
|
+
from .loudness import LoudnessMixin
|
|
24
|
+
from .playback import PlayMixin
|
|
25
|
+
from .whisper import WhisperMixin
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
STFTParams = namedtuple(
|
|
29
|
+
"STFTParams",
|
|
30
|
+
["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
|
|
31
|
+
)
|
|
32
|
+
"""
|
|
33
|
+
STFTParams object is a container that holds STFT parameters - window_length,
|
|
34
|
+
hop_length, and window_type. Not all parameters need to be specified. Ones that
|
|
35
|
+
are not specified will be inferred by the AudioSignal parameters.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
window_length : int, optional
|
|
40
|
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
|
41
|
+
hop_length : int, optional
|
|
42
|
+
Hop length of STFT, by default ``window_length // 4``.
|
|
43
|
+
window_type : str, optional
|
|
44
|
+
Type of window to use, by default ``sqrt\_hann``.
|
|
45
|
+
match_stride : bool, optional
|
|
46
|
+
Whether to match the stride of convolutional layers, by default False
|
|
47
|
+
padding_type : str, optional
|
|
48
|
+
Type of padding to use, by default 'reflect'
|
|
49
|
+
"""
|
|
50
|
+
STFTParams.__new__.__defaults__ = (None, None, None, None, None)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AudioSignal(
|
|
54
|
+
EffectMixin,
|
|
55
|
+
LoudnessMixin,
|
|
56
|
+
PlayMixin,
|
|
57
|
+
ImpulseResponseMixin,
|
|
58
|
+
DSPMixin,
|
|
59
|
+
DisplayMixin,
|
|
60
|
+
FFMPEGMixin,
|
|
61
|
+
WhisperMixin,
|
|
62
|
+
):
|
|
63
|
+
"""This is the core object of this library. Audio is always
|
|
64
|
+
loaded into an AudioSignal, which then enables all the features
|
|
65
|
+
of this library, including audio augmentations, I/O, playback,
|
|
66
|
+
and more.
|
|
67
|
+
|
|
68
|
+
The structure of this object is that the base functionality
|
|
69
|
+
is defined in ``core/audio_signal.py``, while extensions to
|
|
70
|
+
that functionality are defined in the other ``core/*.py``
|
|
71
|
+
files. For example, all the display-based functionality
|
|
72
|
+
(e.g. plot spectrograms, waveforms, write to tensorboard)
|
|
73
|
+
are in ``core/display.py``.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
|
|
78
|
+
Object to create AudioSignal from. Can be a tensor, numpy array,
|
|
79
|
+
or a path to a file. The file is always reshaped to
|
|
80
|
+
sample_rate : int, optional
|
|
81
|
+
Sample rate of the audio. If different from underlying file, resampling is
|
|
82
|
+
performed. If passing in an array or tensor, this must be defined,
|
|
83
|
+
by default None
|
|
84
|
+
stft_params : STFTParams, optional
|
|
85
|
+
Parameters of STFT to use. , by default None
|
|
86
|
+
offset : float, optional
|
|
87
|
+
Offset in seconds to read from file, by default 0
|
|
88
|
+
duration : float, optional
|
|
89
|
+
Duration in seconds to read from file, by default None
|
|
90
|
+
device : str, optional
|
|
91
|
+
Device to load audio onto, by default None
|
|
92
|
+
|
|
93
|
+
Examples
|
|
94
|
+
--------
|
|
95
|
+
Loading an AudioSignal from an array, at a sample rate of
|
|
96
|
+
44100.
|
|
97
|
+
|
|
98
|
+
>>> signal = AudioSignal(torch.randn(5*44100), 44100)
|
|
99
|
+
|
|
100
|
+
Note, the signal is reshaped to have a batch size, and one
|
|
101
|
+
audio channel:
|
|
102
|
+
|
|
103
|
+
>>> print(signal.shape)
|
|
104
|
+
(1, 1, 44100)
|
|
105
|
+
|
|
106
|
+
You can treat AudioSignals like tensors, and many of the same
|
|
107
|
+
functions you might use on tensors are defined for AudioSignals
|
|
108
|
+
as well:
|
|
109
|
+
|
|
110
|
+
>>> signal.to("cuda")
|
|
111
|
+
>>> signal.cuda()
|
|
112
|
+
>>> signal.clone()
|
|
113
|
+
>>> signal.detach()
|
|
114
|
+
|
|
115
|
+
Indexing AudioSignals returns an AudioSignal:
|
|
116
|
+
|
|
117
|
+
>>> signal[..., 3*44100:4*44100]
|
|
118
|
+
|
|
119
|
+
The above signal is 1 second long, and is also an AudioSignal.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
|
|
125
|
+
sample_rate: int = None,
|
|
126
|
+
stft_params: STFTParams = None,
|
|
127
|
+
offset: float = 0,
|
|
128
|
+
duration: float = None,
|
|
129
|
+
device: str = None,
|
|
130
|
+
):
|
|
131
|
+
audio_path = None
|
|
132
|
+
audio_array = None
|
|
133
|
+
|
|
134
|
+
if isinstance(audio_path_or_array, str):
|
|
135
|
+
audio_path = audio_path_or_array
|
|
136
|
+
elif isinstance(audio_path_or_array, pathlib.Path):
|
|
137
|
+
audio_path = audio_path_or_array
|
|
138
|
+
elif isinstance(audio_path_or_array, np.ndarray):
|
|
139
|
+
audio_array = audio_path_or_array
|
|
140
|
+
elif torch.is_tensor(audio_path_or_array):
|
|
141
|
+
audio_array = audio_path_or_array
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"audio_path_or_array must be either a Path, "
|
|
145
|
+
"string, numpy array, or torch Tensor!"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.path_to_file = None
|
|
149
|
+
|
|
150
|
+
self.audio_data = None
|
|
151
|
+
self.sources = None # List of AudioSignal objects.
|
|
152
|
+
self.stft_data = None
|
|
153
|
+
if audio_path is not None:
|
|
154
|
+
self.load_from_file(
|
|
155
|
+
audio_path, offset=offset, duration=duration, device=device
|
|
156
|
+
)
|
|
157
|
+
elif audio_array is not None:
|
|
158
|
+
assert sample_rate is not None, "Must set sample rate!"
|
|
159
|
+
self.load_from_array(audio_array, sample_rate, device=device)
|
|
160
|
+
|
|
161
|
+
self.window = None
|
|
162
|
+
self.stft_params = stft_params
|
|
163
|
+
|
|
164
|
+
self.metadata = {
|
|
165
|
+
"offset": offset,
|
|
166
|
+
"duration": duration,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def path_to_input_file(
|
|
171
|
+
self,
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Path to input file, if it exists.
|
|
175
|
+
Alias to ``path_to_file`` for backwards compatibility
|
|
176
|
+
"""
|
|
177
|
+
return self.path_to_file
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def excerpt(
|
|
181
|
+
cls,
|
|
182
|
+
audio_path: typing.Union[str, Path],
|
|
183
|
+
offset: float = None,
|
|
184
|
+
duration: float = None,
|
|
185
|
+
state: typing.Union[np.random.RandomState, int] = None,
|
|
186
|
+
**kwargs,
|
|
187
|
+
):
|
|
188
|
+
"""Randomly draw an excerpt of ``duration`` seconds from an
|
|
189
|
+
audio file specified at ``audio_path``, between ``offset`` seconds
|
|
190
|
+
and end of file. ``state`` can be used to seed the random draw.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
audio_path : typing.Union[str, Path]
|
|
195
|
+
Path to audio file to grab excerpt from.
|
|
196
|
+
offset : float, optional
|
|
197
|
+
Lower bound for the start time, in seconds drawn from
|
|
198
|
+
the file, by default None.
|
|
199
|
+
duration : float, optional
|
|
200
|
+
Duration of excerpt, in seconds, by default None
|
|
201
|
+
state : typing.Union[np.random.RandomState, int], optional
|
|
202
|
+
RandomState or seed of random state, by default None
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
AudioSignal
|
|
207
|
+
AudioSignal containing excerpt.
|
|
208
|
+
|
|
209
|
+
Examples
|
|
210
|
+
--------
|
|
211
|
+
>>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
|
|
212
|
+
"""
|
|
213
|
+
info = util.info(audio_path)
|
|
214
|
+
total_duration = info.duration
|
|
215
|
+
|
|
216
|
+
state = util.random_state(state)
|
|
217
|
+
lower_bound = 0 if offset is None else offset
|
|
218
|
+
upper_bound = max(total_duration - duration, 0)
|
|
219
|
+
offset = state.uniform(lower_bound, upper_bound)
|
|
220
|
+
|
|
221
|
+
signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
|
|
222
|
+
signal.metadata["offset"] = offset
|
|
223
|
+
signal.metadata["duration"] = duration
|
|
224
|
+
|
|
225
|
+
return signal
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def salient_excerpt(
|
|
229
|
+
cls,
|
|
230
|
+
audio_path: typing.Union[str, Path],
|
|
231
|
+
loudness_cutoff: float = None,
|
|
232
|
+
num_tries: int = 8,
|
|
233
|
+
state: typing.Union[np.random.RandomState, int] = None,
|
|
234
|
+
**kwargs,
|
|
235
|
+
):
|
|
236
|
+
"""Similar to AudioSignal.excerpt, except it extracts excerpts only
|
|
237
|
+
if they are above a specified loudness threshold, which is computed via
|
|
238
|
+
a fast LUFS routine.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
audio_path : typing.Union[str, Path]
|
|
243
|
+
Path to audio file to grab excerpt from.
|
|
244
|
+
loudness_cutoff : float, optional
|
|
245
|
+
Loudness threshold in dB. Typical values are ``-40, -60``,
|
|
246
|
+
etc, by default None
|
|
247
|
+
num_tries : int, optional
|
|
248
|
+
Number of tries to grab an excerpt above the threshold
|
|
249
|
+
before giving up, by default 8.
|
|
250
|
+
state : typing.Union[np.random.RandomState, int], optional
|
|
251
|
+
RandomState or seed of random state, by default None
|
|
252
|
+
kwargs : dict
|
|
253
|
+
Keyword arguments to AudioSignal.excerpt
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
AudioSignal
|
|
258
|
+
AudioSignal containing excerpt.
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
.. warning::
|
|
262
|
+
if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
|
|
263
|
+
result in an infinite loop if ``audio_path`` does not have
|
|
264
|
+
any loud enough excerpts.
|
|
265
|
+
|
|
266
|
+
Examples
|
|
267
|
+
--------
|
|
268
|
+
>>> signal = AudioSignal.salient_excerpt(
|
|
269
|
+
"path/to/audio",
|
|
270
|
+
loudness_cutoff=-40,
|
|
271
|
+
duration=5
|
|
272
|
+
)
|
|
273
|
+
"""
|
|
274
|
+
state = util.random_state(state)
|
|
275
|
+
if loudness_cutoff is None:
|
|
276
|
+
excerpt = cls.excerpt(audio_path, state=state, **kwargs)
|
|
277
|
+
else:
|
|
278
|
+
loudness = -np.inf
|
|
279
|
+
num_try = 0
|
|
280
|
+
while loudness <= loudness_cutoff:
|
|
281
|
+
excerpt = cls.excerpt(audio_path, state=state, **kwargs)
|
|
282
|
+
loudness = excerpt.loudness()
|
|
283
|
+
num_try += 1
|
|
284
|
+
if num_tries is not None and num_try >= num_tries:
|
|
285
|
+
break
|
|
286
|
+
return excerpt
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def zeros(
|
|
290
|
+
cls,
|
|
291
|
+
duration: float,
|
|
292
|
+
sample_rate: int,
|
|
293
|
+
num_channels: int = 1,
|
|
294
|
+
batch_size: int = 1,
|
|
295
|
+
**kwargs,
|
|
296
|
+
):
|
|
297
|
+
"""Helper function create an AudioSignal of all zeros.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
duration : float
|
|
302
|
+
Duration of AudioSignal
|
|
303
|
+
sample_rate : int
|
|
304
|
+
Sample rate of AudioSignal
|
|
305
|
+
num_channels : int, optional
|
|
306
|
+
Number of channels, by default 1
|
|
307
|
+
batch_size : int, optional
|
|
308
|
+
Batch size, by default 1
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
AudioSignal
|
|
313
|
+
AudioSignal containing all zeros.
|
|
314
|
+
|
|
315
|
+
Examples
|
|
316
|
+
--------
|
|
317
|
+
Generate 5 seconds of all zeros at a sample rate of 44100.
|
|
318
|
+
|
|
319
|
+
>>> signal = AudioSignal.zeros(5.0, 44100)
|
|
320
|
+
"""
|
|
321
|
+
n_samples = int(duration * sample_rate)
|
|
322
|
+
return cls(
|
|
323
|
+
torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def wave(
|
|
328
|
+
cls,
|
|
329
|
+
frequency: float,
|
|
330
|
+
duration: float,
|
|
331
|
+
sample_rate: int,
|
|
332
|
+
num_channels: int = 1,
|
|
333
|
+
shape: str = "sine",
|
|
334
|
+
**kwargs,
|
|
335
|
+
):
|
|
336
|
+
"""
|
|
337
|
+
Generate a waveform of a given frequency and shape.
|
|
338
|
+
|
|
339
|
+
Parameters
|
|
340
|
+
----------
|
|
341
|
+
frequency : float
|
|
342
|
+
Frequency of the waveform
|
|
343
|
+
duration : float
|
|
344
|
+
Duration of the waveform
|
|
345
|
+
sample_rate : int
|
|
346
|
+
Sample rate of the waveform
|
|
347
|
+
num_channels : int, optional
|
|
348
|
+
Number of channels, by default 1
|
|
349
|
+
shape : str, optional
|
|
350
|
+
Shape of the waveform, by default "saw"
|
|
351
|
+
One of "sawtooth", "square", "sine", "triangle"
|
|
352
|
+
kwargs : dict
|
|
353
|
+
Keyword arguments to AudioSignal
|
|
354
|
+
"""
|
|
355
|
+
n_samples = int(duration * sample_rate)
|
|
356
|
+
t = torch.linspace(0, duration, n_samples)
|
|
357
|
+
if shape == "sawtooth":
|
|
358
|
+
from scipy.signal import sawtooth
|
|
359
|
+
|
|
360
|
+
wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
|
|
361
|
+
elif shape == "square":
|
|
362
|
+
from scipy.signal import square
|
|
363
|
+
|
|
364
|
+
wave_data = square(2 * np.pi * frequency * t)
|
|
365
|
+
elif shape == "sine":
|
|
366
|
+
wave_data = np.sin(2 * np.pi * frequency * t)
|
|
367
|
+
elif shape == "triangle":
|
|
368
|
+
from scipy.signal import sawtooth
|
|
369
|
+
|
|
370
|
+
# frequency is doubled by the abs call, so omit the 2 in 2pi
|
|
371
|
+
wave_data = sawtooth(np.pi * frequency * t, 0.5)
|
|
372
|
+
wave_data = -np.abs(wave_data) * 2 + 1
|
|
373
|
+
else:
|
|
374
|
+
raise ValueError(f"Invalid shape {shape}")
|
|
375
|
+
|
|
376
|
+
wave_data = torch.tensor(wave_data, dtype=torch.float32)
|
|
377
|
+
wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
|
|
378
|
+
return cls(wave_data, sample_rate, **kwargs)
|
|
379
|
+
|
|
380
|
+
@classmethod
|
|
381
|
+
def batch(
|
|
382
|
+
cls,
|
|
383
|
+
audio_signals: list,
|
|
384
|
+
pad_signals: bool = False,
|
|
385
|
+
truncate_signals: bool = False,
|
|
386
|
+
resample: bool = False,
|
|
387
|
+
dim: int = 0,
|
|
388
|
+
):
|
|
389
|
+
"""Creates a batched AudioSignal from a list of AudioSignals.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
audio_signals : list[AudioSignal]
|
|
394
|
+
List of AudioSignal objects
|
|
395
|
+
pad_signals : bool, optional
|
|
396
|
+
Whether to pad signals to length of the maximum length
|
|
397
|
+
AudioSignal in the list, by default False
|
|
398
|
+
truncate_signals : bool, optional
|
|
399
|
+
Whether to truncate signals to length of shortest length
|
|
400
|
+
AudioSignal in the list, by default False
|
|
401
|
+
resample : bool, optional
|
|
402
|
+
Whether to resample AudioSignal to the sample rate of
|
|
403
|
+
the first AudioSignal in the list, by default False
|
|
404
|
+
dim : int, optional
|
|
405
|
+
Dimension along which to batch the signals.
|
|
406
|
+
|
|
407
|
+
Returns
|
|
408
|
+
-------
|
|
409
|
+
AudioSignal
|
|
410
|
+
Batched AudioSignal.
|
|
411
|
+
|
|
412
|
+
Raises
|
|
413
|
+
------
|
|
414
|
+
RuntimeError
|
|
415
|
+
If not all AudioSignals are the same sample rate, and
|
|
416
|
+
``resample=False``, an error is raised.
|
|
417
|
+
RuntimeError
|
|
418
|
+
If not all AudioSignals are the same the length, and
|
|
419
|
+
both ``pad_signals=False`` and ``truncate_signals=False``,
|
|
420
|
+
an error is raised.
|
|
421
|
+
|
|
422
|
+
Examples
|
|
423
|
+
--------
|
|
424
|
+
Batching a bunch of random signals:
|
|
425
|
+
|
|
426
|
+
>>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
|
|
427
|
+
>>> signal = AudioSignal.batch(signal_list)
|
|
428
|
+
>>> print(signal.shape)
|
|
429
|
+
(10, 1, 44100)
|
|
430
|
+
|
|
431
|
+
"""
|
|
432
|
+
signal_lengths = [x.signal_length for x in audio_signals]
|
|
433
|
+
sample_rates = [x.sample_rate for x in audio_signals]
|
|
434
|
+
|
|
435
|
+
if len(set(sample_rates)) != 1:
|
|
436
|
+
if resample:
|
|
437
|
+
for x in audio_signals:
|
|
438
|
+
x.resample(sample_rates[0])
|
|
439
|
+
else:
|
|
440
|
+
raise RuntimeError(
|
|
441
|
+
f"Not all signals had the same sample rate! Got {sample_rates}. "
|
|
442
|
+
f"All signals must have the same sample rate, or resample must be True. "
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
if len(set(signal_lengths)) != 1:
|
|
446
|
+
if pad_signals:
|
|
447
|
+
max_length = max(signal_lengths)
|
|
448
|
+
for x in audio_signals:
|
|
449
|
+
pad_len = max_length - x.signal_length
|
|
450
|
+
x.zero_pad(0, pad_len)
|
|
451
|
+
elif truncate_signals:
|
|
452
|
+
min_length = min(signal_lengths)
|
|
453
|
+
for x in audio_signals:
|
|
454
|
+
x.truncate_samples(min_length)
|
|
455
|
+
else:
|
|
456
|
+
raise RuntimeError(
|
|
457
|
+
f"Not all signals had the same length! Got {signal_lengths}. "
|
|
458
|
+
f"All signals must be the same length, or pad_signals/truncate_signals "
|
|
459
|
+
f"must be True. "
|
|
460
|
+
)
|
|
461
|
+
# Concatenate along the specified dimension (default 0)
|
|
462
|
+
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
|
|
463
|
+
audio_paths = [x.path_to_file for x in audio_signals]
|
|
464
|
+
|
|
465
|
+
batched_signal = cls(
|
|
466
|
+
audio_data,
|
|
467
|
+
sample_rate=audio_signals[0].sample_rate,
|
|
468
|
+
)
|
|
469
|
+
batched_signal.path_to_file = audio_paths
|
|
470
|
+
return batched_signal
|
|
471
|
+
|
|
472
|
+
# I/O
|
|
473
|
+
def load_from_file(
|
|
474
|
+
self,
|
|
475
|
+
audio_path: typing.Union[str, Path],
|
|
476
|
+
offset: float,
|
|
477
|
+
duration: float,
|
|
478
|
+
device: str = "cpu",
|
|
479
|
+
):
|
|
480
|
+
"""Loads data from file. Used internally when AudioSignal
|
|
481
|
+
is instantiated with a path to a file.
|
|
482
|
+
|
|
483
|
+
Parameters
|
|
484
|
+
----------
|
|
485
|
+
audio_path : typing.Union[str, Path]
|
|
486
|
+
Path to file
|
|
487
|
+
offset : float
|
|
488
|
+
Offset in seconds
|
|
489
|
+
duration : float
|
|
490
|
+
Duration in seconds
|
|
491
|
+
device : str, optional
|
|
492
|
+
Device to put AudioSignal on, by default "cpu"
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
AudioSignal
|
|
497
|
+
AudioSignal loaded from file
|
|
498
|
+
"""
|
|
499
|
+
import librosa
|
|
500
|
+
|
|
501
|
+
data, sample_rate = librosa.load(
|
|
502
|
+
audio_path,
|
|
503
|
+
offset=offset,
|
|
504
|
+
duration=duration,
|
|
505
|
+
sr=None,
|
|
506
|
+
mono=False,
|
|
507
|
+
)
|
|
508
|
+
data = util.ensure_tensor(data)
|
|
509
|
+
if data.shape[-1] == 0:
|
|
510
|
+
raise RuntimeError(
|
|
511
|
+
f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
if data.ndim < 2:
|
|
515
|
+
data = data.unsqueeze(0)
|
|
516
|
+
if data.ndim < 3:
|
|
517
|
+
data = data.unsqueeze(0)
|
|
518
|
+
self.audio_data = data
|
|
519
|
+
|
|
520
|
+
self.original_signal_length = self.signal_length
|
|
521
|
+
|
|
522
|
+
self.sample_rate = sample_rate
|
|
523
|
+
self.path_to_file = audio_path
|
|
524
|
+
return self.to(device)
|
|
525
|
+
|
|
526
|
+
def load_from_array(
|
|
527
|
+
self,
|
|
528
|
+
audio_array: typing.Union[torch.Tensor, np.ndarray],
|
|
529
|
+
sample_rate: int,
|
|
530
|
+
device: str = "cpu",
|
|
531
|
+
):
|
|
532
|
+
"""Loads data from array, reshaping it to be exactly 3
|
|
533
|
+
dimensions. Used internally when AudioSignal is called
|
|
534
|
+
with a tensor or an array.
|
|
535
|
+
|
|
536
|
+
Parameters
|
|
537
|
+
----------
|
|
538
|
+
audio_array : typing.Union[torch.Tensor, np.ndarray]
|
|
539
|
+
Array/tensor of audio of samples.
|
|
540
|
+
sample_rate : int
|
|
541
|
+
Sample rate of audio
|
|
542
|
+
device : str, optional
|
|
543
|
+
Device to move audio onto, by default "cpu"
|
|
544
|
+
|
|
545
|
+
Returns
|
|
546
|
+
-------
|
|
547
|
+
AudioSignal
|
|
548
|
+
AudioSignal loaded from array
|
|
549
|
+
"""
|
|
550
|
+
audio_data = util.ensure_tensor(audio_array)
|
|
551
|
+
|
|
552
|
+
if audio_data.dtype == torch.double:
|
|
553
|
+
audio_data = audio_data.float()
|
|
554
|
+
|
|
555
|
+
if audio_data.ndim < 2:
|
|
556
|
+
audio_data = audio_data.unsqueeze(0)
|
|
557
|
+
if audio_data.ndim < 3:
|
|
558
|
+
audio_data = audio_data.unsqueeze(0)
|
|
559
|
+
self.audio_data = audio_data
|
|
560
|
+
|
|
561
|
+
self.original_signal_length = self.signal_length
|
|
562
|
+
|
|
563
|
+
self.sample_rate = sample_rate
|
|
564
|
+
return self.to(device)
|
|
565
|
+
|
|
566
|
+
def write(self, audio_path: typing.Union[str, Path]):
|
|
567
|
+
"""Writes audio to a file. Only writes the audio
|
|
568
|
+
that is in the very first item of the batch. To write other items
|
|
569
|
+
in the batch, index the signal along the batch dimension
|
|
570
|
+
before writing. After writing, the signal's ``path_to_file``
|
|
571
|
+
attribute is updated to the new path.
|
|
572
|
+
|
|
573
|
+
Parameters
|
|
574
|
+
----------
|
|
575
|
+
audio_path : typing.Union[str, Path]
|
|
576
|
+
Path to write audio to.
|
|
577
|
+
|
|
578
|
+
Returns
|
|
579
|
+
-------
|
|
580
|
+
AudioSignal
|
|
581
|
+
Returns original AudioSignal, so you can use this in a fluent
|
|
582
|
+
interface.
|
|
583
|
+
|
|
584
|
+
Examples
|
|
585
|
+
--------
|
|
586
|
+
Creating and writing a signal to disk:
|
|
587
|
+
|
|
588
|
+
>>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
|
|
589
|
+
>>> signal.write("/tmp/out.wav")
|
|
590
|
+
|
|
591
|
+
Writing a different element of the batch:
|
|
592
|
+
|
|
593
|
+
>>> signal[5].write("/tmp/out.wav")
|
|
594
|
+
|
|
595
|
+
Using this in a fluent interface:
|
|
596
|
+
|
|
597
|
+
>>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
|
|
598
|
+
|
|
599
|
+
"""
|
|
600
|
+
if self.audio_data[0].abs().max() > 1:
|
|
601
|
+
warnings.warn("Audio amplitude > 1 clipped when saving")
|
|
602
|
+
soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
|
|
603
|
+
|
|
604
|
+
self.path_to_file = audio_path
|
|
605
|
+
return self
|
|
606
|
+
|
|
607
|
+
def deepcopy(self):
|
|
608
|
+
"""Copies the signal and all of its attributes.
|
|
609
|
+
|
|
610
|
+
Returns
|
|
611
|
+
-------
|
|
612
|
+
AudioSignal
|
|
613
|
+
Deep copy of the audio signal.
|
|
614
|
+
"""
|
|
615
|
+
return copy.deepcopy(self)
|
|
616
|
+
|
|
617
|
+
def copy(self):
|
|
618
|
+
"""Shallow copy of signal.
|
|
619
|
+
|
|
620
|
+
Returns
|
|
621
|
+
-------
|
|
622
|
+
AudioSignal
|
|
623
|
+
Shallow copy of the audio signal.
|
|
624
|
+
"""
|
|
625
|
+
return copy.copy(self)
|
|
626
|
+
|
|
627
|
+
def clone(self):
|
|
628
|
+
"""Clones all tensors contained in the AudioSignal,
|
|
629
|
+
and returns a copy of the signal with everything
|
|
630
|
+
cloned. Useful when using AudioSignal within autograd
|
|
631
|
+
computation graphs.
|
|
632
|
+
|
|
633
|
+
Relevant attributes are the stft data, the audio data,
|
|
634
|
+
and the loudness of the file.
|
|
635
|
+
|
|
636
|
+
Returns
|
|
637
|
+
-------
|
|
638
|
+
AudioSignal
|
|
639
|
+
Clone of AudioSignal.
|
|
640
|
+
"""
|
|
641
|
+
clone = type(self)(
|
|
642
|
+
self.audio_data.clone(),
|
|
643
|
+
self.sample_rate,
|
|
644
|
+
stft_params=self.stft_params,
|
|
645
|
+
)
|
|
646
|
+
if self.stft_data is not None:
|
|
647
|
+
clone.stft_data = self.stft_data.clone()
|
|
648
|
+
if self._loudness is not None:
|
|
649
|
+
clone._loudness = self._loudness.clone()
|
|
650
|
+
clone.path_to_file = copy.deepcopy(self.path_to_file)
|
|
651
|
+
clone.metadata = copy.deepcopy(self.metadata)
|
|
652
|
+
return clone
|
|
653
|
+
|
|
654
|
+
def detach(self):
|
|
655
|
+
"""Detaches tensors contained in AudioSignal.
|
|
656
|
+
|
|
657
|
+
Relevant attributes are the stft data, the audio data,
|
|
658
|
+
and the loudness of the file.
|
|
659
|
+
|
|
660
|
+
Returns
|
|
661
|
+
-------
|
|
662
|
+
AudioSignal
|
|
663
|
+
Same signal, but with all tensors detached.
|
|
664
|
+
"""
|
|
665
|
+
if self._loudness is not None:
|
|
666
|
+
self._loudness = self._loudness.detach()
|
|
667
|
+
if self.stft_data is not None:
|
|
668
|
+
self.stft_data = self.stft_data.detach()
|
|
669
|
+
|
|
670
|
+
self.audio_data = self.audio_data.detach()
|
|
671
|
+
return self
|
|
672
|
+
|
|
673
|
+
def hash(self):
|
|
674
|
+
"""Writes the audio data to a temporary file, and then
|
|
675
|
+
hashes it using hashlib. Useful for creating a file
|
|
676
|
+
name based on the audio content.
|
|
677
|
+
|
|
678
|
+
Returns
|
|
679
|
+
-------
|
|
680
|
+
str
|
|
681
|
+
Hash of audio data.
|
|
682
|
+
|
|
683
|
+
Examples
|
|
684
|
+
--------
|
|
685
|
+
Creating a signal, and writing it to a unique file name:
|
|
686
|
+
|
|
687
|
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
|
688
|
+
>>> hash = signal.hash()
|
|
689
|
+
>>> signal.write(f"{hash}.wav")
|
|
690
|
+
|
|
691
|
+
"""
|
|
692
|
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
|
693
|
+
self.write(f.name)
|
|
694
|
+
h = hashlib.sha256()
|
|
695
|
+
b = bytearray(128 * 1024)
|
|
696
|
+
mv = memoryview(b)
|
|
697
|
+
with open(f.name, "rb", buffering=0) as f:
|
|
698
|
+
for n in iter(lambda: f.readinto(mv), 0):
|
|
699
|
+
h.update(mv[:n])
|
|
700
|
+
file_hash = h.hexdigest()
|
|
701
|
+
return file_hash
|
|
702
|
+
|
|
703
|
+
# Signal operations
|
|
704
|
+
def to_mono(self):
|
|
705
|
+
"""Converts audio data to mono audio, by taking the mean
|
|
706
|
+
along the channels dimension.
|
|
707
|
+
|
|
708
|
+
Returns
|
|
709
|
+
-------
|
|
710
|
+
AudioSignal
|
|
711
|
+
AudioSignal with mean of channels.
|
|
712
|
+
"""
|
|
713
|
+
self.audio_data = self.audio_data.mean(1, keepdim=True)
|
|
714
|
+
return self
|
|
715
|
+
|
|
716
|
+
def resample(self, sample_rate: int):
|
|
717
|
+
"""Resamples the audio, using sinc interpolation. This works on both
|
|
718
|
+
cpu and gpu, and is much faster on gpu.
|
|
719
|
+
|
|
720
|
+
Parameters
|
|
721
|
+
----------
|
|
722
|
+
sample_rate : int
|
|
723
|
+
Sample rate to resample to.
|
|
724
|
+
|
|
725
|
+
Returns
|
|
726
|
+
-------
|
|
727
|
+
AudioSignal
|
|
728
|
+
Resampled AudioSignal
|
|
729
|
+
"""
|
|
730
|
+
if sample_rate == self.sample_rate:
|
|
731
|
+
return self
|
|
732
|
+
self.audio_data = julius.resample_frac(
|
|
733
|
+
self.audio_data, self.sample_rate, sample_rate
|
|
734
|
+
)
|
|
735
|
+
self.sample_rate = sample_rate
|
|
736
|
+
return self
|
|
737
|
+
|
|
738
|
+
# Tensor operations
|
|
739
|
+
def to(self, device: str):
|
|
740
|
+
"""Moves all tensors contained in signal to the specified device.
|
|
741
|
+
|
|
742
|
+
Parameters
|
|
743
|
+
----------
|
|
744
|
+
device : str
|
|
745
|
+
Device to move AudioSignal onto. Typical values are
|
|
746
|
+
"cuda", "cpu", or "cuda:n" to specify the nth gpu.
|
|
747
|
+
|
|
748
|
+
Returns
|
|
749
|
+
-------
|
|
750
|
+
AudioSignal
|
|
751
|
+
AudioSignal with all tensors moved to specified device.
|
|
752
|
+
"""
|
|
753
|
+
if self._loudness is not None:
|
|
754
|
+
self._loudness = self._loudness.to(device)
|
|
755
|
+
if self.stft_data is not None:
|
|
756
|
+
self.stft_data = self.stft_data.to(device)
|
|
757
|
+
if self.audio_data is not None:
|
|
758
|
+
self.audio_data = self.audio_data.to(device)
|
|
759
|
+
return self
|
|
760
|
+
|
|
761
|
+
def float(self):
|
|
762
|
+
"""Calls ``.float()`` on ``self.audio_data``.
|
|
763
|
+
|
|
764
|
+
Returns
|
|
765
|
+
-------
|
|
766
|
+
AudioSignal
|
|
767
|
+
"""
|
|
768
|
+
self.audio_data = self.audio_data.float()
|
|
769
|
+
return self
|
|
770
|
+
|
|
771
|
+
def cpu(self):
|
|
772
|
+
"""Moves AudioSignal to cpu.
|
|
773
|
+
|
|
774
|
+
Returns
|
|
775
|
+
-------
|
|
776
|
+
AudioSignal
|
|
777
|
+
"""
|
|
778
|
+
return self.to("cpu")
|
|
779
|
+
|
|
780
|
+
def cuda(self): # pragma: no cover
|
|
781
|
+
"""Moves AudioSignal to cuda.
|
|
782
|
+
|
|
783
|
+
Returns
|
|
784
|
+
-------
|
|
785
|
+
AudioSignal
|
|
786
|
+
"""
|
|
787
|
+
return self.to("cuda")
|
|
788
|
+
|
|
789
|
+
def numpy(self):
|
|
790
|
+
"""Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
|
|
791
|
+
|
|
792
|
+
Returns
|
|
793
|
+
-------
|
|
794
|
+
np.ndarray
|
|
795
|
+
Audio data as a numpy array.
|
|
796
|
+
"""
|
|
797
|
+
return self.audio_data.detach().cpu().numpy()
|
|
798
|
+
|
|
799
|
+
def zero_pad(self, before: int, after: int):
|
|
800
|
+
"""Zero pads the audio_data tensor before and after.
|
|
801
|
+
|
|
802
|
+
Parameters
|
|
803
|
+
----------
|
|
804
|
+
before : int
|
|
805
|
+
How many zeros to prepend to audio.
|
|
806
|
+
after : int
|
|
807
|
+
How many zeros to append to audio.
|
|
808
|
+
|
|
809
|
+
Returns
|
|
810
|
+
-------
|
|
811
|
+
AudioSignal
|
|
812
|
+
AudioSignal with padding applied.
|
|
813
|
+
"""
|
|
814
|
+
self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
|
|
815
|
+
return self
|
|
816
|
+
|
|
817
|
+
def zero_pad_to(self, length: int, mode: str = "after"):
|
|
818
|
+
"""Pad with zeros to a specified length, either before or after
|
|
819
|
+
the audio data.
|
|
820
|
+
|
|
821
|
+
Parameters
|
|
822
|
+
----------
|
|
823
|
+
length : int
|
|
824
|
+
Length to pad to
|
|
825
|
+
mode : str, optional
|
|
826
|
+
Whether to prepend or append zeros to signal, by default "after"
|
|
827
|
+
|
|
828
|
+
Returns
|
|
829
|
+
-------
|
|
830
|
+
AudioSignal
|
|
831
|
+
AudioSignal with padding applied.
|
|
832
|
+
"""
|
|
833
|
+
if mode == "before":
|
|
834
|
+
self.zero_pad(max(length - self.signal_length, 0), 0)
|
|
835
|
+
elif mode == "after":
|
|
836
|
+
self.zero_pad(0, max(length - self.signal_length, 0))
|
|
837
|
+
return self
|
|
838
|
+
|
|
839
|
+
def trim(self, before: int, after: int):
|
|
840
|
+
"""Trims the audio_data tensor before and after.
|
|
841
|
+
|
|
842
|
+
Parameters
|
|
843
|
+
----------
|
|
844
|
+
before : int
|
|
845
|
+
How many samples to trim from beginning.
|
|
846
|
+
after : int
|
|
847
|
+
How many samples to trim from end.
|
|
848
|
+
|
|
849
|
+
Returns
|
|
850
|
+
-------
|
|
851
|
+
AudioSignal
|
|
852
|
+
AudioSignal with trimming applied.
|
|
853
|
+
"""
|
|
854
|
+
if after == 0:
|
|
855
|
+
self.audio_data = self.audio_data[..., before:]
|
|
856
|
+
else:
|
|
857
|
+
self.audio_data = self.audio_data[..., before:-after]
|
|
858
|
+
return self
|
|
859
|
+
|
|
860
|
+
def truncate_samples(self, length_in_samples: int):
|
|
861
|
+
"""Truncate signal to specified length.
|
|
862
|
+
|
|
863
|
+
Parameters
|
|
864
|
+
----------
|
|
865
|
+
length_in_samples : int
|
|
866
|
+
Truncate to this many samples.
|
|
867
|
+
|
|
868
|
+
Returns
|
|
869
|
+
-------
|
|
870
|
+
AudioSignal
|
|
871
|
+
AudioSignal with truncation applied.
|
|
872
|
+
"""
|
|
873
|
+
self.audio_data = self.audio_data[..., :length_in_samples]
|
|
874
|
+
return self
|
|
875
|
+
|
|
876
|
+
@property
|
|
877
|
+
def device(self):
|
|
878
|
+
"""Get device that AudioSignal is on.
|
|
879
|
+
|
|
880
|
+
Returns
|
|
881
|
+
-------
|
|
882
|
+
torch.device
|
|
883
|
+
Device that AudioSignal is on.
|
|
884
|
+
"""
|
|
885
|
+
if self.audio_data is not None:
|
|
886
|
+
device = self.audio_data.device
|
|
887
|
+
elif self.stft_data is not None:
|
|
888
|
+
device = self.stft_data.device
|
|
889
|
+
return device
|
|
890
|
+
|
|
891
|
+
# Properties
|
|
892
|
+
@property
|
|
893
|
+
def audio_data(self):
|
|
894
|
+
"""Returns the audio data tensor in the object.
|
|
895
|
+
|
|
896
|
+
Audio data is always of the shape
|
|
897
|
+
(batch_size, num_channels, num_samples). If value has less
|
|
898
|
+
than 3 dims (e.g. is (num_channels, num_samples)), then it will
|
|
899
|
+
be reshaped to (1, num_channels, num_samples) - a batch size of 1.
|
|
900
|
+
|
|
901
|
+
Parameters
|
|
902
|
+
----------
|
|
903
|
+
data : typing.Union[torch.Tensor, np.ndarray]
|
|
904
|
+
Audio data to set.
|
|
905
|
+
|
|
906
|
+
Returns
|
|
907
|
+
-------
|
|
908
|
+
torch.Tensor
|
|
909
|
+
Audio samples.
|
|
910
|
+
"""
|
|
911
|
+
return self._audio_data
|
|
912
|
+
|
|
913
|
+
@audio_data.setter
|
|
914
|
+
def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
|
|
915
|
+
if data is not None:
|
|
916
|
+
assert torch.is_tensor(data), "audio_data should be torch.Tensor"
|
|
917
|
+
assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
|
|
918
|
+
self._audio_data = data
|
|
919
|
+
# Old loudness value not guaranteed to be right, reset it.
|
|
920
|
+
self._loudness = None
|
|
921
|
+
return
|
|
922
|
+
|
|
923
|
+
# alias for audio_data
|
|
924
|
+
samples = audio_data
|
|
925
|
+
|
|
926
|
+
@property
|
|
927
|
+
def stft_data(self):
|
|
928
|
+
"""Returns the STFT data inside the signal. Shape is
|
|
929
|
+
(batch, channels, frequencies, time).
|
|
930
|
+
|
|
931
|
+
Returns
|
|
932
|
+
-------
|
|
933
|
+
torch.Tensor
|
|
934
|
+
Complex spectrogram data.
|
|
935
|
+
"""
|
|
936
|
+
return self._stft_data
|
|
937
|
+
|
|
938
|
+
@stft_data.setter
|
|
939
|
+
def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
|
|
940
|
+
if data is not None:
|
|
941
|
+
assert torch.is_tensor(data) and torch.is_complex(data)
|
|
942
|
+
if self.stft_data is not None and self.stft_data.shape != data.shape:
|
|
943
|
+
warnings.warn("stft_data changed shape")
|
|
944
|
+
self._stft_data = data
|
|
945
|
+
return
|
|
946
|
+
|
|
947
|
+
@property
|
|
948
|
+
def batch_size(self):
|
|
949
|
+
"""Batch size of audio signal.
|
|
950
|
+
|
|
951
|
+
Returns
|
|
952
|
+
-------
|
|
953
|
+
int
|
|
954
|
+
Batch size of signal.
|
|
955
|
+
"""
|
|
956
|
+
return self.audio_data.shape[0]
|
|
957
|
+
|
|
958
|
+
@property
|
|
959
|
+
def signal_length(self):
|
|
960
|
+
"""Length of audio signal.
|
|
961
|
+
|
|
962
|
+
Returns
|
|
963
|
+
-------
|
|
964
|
+
int
|
|
965
|
+
Length of signal in samples.
|
|
966
|
+
"""
|
|
967
|
+
return self.audio_data.shape[-1]
|
|
968
|
+
|
|
969
|
+
# alias for signal_length
|
|
970
|
+
length = signal_length
|
|
971
|
+
|
|
972
|
+
@property
|
|
973
|
+
def shape(self):
|
|
974
|
+
"""Shape of audio data.
|
|
975
|
+
|
|
976
|
+
Returns
|
|
977
|
+
-------
|
|
978
|
+
tuple
|
|
979
|
+
Shape of audio data.
|
|
980
|
+
"""
|
|
981
|
+
return self.audio_data.shape
|
|
982
|
+
|
|
983
|
+
@property
|
|
984
|
+
def signal_duration(self):
|
|
985
|
+
"""Length of audio signal in seconds.
|
|
986
|
+
|
|
987
|
+
Returns
|
|
988
|
+
-------
|
|
989
|
+
float
|
|
990
|
+
Length of signal in seconds.
|
|
991
|
+
"""
|
|
992
|
+
return self.signal_length / self.sample_rate
|
|
993
|
+
|
|
994
|
+
# alias for signal_duration
|
|
995
|
+
duration = signal_duration
|
|
996
|
+
|
|
997
|
+
@property
|
|
998
|
+
def num_channels(self):
|
|
999
|
+
"""Number of audio channels.
|
|
1000
|
+
|
|
1001
|
+
Returns
|
|
1002
|
+
-------
|
|
1003
|
+
int
|
|
1004
|
+
Number of audio channels.
|
|
1005
|
+
"""
|
|
1006
|
+
return self.audio_data.shape[1]
|
|
1007
|
+
|
|
1008
|
+
# STFT
|
|
1009
|
+
@staticmethod
|
|
1010
|
+
@functools.lru_cache(None)
|
|
1011
|
+
def get_window(window_type: str, window_length: int, device: str):
|
|
1012
|
+
"""Wrapper around scipy.signal.get_window so one can also get the
|
|
1013
|
+
popular sqrt-hann window. This function caches for efficiency
|
|
1014
|
+
using functools.lru\_cache.
|
|
1015
|
+
|
|
1016
|
+
Parameters
|
|
1017
|
+
----------
|
|
1018
|
+
window_type : str
|
|
1019
|
+
Type of window to get
|
|
1020
|
+
window_length : int
|
|
1021
|
+
Length of the window
|
|
1022
|
+
device : str
|
|
1023
|
+
Device to put window onto.
|
|
1024
|
+
|
|
1025
|
+
Returns
|
|
1026
|
+
-------
|
|
1027
|
+
torch.Tensor
|
|
1028
|
+
Window returned by scipy.signal.get_window, as a tensor.
|
|
1029
|
+
"""
|
|
1030
|
+
from scipy import signal
|
|
1031
|
+
|
|
1032
|
+
if window_type == "average":
|
|
1033
|
+
window = np.ones(window_length) / window_length
|
|
1034
|
+
elif window_type == "sqrt_hann":
|
|
1035
|
+
window = np.sqrt(signal.get_window("hann", window_length))
|
|
1036
|
+
else:
|
|
1037
|
+
window = signal.get_window(window_type, window_length)
|
|
1038
|
+
window = torch.from_numpy(window).to(device).float()
|
|
1039
|
+
return window
|
|
1040
|
+
|
|
1041
|
+
@property
|
|
1042
|
+
def stft_params(self):
|
|
1043
|
+
"""Returns STFTParams object, which can be re-used to other
|
|
1044
|
+
AudioSignals.
|
|
1045
|
+
|
|
1046
|
+
This property can be set as well. If values are not defined in STFTParams,
|
|
1047
|
+
they are inferred automatically from the signal properties. The default is to use
|
|
1048
|
+
32ms windows, with 8ms hop length, and the square root of the hann window.
|
|
1049
|
+
|
|
1050
|
+
Returns
|
|
1051
|
+
-------
|
|
1052
|
+
STFTParams
|
|
1053
|
+
STFT parameters for the AudioSignal.
|
|
1054
|
+
|
|
1055
|
+
Examples
|
|
1056
|
+
--------
|
|
1057
|
+
>>> stft_params = STFTParams(128, 32)
|
|
1058
|
+
>>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
|
|
1059
|
+
>>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
|
|
1060
|
+
>>> signal1.stft_params = STFTParams() # Defaults
|
|
1061
|
+
"""
|
|
1062
|
+
return self._stft_params
|
|
1063
|
+
|
|
1064
|
+
@stft_params.setter
|
|
1065
|
+
def stft_params(self, value: STFTParams):
|
|
1066
|
+
default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
|
|
1067
|
+
default_hop_len = default_win_len // 4
|
|
1068
|
+
default_win_type = "hann"
|
|
1069
|
+
default_match_stride = False
|
|
1070
|
+
default_padding_type = "reflect"
|
|
1071
|
+
|
|
1072
|
+
default_stft_params = STFTParams(
|
|
1073
|
+
window_length=default_win_len,
|
|
1074
|
+
hop_length=default_hop_len,
|
|
1075
|
+
window_type=default_win_type,
|
|
1076
|
+
match_stride=default_match_stride,
|
|
1077
|
+
padding_type=default_padding_type,
|
|
1078
|
+
)._asdict()
|
|
1079
|
+
|
|
1080
|
+
value = value._asdict() if value else default_stft_params
|
|
1081
|
+
|
|
1082
|
+
for key in default_stft_params:
|
|
1083
|
+
if value[key] is None:
|
|
1084
|
+
value[key] = default_stft_params[key]
|
|
1085
|
+
|
|
1086
|
+
self._stft_params = STFTParams(**value)
|
|
1087
|
+
self.stft_data = None
|
|
1088
|
+
|
|
1089
|
+
def compute_stft_padding(
|
|
1090
|
+
self, window_length: int, hop_length: int, match_stride: bool
|
|
1091
|
+
):
|
|
1092
|
+
"""Compute how the STFT should be padded, based on match\_stride.
|
|
1093
|
+
|
|
1094
|
+
Parameters
|
|
1095
|
+
----------
|
|
1096
|
+
window_length : int
|
|
1097
|
+
Window length of STFT.
|
|
1098
|
+
hop_length : int
|
|
1099
|
+
Hop length of STFT.
|
|
1100
|
+
match_stride : bool
|
|
1101
|
+
Whether or not to match stride, making the STFT have the same alignment as
|
|
1102
|
+
convolutional layers.
|
|
1103
|
+
|
|
1104
|
+
Returns
|
|
1105
|
+
-------
|
|
1106
|
+
tuple
|
|
1107
|
+
Amount to pad on either side of audio.
|
|
1108
|
+
"""
|
|
1109
|
+
length = self.signal_length
|
|
1110
|
+
|
|
1111
|
+
if match_stride:
|
|
1112
|
+
assert (
|
|
1113
|
+
hop_length == window_length // 4
|
|
1114
|
+
), "For match_stride, hop must equal n_fft // 4"
|
|
1115
|
+
right_pad = math.ceil(length / hop_length) * hop_length - length
|
|
1116
|
+
pad = (window_length - hop_length) // 2
|
|
1117
|
+
else:
|
|
1118
|
+
right_pad = 0
|
|
1119
|
+
pad = 0
|
|
1120
|
+
|
|
1121
|
+
return right_pad, pad
|
|
1122
|
+
|
|
1123
|
+
def stft(
|
|
1124
|
+
self,
|
|
1125
|
+
window_length: int = None,
|
|
1126
|
+
hop_length: int = None,
|
|
1127
|
+
window_type: str = None,
|
|
1128
|
+
match_stride: bool = None,
|
|
1129
|
+
padding_type: str = None,
|
|
1130
|
+
):
|
|
1131
|
+
"""Computes the short-time Fourier transform of the audio data,
|
|
1132
|
+
with specified STFT parameters.
|
|
1133
|
+
|
|
1134
|
+
Parameters
|
|
1135
|
+
----------
|
|
1136
|
+
window_length : int, optional
|
|
1137
|
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
|
1138
|
+
hop_length : int, optional
|
|
1139
|
+
Hop length of STFT, by default ``window_length // 4``.
|
|
1140
|
+
window_type : str, optional
|
|
1141
|
+
Type of window to use, by default ``sqrt\_hann``.
|
|
1142
|
+
match_stride : bool, optional
|
|
1143
|
+
Whether to match the stride of convolutional layers, by default False
|
|
1144
|
+
padding_type : str, optional
|
|
1145
|
+
Type of padding to use, by default 'reflect'
|
|
1146
|
+
|
|
1147
|
+
Returns
|
|
1148
|
+
-------
|
|
1149
|
+
torch.Tensor
|
|
1150
|
+
STFT of audio data.
|
|
1151
|
+
|
|
1152
|
+
Examples
|
|
1153
|
+
--------
|
|
1154
|
+
Compute the STFT of an AudioSignal:
|
|
1155
|
+
|
|
1156
|
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
|
1157
|
+
>>> signal.stft()
|
|
1158
|
+
|
|
1159
|
+
Vary the window and hop length:
|
|
1160
|
+
|
|
1161
|
+
>>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
|
|
1162
|
+
>>> for stft_param in stft_params:
|
|
1163
|
+
>>> signal.stft_params = stft_params
|
|
1164
|
+
>>> signal.stft()
|
|
1165
|
+
|
|
1166
|
+
"""
|
|
1167
|
+
window_length = (
|
|
1168
|
+
self.stft_params.window_length
|
|
1169
|
+
if window_length is None
|
|
1170
|
+
else int(window_length)
|
|
1171
|
+
)
|
|
1172
|
+
hop_length = (
|
|
1173
|
+
self.stft_params.hop_length if hop_length is None else int(hop_length)
|
|
1174
|
+
)
|
|
1175
|
+
window_type = (
|
|
1176
|
+
self.stft_params.window_type if window_type is None else window_type
|
|
1177
|
+
)
|
|
1178
|
+
match_stride = (
|
|
1179
|
+
self.stft_params.match_stride if match_stride is None else match_stride
|
|
1180
|
+
)
|
|
1181
|
+
padding_type = (
|
|
1182
|
+
self.stft_params.padding_type if padding_type is None else padding_type
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
window = self.get_window(window_type, window_length, self.audio_data.device)
|
|
1186
|
+
window = window.to(self.audio_data.device)
|
|
1187
|
+
|
|
1188
|
+
audio_data = self.audio_data
|
|
1189
|
+
right_pad, pad = self.compute_stft_padding(
|
|
1190
|
+
window_length, hop_length, match_stride
|
|
1191
|
+
)
|
|
1192
|
+
audio_data = torch.nn.functional.pad(
|
|
1193
|
+
audio_data, (pad, pad + right_pad), padding_type
|
|
1194
|
+
)
|
|
1195
|
+
stft_data = torch.stft(
|
|
1196
|
+
audio_data.reshape(-1, audio_data.shape[-1]),
|
|
1197
|
+
n_fft=window_length,
|
|
1198
|
+
hop_length=hop_length,
|
|
1199
|
+
window=window,
|
|
1200
|
+
return_complex=True,
|
|
1201
|
+
center=True,
|
|
1202
|
+
)
|
|
1203
|
+
_, nf, nt = stft_data.shape
|
|
1204
|
+
stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
|
|
1205
|
+
|
|
1206
|
+
if match_stride:
|
|
1207
|
+
# Drop first two and last two frames, which are added
|
|
1208
|
+
# because of padding. Now num_frames * hop_length = num_samples.
|
|
1209
|
+
stft_data = stft_data[..., 2:-2]
|
|
1210
|
+
self.stft_data = stft_data
|
|
1211
|
+
|
|
1212
|
+
return stft_data
|
|
1213
|
+
|
|
1214
|
+
def istft(
|
|
1215
|
+
self,
|
|
1216
|
+
window_length: int = None,
|
|
1217
|
+
hop_length: int = None,
|
|
1218
|
+
window_type: str = None,
|
|
1219
|
+
match_stride: bool = None,
|
|
1220
|
+
length: int = None,
|
|
1221
|
+
):
|
|
1222
|
+
"""Computes inverse STFT and sets it to audio\_data.
|
|
1223
|
+
|
|
1224
|
+
Parameters
|
|
1225
|
+
----------
|
|
1226
|
+
window_length : int, optional
|
|
1227
|
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
|
1228
|
+
hop_length : int, optional
|
|
1229
|
+
Hop length of STFT, by default ``window_length // 4``.
|
|
1230
|
+
window_type : str, optional
|
|
1231
|
+
Type of window to use, by default ``sqrt\_hann``.
|
|
1232
|
+
match_stride : bool, optional
|
|
1233
|
+
Whether to match the stride of convolutional layers, by default False
|
|
1234
|
+
length : int, optional
|
|
1235
|
+
Original length of signal, by default None
|
|
1236
|
+
|
|
1237
|
+
Returns
|
|
1238
|
+
-------
|
|
1239
|
+
AudioSignal
|
|
1240
|
+
AudioSignal with istft applied.
|
|
1241
|
+
|
|
1242
|
+
Raises
|
|
1243
|
+
------
|
|
1244
|
+
RuntimeError
|
|
1245
|
+
Raises an error if stft was not called prior to istft on the signal,
|
|
1246
|
+
or if stft_data is not set.
|
|
1247
|
+
"""
|
|
1248
|
+
if self.stft_data is None:
|
|
1249
|
+
raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
|
|
1250
|
+
|
|
1251
|
+
window_length = (
|
|
1252
|
+
self.stft_params.window_length
|
|
1253
|
+
if window_length is None
|
|
1254
|
+
else int(window_length)
|
|
1255
|
+
)
|
|
1256
|
+
hop_length = (
|
|
1257
|
+
self.stft_params.hop_length if hop_length is None else int(hop_length)
|
|
1258
|
+
)
|
|
1259
|
+
window_type = (
|
|
1260
|
+
self.stft_params.window_type if window_type is None else window_type
|
|
1261
|
+
)
|
|
1262
|
+
match_stride = (
|
|
1263
|
+
self.stft_params.match_stride if match_stride is None else match_stride
|
|
1264
|
+
)
|
|
1265
|
+
|
|
1266
|
+
window = self.get_window(window_type, window_length, self.stft_data.device)
|
|
1267
|
+
|
|
1268
|
+
nb, nch, nf, nt = self.stft_data.shape
|
|
1269
|
+
stft_data = self.stft_data.reshape(nb * nch, nf, nt)
|
|
1270
|
+
right_pad, pad = self.compute_stft_padding(
|
|
1271
|
+
window_length, hop_length, match_stride
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
if length is None:
|
|
1275
|
+
length = self.original_signal_length
|
|
1276
|
+
length = length + 2 * pad + right_pad
|
|
1277
|
+
|
|
1278
|
+
if match_stride:
|
|
1279
|
+
# Zero-pad the STFT on either side, putting back the frames that were
|
|
1280
|
+
# dropped in stft().
|
|
1281
|
+
stft_data = torch.nn.functional.pad(stft_data, (2, 2))
|
|
1282
|
+
|
|
1283
|
+
audio_data = torch.istft(
|
|
1284
|
+
stft_data,
|
|
1285
|
+
n_fft=window_length,
|
|
1286
|
+
hop_length=hop_length,
|
|
1287
|
+
window=window,
|
|
1288
|
+
length=length,
|
|
1289
|
+
center=True,
|
|
1290
|
+
)
|
|
1291
|
+
audio_data = audio_data.reshape(nb, nch, -1)
|
|
1292
|
+
if match_stride:
|
|
1293
|
+
audio_data = audio_data[..., pad : -(pad + right_pad)]
|
|
1294
|
+
self.audio_data = audio_data
|
|
1295
|
+
|
|
1296
|
+
return self
|
|
1297
|
+
|
|
1298
|
+
@staticmethod
|
|
1299
|
+
@functools.lru_cache(None)
|
|
1300
|
+
def get_mel_filters(
|
|
1301
|
+
sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
|
|
1302
|
+
):
|
|
1303
|
+
"""Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
|
|
1304
|
+
|
|
1305
|
+
Parameters
|
|
1306
|
+
----------
|
|
1307
|
+
sr : int
|
|
1308
|
+
Sample rate of audio
|
|
1309
|
+
n_fft : int
|
|
1310
|
+
Number of FFT bins
|
|
1311
|
+
n_mels : int
|
|
1312
|
+
Number of mels
|
|
1313
|
+
fmin : float, optional
|
|
1314
|
+
Lowest frequency, in Hz, by default 0.0
|
|
1315
|
+
fmax : float, optional
|
|
1316
|
+
Highest frequency, by default None
|
|
1317
|
+
|
|
1318
|
+
Returns
|
|
1319
|
+
-------
|
|
1320
|
+
np.ndarray [shape=(n_mels, 1 + n_fft/2)]
|
|
1321
|
+
Mel transform matrix
|
|
1322
|
+
"""
|
|
1323
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
1324
|
+
|
|
1325
|
+
return librosa_mel_fn(
|
|
1326
|
+
sr=sr,
|
|
1327
|
+
n_fft=n_fft,
|
|
1328
|
+
n_mels=n_mels,
|
|
1329
|
+
fmin=fmin,
|
|
1330
|
+
fmax=fmax,
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
def mel_spectrogram(
|
|
1334
|
+
self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
|
|
1335
|
+
):
|
|
1336
|
+
"""Computes a Mel spectrogram.
|
|
1337
|
+
|
|
1338
|
+
Parameters
|
|
1339
|
+
----------
|
|
1340
|
+
n_mels : int, optional
|
|
1341
|
+
Number of mels, by default 80
|
|
1342
|
+
mel_fmin : float, optional
|
|
1343
|
+
Lowest frequency, in Hz, by default 0.0
|
|
1344
|
+
mel_fmax : float, optional
|
|
1345
|
+
Highest frequency, by default None
|
|
1346
|
+
kwargs : dict, optional
|
|
1347
|
+
Keyword arguments to self.stft().
|
|
1348
|
+
|
|
1349
|
+
Returns
|
|
1350
|
+
-------
|
|
1351
|
+
torch.Tensor [shape=(batch, channels, mels, time)]
|
|
1352
|
+
Mel spectrogram.
|
|
1353
|
+
"""
|
|
1354
|
+
stft = self.stft(**kwargs)
|
|
1355
|
+
magnitude = torch.abs(stft)
|
|
1356
|
+
|
|
1357
|
+
nf = magnitude.shape[2]
|
|
1358
|
+
mel_basis = self.get_mel_filters(
|
|
1359
|
+
sr=self.sample_rate,
|
|
1360
|
+
n_fft=2 * (nf - 1),
|
|
1361
|
+
n_mels=n_mels,
|
|
1362
|
+
fmin=mel_fmin,
|
|
1363
|
+
fmax=mel_fmax,
|
|
1364
|
+
)
|
|
1365
|
+
mel_basis = torch.from_numpy(mel_basis).to(self.device)
|
|
1366
|
+
|
|
1367
|
+
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
|
1368
|
+
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
|
1369
|
+
return mel_spectrogram
|
|
1370
|
+
|
|
1371
|
+
@staticmethod
|
|
1372
|
+
@functools.lru_cache(None)
|
|
1373
|
+
def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
|
|
1374
|
+
"""Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
|
|
1375
|
+
it can be normalized depending on norm. For more information about dct:
|
|
1376
|
+
http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
|
|
1377
|
+
|
|
1378
|
+
Parameters
|
|
1379
|
+
----------
|
|
1380
|
+
n_mfcc : int
|
|
1381
|
+
Number of mfccs
|
|
1382
|
+
n_mels : int
|
|
1383
|
+
Number of mels
|
|
1384
|
+
norm : str
|
|
1385
|
+
Use "ortho" to get a orthogonal matrix or None, by default "ortho"
|
|
1386
|
+
device : str, optional
|
|
1387
|
+
Device to load the transformation matrix on, by default None
|
|
1388
|
+
|
|
1389
|
+
Returns
|
|
1390
|
+
-------
|
|
1391
|
+
torch.Tensor [shape=(n_mels, n_mfcc)] T
|
|
1392
|
+
The dct transformation matrix.
|
|
1393
|
+
"""
|
|
1394
|
+
from torchaudio.functional import create_dct
|
|
1395
|
+
|
|
1396
|
+
return create_dct(n_mfcc, n_mels, norm).to(device)
|
|
1397
|
+
|
|
1398
|
+
def mfcc(
|
|
1399
|
+
self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
|
|
1400
|
+
):
|
|
1401
|
+
"""Computes mel-frequency cepstral coefficients (MFCCs).
|
|
1402
|
+
|
|
1403
|
+
Parameters
|
|
1404
|
+
----------
|
|
1405
|
+
n_mfcc : int, optional
|
|
1406
|
+
Number of mels, by default 40
|
|
1407
|
+
n_mels : int, optional
|
|
1408
|
+
Number of mels, by default 80
|
|
1409
|
+
log_offset: float, optional
|
|
1410
|
+
Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
|
|
1411
|
+
kwargs : dict, optional
|
|
1412
|
+
Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
|
|
1413
|
+
|
|
1414
|
+
Returns
|
|
1415
|
+
-------
|
|
1416
|
+
torch.Tensor [shape=(batch, channels, mfccs, time)]
|
|
1417
|
+
MFCCs.
|
|
1418
|
+
"""
|
|
1419
|
+
|
|
1420
|
+
mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
|
|
1421
|
+
mel_spectrogram = torch.log(mel_spectrogram + log_offset)
|
|
1422
|
+
dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
|
|
1423
|
+
|
|
1424
|
+
mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
|
|
1425
|
+
mfcc = mfcc.transpose(-1, -2)
|
|
1426
|
+
return mfcc
|
|
1427
|
+
|
|
1428
|
+
@property
|
|
1429
|
+
def magnitude(self):
|
|
1430
|
+
"""Computes and returns the absolute value of the STFT, which
|
|
1431
|
+
is the magnitude. This value can also be set to some tensor.
|
|
1432
|
+
When set, ``self.stft_data`` is manipulated so that its magnitude
|
|
1433
|
+
matches what this is set to, and modulated by the phase.
|
|
1434
|
+
|
|
1435
|
+
Returns
|
|
1436
|
+
-------
|
|
1437
|
+
torch.Tensor
|
|
1438
|
+
Magnitude of STFT.
|
|
1439
|
+
|
|
1440
|
+
Examples
|
|
1441
|
+
--------
|
|
1442
|
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
|
1443
|
+
>>> magnitude = signal.magnitude # Computes stft if not computed
|
|
1444
|
+
>>> magnitude[magnitude < magnitude.mean()] = 0
|
|
1445
|
+
>>> signal.magnitude = magnitude
|
|
1446
|
+
>>> signal.istft()
|
|
1447
|
+
"""
|
|
1448
|
+
if self.stft_data is None:
|
|
1449
|
+
self.stft()
|
|
1450
|
+
return torch.abs(self.stft_data)
|
|
1451
|
+
|
|
1452
|
+
@magnitude.setter
|
|
1453
|
+
def magnitude(self, value):
|
|
1454
|
+
self.stft_data = value * torch.exp(1j * self.phase)
|
|
1455
|
+
return
|
|
1456
|
+
|
|
1457
|
+
def log_magnitude(
|
|
1458
|
+
self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
|
|
1459
|
+
):
|
|
1460
|
+
"""Computes the log-magnitude of the spectrogram.
|
|
1461
|
+
|
|
1462
|
+
Parameters
|
|
1463
|
+
----------
|
|
1464
|
+
ref_value : float, optional
|
|
1465
|
+
The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
|
|
1466
|
+
Zeros in the output correspond to positions where ``S == ref``,
|
|
1467
|
+
by default 1.0
|
|
1468
|
+
amin : float, optional
|
|
1469
|
+
Minimum threshold for ``S`` and ``ref``, by default 1e-5
|
|
1470
|
+
top_db : float, optional
|
|
1471
|
+
Threshold the output at ``top_db`` below the peak:
|
|
1472
|
+
``max(10 * log10(S/ref)) - top_db``, by default -80.0
|
|
1473
|
+
|
|
1474
|
+
Returns
|
|
1475
|
+
-------
|
|
1476
|
+
torch.Tensor
|
|
1477
|
+
Log-magnitude spectrogram
|
|
1478
|
+
"""
|
|
1479
|
+
magnitude = self.magnitude
|
|
1480
|
+
|
|
1481
|
+
amin = amin**2
|
|
1482
|
+
log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
|
|
1483
|
+
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
|
|
1484
|
+
|
|
1485
|
+
if top_db is not None:
|
|
1486
|
+
log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
|
|
1487
|
+
return log_spec
|
|
1488
|
+
|
|
1489
|
+
@property
|
|
1490
|
+
def phase(self):
|
|
1491
|
+
"""Computes and returns the phase of the STFT.
|
|
1492
|
+
This value can also be set to some tensor.
|
|
1493
|
+
When set, ``self.stft_data`` is manipulated so that its phase
|
|
1494
|
+
matches what this is set to, we original magnitudeith th.
|
|
1495
|
+
|
|
1496
|
+
Returns
|
|
1497
|
+
-------
|
|
1498
|
+
torch.Tensor
|
|
1499
|
+
Phase of STFT.
|
|
1500
|
+
|
|
1501
|
+
Examples
|
|
1502
|
+
--------
|
|
1503
|
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
|
1504
|
+
>>> phase = signal.phase # Computes stft if not computed
|
|
1505
|
+
>>> phase[phase < phase.mean()] = 0
|
|
1506
|
+
>>> signal.phase = phase
|
|
1507
|
+
>>> signal.istft()
|
|
1508
|
+
"""
|
|
1509
|
+
if self.stft_data is None:
|
|
1510
|
+
self.stft()
|
|
1511
|
+
return torch.angle(self.stft_data)
|
|
1512
|
+
|
|
1513
|
+
@phase.setter
|
|
1514
|
+
def phase(self, value):
|
|
1515
|
+
self.stft_data = self.magnitude * torch.exp(1j * value)
|
|
1516
|
+
return
|
|
1517
|
+
|
|
1518
|
+
# Operator overloading
|
|
1519
|
+
def __add__(self, other):
|
|
1520
|
+
new_signal = self.clone()
|
|
1521
|
+
new_signal.audio_data += util._get_value(other)
|
|
1522
|
+
return new_signal
|
|
1523
|
+
|
|
1524
|
+
def __iadd__(self, other):
|
|
1525
|
+
self.audio_data += util._get_value(other)
|
|
1526
|
+
return self
|
|
1527
|
+
|
|
1528
|
+
def __radd__(self, other):
|
|
1529
|
+
return self + other
|
|
1530
|
+
|
|
1531
|
+
def __sub__(self, other):
|
|
1532
|
+
new_signal = self.clone()
|
|
1533
|
+
new_signal.audio_data -= util._get_value(other)
|
|
1534
|
+
return new_signal
|
|
1535
|
+
|
|
1536
|
+
def __isub__(self, other):
|
|
1537
|
+
self.audio_data -= util._get_value(other)
|
|
1538
|
+
return self
|
|
1539
|
+
|
|
1540
|
+
def __mul__(self, other):
|
|
1541
|
+
new_signal = self.clone()
|
|
1542
|
+
new_signal.audio_data *= util._get_value(other)
|
|
1543
|
+
return new_signal
|
|
1544
|
+
|
|
1545
|
+
def __imul__(self, other):
|
|
1546
|
+
self.audio_data *= util._get_value(other)
|
|
1547
|
+
return self
|
|
1548
|
+
|
|
1549
|
+
def __rmul__(self, other):
|
|
1550
|
+
return self * other
|
|
1551
|
+
|
|
1552
|
+
# Representation
|
|
1553
|
+
def _info(self):
|
|
1554
|
+
dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
|
|
1555
|
+
info = {
|
|
1556
|
+
"duration": f"{dur} seconds",
|
|
1557
|
+
"batch_size": self.batch_size,
|
|
1558
|
+
"path": self.path_to_file if self.path_to_file else "path unknown",
|
|
1559
|
+
"sample_rate": self.sample_rate,
|
|
1560
|
+
"num_channels": self.num_channels if self.num_channels else "[unknown]",
|
|
1561
|
+
"audio_data.shape": self.audio_data.shape,
|
|
1562
|
+
"stft_params": self.stft_params,
|
|
1563
|
+
"device": self.device,
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1566
|
+
return info
|
|
1567
|
+
|
|
1568
|
+
def markdown(self):
|
|
1569
|
+
"""Produces a markdown representation of AudioSignal, in a markdown table.
|
|
1570
|
+
|
|
1571
|
+
Returns
|
|
1572
|
+
-------
|
|
1573
|
+
str
|
|
1574
|
+
Markdown representation of AudioSignal.
|
|
1575
|
+
|
|
1576
|
+
Examples
|
|
1577
|
+
--------
|
|
1578
|
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
|
1579
|
+
>>> print(signal.markdown())
|
|
1580
|
+
| Key | Value
|
|
1581
|
+
|---|---
|
|
1582
|
+
| duration | 1.000 seconds |
|
|
1583
|
+
| batch_size | 1 |
|
|
1584
|
+
| path | path unknown |
|
|
1585
|
+
| sample_rate | 44100 |
|
|
1586
|
+
| num_channels | 1 |
|
|
1587
|
+
| audio_data.shape | torch.Size([1, 1, 44100]) |
|
|
1588
|
+
| stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
|
|
1589
|
+
| device | cpu |
|
|
1590
|
+
"""
|
|
1591
|
+
info = self._info()
|
|
1592
|
+
|
|
1593
|
+
FORMAT = "| Key | Value \n" "|---|--- \n"
|
|
1594
|
+
for k, v in info.items():
|
|
1595
|
+
row = f"| {k} | {v} |\n"
|
|
1596
|
+
FORMAT += row
|
|
1597
|
+
return FORMAT
|
|
1598
|
+
|
|
1599
|
+
def __str__(self):
|
|
1600
|
+
info = self._info()
|
|
1601
|
+
|
|
1602
|
+
desc = ""
|
|
1603
|
+
for k, v in info.items():
|
|
1604
|
+
desc += f"{k}: {v}\n"
|
|
1605
|
+
return desc
|
|
1606
|
+
|
|
1607
|
+
def __rich__(self):
|
|
1608
|
+
from rich.table import Table
|
|
1609
|
+
|
|
1610
|
+
info = self._info()
|
|
1611
|
+
|
|
1612
|
+
table = Table(title=f"{self.__class__.__name__}")
|
|
1613
|
+
table.add_column("Key", style="green")
|
|
1614
|
+
table.add_column("Value", style="cyan")
|
|
1615
|
+
|
|
1616
|
+
for k, v in info.items():
|
|
1617
|
+
table.add_row(k, str(v))
|
|
1618
|
+
return table
|
|
1619
|
+
|
|
1620
|
+
# Comparison
|
|
1621
|
+
def __eq__(self, other):
|
|
1622
|
+
for k, v in list(self.__dict__.items()):
|
|
1623
|
+
if torch.is_tensor(v):
|
|
1624
|
+
if not torch.allclose(v, other.__dict__[k], atol=1e-6):
|
|
1625
|
+
max_error = (v - other.__dict__[k]).abs().max()
|
|
1626
|
+
print(f"Max abs error for {k}: {max_error}")
|
|
1627
|
+
return False
|
|
1628
|
+
return True
|
|
1629
|
+
|
|
1630
|
+
# Indexing
|
|
1631
|
+
def __getitem__(self, key):
|
|
1632
|
+
if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
|
|
1633
|
+
assert self.batch_size == 1
|
|
1634
|
+
audio_data = self.audio_data
|
|
1635
|
+
_loudness = self._loudness
|
|
1636
|
+
stft_data = self.stft_data
|
|
1637
|
+
|
|
1638
|
+
elif isinstance(key, (bool, int, list, slice, tuple)) or (
|
|
1639
|
+
torch.is_tensor(key) and key.ndim <= 1
|
|
1640
|
+
):
|
|
1641
|
+
# Indexing only on the batch dimension.
|
|
1642
|
+
# Then let's copy over relevant stuff.
|
|
1643
|
+
# Future work: make this work for time-indexing
|
|
1644
|
+
# as well, using the hop length.
|
|
1645
|
+
audio_data = self.audio_data[key]
|
|
1646
|
+
_loudness = self._loudness[key] if self._loudness is not None else None
|
|
1647
|
+
stft_data = self.stft_data[key] if self.stft_data is not None else None
|
|
1648
|
+
|
|
1649
|
+
sources = None
|
|
1650
|
+
|
|
1651
|
+
copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
|
|
1652
|
+
copy._loudness = _loudness
|
|
1653
|
+
copy._stft_data = stft_data
|
|
1654
|
+
copy.sources = sources
|
|
1655
|
+
|
|
1656
|
+
return copy
|
|
1657
|
+
|
|
1658
|
+
def __setitem__(self, key, value):
|
|
1659
|
+
if not isinstance(value, type(self)):
|
|
1660
|
+
self.audio_data[key] = value
|
|
1661
|
+
return
|
|
1662
|
+
|
|
1663
|
+
if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
|
|
1664
|
+
assert self.batch_size == 1
|
|
1665
|
+
self.audio_data = value.audio_data
|
|
1666
|
+
self._loudness = value._loudness
|
|
1667
|
+
self.stft_data = value.stft_data
|
|
1668
|
+
return
|
|
1669
|
+
|
|
1670
|
+
elif isinstance(key, (bool, int, list, slice, tuple)) or (
|
|
1671
|
+
torch.is_tensor(key) and key.ndim <= 1
|
|
1672
|
+
):
|
|
1673
|
+
if self.audio_data is not None and value.audio_data is not None:
|
|
1674
|
+
self.audio_data[key] = value.audio_data
|
|
1675
|
+
if self._loudness is not None and value._loudness is not None:
|
|
1676
|
+
self._loudness[key] = value._loudness
|
|
1677
|
+
if self.stft_data is not None and value.stft_data is not None:
|
|
1678
|
+
self.stft_data[key] = value.stft_data
|
|
1679
|
+
return
|
|
1680
|
+
|
|
1681
|
+
def __ne__(self, other):
|
|
1682
|
+
return not self == other
|