xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1682 @@
1
+ import copy
2
+ import functools
3
+ import hashlib
4
+ import math
5
+ import pathlib
6
+ import tempfile
7
+ import typing
8
+ import warnings
9
+ from collections import namedtuple
10
+ from pathlib import Path
11
+
12
+ import julius
13
+ import numpy as np
14
+ import soundfile
15
+ import torch
16
+
17
+ from . import util
18
+ from .display import DisplayMixin
19
+ from .dsp import DSPMixin
20
+ from .effects import EffectMixin
21
+ from .effects import ImpulseResponseMixin
22
+ from .ffmpeg import FFMPEGMixin
23
+ from .loudness import LoudnessMixin
24
+ from .playback import PlayMixin
25
+ from .whisper import WhisperMixin
26
+
27
+
28
+ STFTParams = namedtuple(
29
+ "STFTParams",
30
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
31
+ )
32
+ """
33
+ STFTParams object is a container that holds STFT parameters - window_length,
34
+ hop_length, and window_type. Not all parameters need to be specified. Ones that
35
+ are not specified will be inferred by the AudioSignal parameters.
36
+
37
+ Parameters
38
+ ----------
39
+ window_length : int, optional
40
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
41
+ hop_length : int, optional
42
+ Hop length of STFT, by default ``window_length // 4``.
43
+ window_type : str, optional
44
+ Type of window to use, by default ``sqrt\\_hann``.
45
+ match_stride : bool, optional
46
+ Whether to match the stride of convolutional layers, by default False
47
+ padding_type : str, optional
48
+ Type of padding to use, by default 'reflect'
49
+ """
50
+ STFTParams.__new__.__defaults__ = (None, None, None, None, None)
51
+
52
+
53
+ class AudioSignal(
54
+ EffectMixin,
55
+ LoudnessMixin,
56
+ PlayMixin,
57
+ ImpulseResponseMixin,
58
+ DSPMixin,
59
+ DisplayMixin,
60
+ FFMPEGMixin,
61
+ WhisperMixin,
62
+ ):
63
+ """This is the core object of this library. Audio is always
64
+ loaded into an AudioSignal, which then enables all the features
65
+ of this library, including audio augmentations, I/O, playback,
66
+ and more.
67
+
68
+ The structure of this object is that the base functionality
69
+ is defined in ``core/audio_signal.py``, while extensions to
70
+ that functionality are defined in the other ``core/*.py``
71
+ files. For example, all the display-based functionality
72
+ (e.g. plot spectrograms, waveforms, write to tensorboard)
73
+ are in ``core/display.py``.
74
+
75
+ Parameters
76
+ ----------
77
+ audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
78
+ Object to create AudioSignal from. Can be a tensor, numpy array,
79
+ or a path to a file. The file is always reshaped to
80
+ sample_rate : int, optional
81
+ Sample rate of the audio. If different from underlying file, resampling is
82
+ performed. If passing in an array or tensor, this must be defined,
83
+ by default None
84
+ stft_params : STFTParams, optional
85
+ Parameters of STFT to use. , by default None
86
+ offset : float, optional
87
+ Offset in seconds to read from file, by default 0
88
+ duration : float, optional
89
+ Duration in seconds to read from file, by default None
90
+ device : str, optional
91
+ Device to load audio onto, by default None
92
+
93
+ Examples
94
+ --------
95
+ Loading an AudioSignal from an array, at a sample rate of
96
+ 44100.
97
+
98
+ >>> signal = AudioSignal(torch.randn(5*44100), 44100)
99
+
100
+ Note, the signal is reshaped to have a batch size, and one
101
+ audio channel:
102
+
103
+ >>> print(signal.shape)
104
+ (1, 1, 44100)
105
+
106
+ You can treat AudioSignals like tensors, and many of the same
107
+ functions you might use on tensors are defined for AudioSignals
108
+ as well:
109
+
110
+ >>> signal.to("cuda")
111
+ >>> signal.cuda()
112
+ >>> signal.clone()
113
+ >>> signal.detach()
114
+
115
+ Indexing AudioSignals returns an AudioSignal:
116
+
117
+ >>> signal[..., 3*44100:4*44100]
118
+
119
+ The above signal is 1 second long, and is also an AudioSignal.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
125
+ sample_rate: int = None,
126
+ stft_params: STFTParams = None,
127
+ offset: float = 0,
128
+ duration: float = None,
129
+ device: str = None,
130
+ ):
131
+ audio_path = None
132
+ audio_array = None
133
+
134
+ if isinstance(audio_path_or_array, str):
135
+ audio_path = audio_path_or_array
136
+ elif isinstance(audio_path_or_array, pathlib.Path):
137
+ audio_path = audio_path_or_array
138
+ elif isinstance(audio_path_or_array, np.ndarray):
139
+ audio_array = audio_path_or_array
140
+ elif torch.is_tensor(audio_path_or_array):
141
+ audio_array = audio_path_or_array
142
+ else:
143
+ raise ValueError(
144
+ "audio_path_or_array must be either a Path, "
145
+ "string, numpy array, or torch Tensor!"
146
+ )
147
+
148
+ self.path_to_file = None
149
+
150
+ self.audio_data = None
151
+ self.sources = None # List of AudioSignal objects.
152
+ self.stft_data = None
153
+ if audio_path is not None:
154
+ self.load_from_file(
155
+ audio_path, offset=offset, duration=duration, device=device
156
+ )
157
+ elif audio_array is not None:
158
+ assert sample_rate is not None, "Must set sample rate!"
159
+ self.load_from_array(audio_array, sample_rate, device=device)
160
+
161
+ self.window = None
162
+ self.stft_params = stft_params
163
+
164
+ self.metadata = {
165
+ "offset": offset,
166
+ "duration": duration,
167
+ }
168
+
169
+ @property
170
+ def path_to_input_file(
171
+ self,
172
+ ):
173
+ """
174
+ Path to input file, if it exists.
175
+ Alias to ``path_to_file`` for backwards compatibility
176
+ """
177
+ return self.path_to_file
178
+
179
+ @classmethod
180
+ def excerpt(
181
+ cls,
182
+ audio_path: typing.Union[str, Path],
183
+ offset: float = None,
184
+ duration: float = None,
185
+ state: typing.Union[np.random.RandomState, int] = None,
186
+ **kwargs,
187
+ ):
188
+ """Randomly draw an excerpt of ``duration`` seconds from an
189
+ audio file specified at ``audio_path``, between ``offset`` seconds
190
+ and end of file. ``state`` can be used to seed the random draw.
191
+
192
+ Parameters
193
+ ----------
194
+ audio_path : typing.Union[str, Path]
195
+ Path to audio file to grab excerpt from.
196
+ offset : float, optional
197
+ Lower bound for the start time, in seconds drawn from
198
+ the file, by default None.
199
+ duration : float, optional
200
+ Duration of excerpt, in seconds, by default None
201
+ state : typing.Union[np.random.RandomState, int], optional
202
+ RandomState or seed of random state, by default None
203
+
204
+ Returns
205
+ -------
206
+ AudioSignal
207
+ AudioSignal containing excerpt.
208
+
209
+ Examples
210
+ --------
211
+ >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
212
+ """
213
+ info = util.info(audio_path)
214
+ total_duration = info.duration
215
+
216
+ state = util.random_state(state)
217
+ lower_bound = 0 if offset is None else offset
218
+ upper_bound = max(total_duration - duration, 0)
219
+ offset = state.uniform(lower_bound, upper_bound)
220
+
221
+ signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
222
+ signal.metadata["offset"] = offset
223
+ signal.metadata["duration"] = duration
224
+
225
+ return signal
226
+
227
+ @classmethod
228
+ def salient_excerpt(
229
+ cls,
230
+ audio_path: typing.Union[str, Path],
231
+ loudness_cutoff: float = None,
232
+ num_tries: int = 8,
233
+ state: typing.Union[np.random.RandomState, int] = None,
234
+ **kwargs,
235
+ ):
236
+ """Similar to AudioSignal.excerpt, except it extracts excerpts only
237
+ if they are above a specified loudness threshold, which is computed via
238
+ a fast LUFS routine.
239
+
240
+ Parameters
241
+ ----------
242
+ audio_path : typing.Union[str, Path]
243
+ Path to audio file to grab excerpt from.
244
+ loudness_cutoff : float, optional
245
+ Loudness threshold in dB. Typical values are ``-40, -60``,
246
+ etc, by default None
247
+ num_tries : int, optional
248
+ Number of tries to grab an excerpt above the threshold
249
+ before giving up, by default 8.
250
+ state : typing.Union[np.random.RandomState, int], optional
251
+ RandomState or seed of random state, by default None
252
+ kwargs : dict
253
+ Keyword arguments to AudioSignal.excerpt
254
+
255
+ Returns
256
+ -------
257
+ AudioSignal
258
+ AudioSignal containing excerpt.
259
+
260
+
261
+ .. warning::
262
+ if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
263
+ result in an infinite loop if ``audio_path`` does not have
264
+ any loud enough excerpts.
265
+
266
+ Examples
267
+ --------
268
+ >>> signal = AudioSignal.salient_excerpt(
269
+ "path/to/audio",
270
+ loudness_cutoff=-40,
271
+ duration=5
272
+ )
273
+ """
274
+ state = util.random_state(state)
275
+ if loudness_cutoff is None:
276
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
277
+ else:
278
+ loudness = -np.inf
279
+ num_try = 0
280
+ while loudness <= loudness_cutoff:
281
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
282
+ loudness = excerpt.loudness()
283
+ num_try += 1
284
+ if num_tries is not None and num_try >= num_tries:
285
+ break
286
+ return excerpt
287
+
288
+ @classmethod
289
+ def zeros(
290
+ cls,
291
+ duration: float,
292
+ sample_rate: int,
293
+ num_channels: int = 1,
294
+ batch_size: int = 1,
295
+ **kwargs,
296
+ ):
297
+ """Helper function create an AudioSignal of all zeros.
298
+
299
+ Parameters
300
+ ----------
301
+ duration : float
302
+ Duration of AudioSignal
303
+ sample_rate : int
304
+ Sample rate of AudioSignal
305
+ num_channels : int, optional
306
+ Number of channels, by default 1
307
+ batch_size : int, optional
308
+ Batch size, by default 1
309
+
310
+ Returns
311
+ -------
312
+ AudioSignal
313
+ AudioSignal containing all zeros.
314
+
315
+ Examples
316
+ --------
317
+ Generate 5 seconds of all zeros at a sample rate of 44100.
318
+
319
+ >>> signal = AudioSignal.zeros(5.0, 44100)
320
+ """
321
+ n_samples = int(duration * sample_rate)
322
+ return cls(
323
+ torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
324
+ )
325
+
326
+ @classmethod
327
+ def wave(
328
+ cls,
329
+ frequency: float,
330
+ duration: float,
331
+ sample_rate: int,
332
+ num_channels: int = 1,
333
+ shape: str = "sine",
334
+ **kwargs,
335
+ ):
336
+ """
337
+ Generate a waveform of a given frequency and shape.
338
+
339
+ Parameters
340
+ ----------
341
+ frequency : float
342
+ Frequency of the waveform
343
+ duration : float
344
+ Duration of the waveform
345
+ sample_rate : int
346
+ Sample rate of the waveform
347
+ num_channels : int, optional
348
+ Number of channels, by default 1
349
+ shape : str, optional
350
+ Shape of the waveform, by default "saw"
351
+ One of "sawtooth", "square", "sine", "triangle"
352
+ kwargs : dict
353
+ Keyword arguments to AudioSignal
354
+ """
355
+ n_samples = int(duration * sample_rate)
356
+ t = torch.linspace(0, duration, n_samples)
357
+ if shape == "sawtooth":
358
+ from scipy.signal import sawtooth
359
+
360
+ wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
361
+ elif shape == "square":
362
+ from scipy.signal import square
363
+
364
+ wave_data = square(2 * np.pi * frequency * t)
365
+ elif shape == "sine":
366
+ wave_data = np.sin(2 * np.pi * frequency * t)
367
+ elif shape == "triangle":
368
+ from scipy.signal import sawtooth
369
+
370
+ # frequency is doubled by the abs call, so omit the 2 in 2pi
371
+ wave_data = sawtooth(np.pi * frequency * t, 0.5)
372
+ wave_data = -np.abs(wave_data) * 2 + 1
373
+ else:
374
+ raise ValueError(f"Invalid shape {shape}")
375
+
376
+ wave_data = torch.tensor(wave_data, dtype=torch.float32)
377
+ wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
378
+ return cls(wave_data, sample_rate, **kwargs)
379
+
380
+ @classmethod
381
+ def batch(
382
+ cls,
383
+ audio_signals: list,
384
+ pad_signals: bool = False,
385
+ truncate_signals: bool = False,
386
+ resample: bool = False,
387
+ dim: int = 0,
388
+ ):
389
+ """Creates a batched AudioSignal from a list of AudioSignals.
390
+
391
+ Parameters
392
+ ----------
393
+ audio_signals : list[AudioSignal]
394
+ List of AudioSignal objects
395
+ pad_signals : bool, optional
396
+ Whether to pad signals to length of the maximum length
397
+ AudioSignal in the list, by default False
398
+ truncate_signals : bool, optional
399
+ Whether to truncate signals to length of shortest length
400
+ AudioSignal in the list, by default False
401
+ resample : bool, optional
402
+ Whether to resample AudioSignal to the sample rate of
403
+ the first AudioSignal in the list, by default False
404
+ dim : int, optional
405
+ Dimension along which to batch the signals.
406
+
407
+ Returns
408
+ -------
409
+ AudioSignal
410
+ Batched AudioSignal.
411
+
412
+ Raises
413
+ ------
414
+ RuntimeError
415
+ If not all AudioSignals are the same sample rate, and
416
+ ``resample=False``, an error is raised.
417
+ RuntimeError
418
+ If not all AudioSignals are the same the length, and
419
+ both ``pad_signals=False`` and ``truncate_signals=False``,
420
+ an error is raised.
421
+
422
+ Examples
423
+ --------
424
+ Batching a bunch of random signals:
425
+
426
+ >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
427
+ >>> signal = AudioSignal.batch(signal_list)
428
+ >>> print(signal.shape)
429
+ (10, 1, 44100)
430
+
431
+ """
432
+ signal_lengths = [x.signal_length for x in audio_signals]
433
+ sample_rates = [x.sample_rate for x in audio_signals]
434
+
435
+ if len(set(sample_rates)) != 1:
436
+ if resample:
437
+ for x in audio_signals:
438
+ x.resample(sample_rates[0])
439
+ else:
440
+ raise RuntimeError(
441
+ f"Not all signals had the same sample rate! Got {sample_rates}. "
442
+ f"All signals must have the same sample rate, or resample must be True. "
443
+ )
444
+
445
+ if len(set(signal_lengths)) != 1:
446
+ if pad_signals:
447
+ max_length = max(signal_lengths)
448
+ for x in audio_signals:
449
+ pad_len = max_length - x.signal_length
450
+ x.zero_pad(0, pad_len)
451
+ elif truncate_signals:
452
+ min_length = min(signal_lengths)
453
+ for x in audio_signals:
454
+ x.truncate_samples(min_length)
455
+ else:
456
+ raise RuntimeError(
457
+ f"Not all signals had the same length! Got {signal_lengths}. "
458
+ f"All signals must be the same length, or pad_signals/truncate_signals "
459
+ f"must be True. "
460
+ )
461
+ # Concatenate along the specified dimension (default 0)
462
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
463
+ audio_paths = [x.path_to_file for x in audio_signals]
464
+
465
+ batched_signal = cls(
466
+ audio_data,
467
+ sample_rate=audio_signals[0].sample_rate,
468
+ )
469
+ batched_signal.path_to_file = audio_paths
470
+ return batched_signal
471
+
472
+ # I/O
473
+ def load_from_file(
474
+ self,
475
+ audio_path: typing.Union[str, Path],
476
+ offset: float,
477
+ duration: float,
478
+ device: str = "cpu",
479
+ ):
480
+ """Loads data from file. Used internally when AudioSignal
481
+ is instantiated with a path to a file.
482
+
483
+ Parameters
484
+ ----------
485
+ audio_path : typing.Union[str, Path]
486
+ Path to file
487
+ offset : float
488
+ Offset in seconds
489
+ duration : float
490
+ Duration in seconds
491
+ device : str, optional
492
+ Device to put AudioSignal on, by default "cpu"
493
+
494
+ Returns
495
+ -------
496
+ AudioSignal
497
+ AudioSignal loaded from file
498
+ """
499
+ import librosa
500
+
501
+ data, sample_rate = librosa.load(
502
+ audio_path,
503
+ offset=offset,
504
+ duration=duration,
505
+ sr=None,
506
+ mono=False,
507
+ )
508
+ data = util.ensure_tensor(data)
509
+ if data.shape[-1] == 0:
510
+ raise RuntimeError(
511
+ f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
512
+ )
513
+
514
+ if data.ndim < 2:
515
+ data = data.unsqueeze(0)
516
+ if data.ndim < 3:
517
+ data = data.unsqueeze(0)
518
+ self.audio_data = data
519
+
520
+ self.original_signal_length = self.signal_length
521
+
522
+ self.sample_rate = sample_rate
523
+ self.path_to_file = audio_path
524
+ return self.to(device)
525
+
526
+ def load_from_array(
527
+ self,
528
+ audio_array: typing.Union[torch.Tensor, np.ndarray],
529
+ sample_rate: int,
530
+ device: str = "cpu",
531
+ ):
532
+ """Loads data from array, reshaping it to be exactly 3
533
+ dimensions. Used internally when AudioSignal is called
534
+ with a tensor or an array.
535
+
536
+ Parameters
537
+ ----------
538
+ audio_array : typing.Union[torch.Tensor, np.ndarray]
539
+ Array/tensor of audio of samples.
540
+ sample_rate : int
541
+ Sample rate of audio
542
+ device : str, optional
543
+ Device to move audio onto, by default "cpu"
544
+
545
+ Returns
546
+ -------
547
+ AudioSignal
548
+ AudioSignal loaded from array
549
+ """
550
+ audio_data = util.ensure_tensor(audio_array)
551
+
552
+ if audio_data.dtype == torch.double:
553
+ audio_data = audio_data.float()
554
+
555
+ if audio_data.ndim < 2:
556
+ audio_data = audio_data.unsqueeze(0)
557
+ if audio_data.ndim < 3:
558
+ audio_data = audio_data.unsqueeze(0)
559
+ self.audio_data = audio_data
560
+
561
+ self.original_signal_length = self.signal_length
562
+
563
+ self.sample_rate = sample_rate
564
+ return self.to(device)
565
+
566
+ def write(self, audio_path: typing.Union[str, Path]):
567
+ """Writes audio to a file. Only writes the audio
568
+ that is in the very first item of the batch. To write other items
569
+ in the batch, index the signal along the batch dimension
570
+ before writing. After writing, the signal's ``path_to_file``
571
+ attribute is updated to the new path.
572
+
573
+ Parameters
574
+ ----------
575
+ audio_path : typing.Union[str, Path]
576
+ Path to write audio to.
577
+
578
+ Returns
579
+ -------
580
+ AudioSignal
581
+ Returns original AudioSignal, so you can use this in a fluent
582
+ interface.
583
+
584
+ Examples
585
+ --------
586
+ Creating and writing a signal to disk:
587
+
588
+ >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
589
+ >>> signal.write("/tmp/out.wav")
590
+
591
+ Writing a different element of the batch:
592
+
593
+ >>> signal[5].write("/tmp/out.wav")
594
+
595
+ Using this in a fluent interface:
596
+
597
+ >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
598
+
599
+ """
600
+ if self.audio_data[0].abs().max() > 1:
601
+ warnings.warn("Audio amplitude > 1 clipped when saving")
602
+ soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
603
+
604
+ self.path_to_file = audio_path
605
+ return self
606
+
607
+ def deepcopy(self):
608
+ """Copies the signal and all of its attributes.
609
+
610
+ Returns
611
+ -------
612
+ AudioSignal
613
+ Deep copy of the audio signal.
614
+ """
615
+ return copy.deepcopy(self)
616
+
617
+ def copy(self):
618
+ """Shallow copy of signal.
619
+
620
+ Returns
621
+ -------
622
+ AudioSignal
623
+ Shallow copy of the audio signal.
624
+ """
625
+ return copy.copy(self)
626
+
627
+ def clone(self):
628
+ """Clones all tensors contained in the AudioSignal,
629
+ and returns a copy of the signal with everything
630
+ cloned. Useful when using AudioSignal within autograd
631
+ computation graphs.
632
+
633
+ Relevant attributes are the stft data, the audio data,
634
+ and the loudness of the file.
635
+
636
+ Returns
637
+ -------
638
+ AudioSignal
639
+ Clone of AudioSignal.
640
+ """
641
+ clone = type(self)(
642
+ self.audio_data.clone(),
643
+ self.sample_rate,
644
+ stft_params=self.stft_params,
645
+ )
646
+ if self.stft_data is not None:
647
+ clone.stft_data = self.stft_data.clone()
648
+ if self._loudness is not None:
649
+ clone._loudness = self._loudness.clone()
650
+ clone.path_to_file = copy.deepcopy(self.path_to_file)
651
+ clone.metadata = copy.deepcopy(self.metadata)
652
+ return clone
653
+
654
+ def detach(self):
655
+ """Detaches tensors contained in AudioSignal.
656
+
657
+ Relevant attributes are the stft data, the audio data,
658
+ and the loudness of the file.
659
+
660
+ Returns
661
+ -------
662
+ AudioSignal
663
+ Same signal, but with all tensors detached.
664
+ """
665
+ if self._loudness is not None:
666
+ self._loudness = self._loudness.detach()
667
+ if self.stft_data is not None:
668
+ self.stft_data = self.stft_data.detach()
669
+
670
+ self.audio_data = self.audio_data.detach()
671
+ return self
672
+
673
+ def hash(self):
674
+ """Writes the audio data to a temporary file, and then
675
+ hashes it using hashlib. Useful for creating a file
676
+ name based on the audio content.
677
+
678
+ Returns
679
+ -------
680
+ str
681
+ Hash of audio data.
682
+
683
+ Examples
684
+ --------
685
+ Creating a signal, and writing it to a unique file name:
686
+
687
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
688
+ >>> hash = signal.hash()
689
+ >>> signal.write(f"{hash}.wav")
690
+
691
+ """
692
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
693
+ self.write(f.name)
694
+ h = hashlib.sha256()
695
+ b = bytearray(128 * 1024)
696
+ mv = memoryview(b)
697
+ with open(f.name, "rb", buffering=0) as f:
698
+ for n in iter(lambda: f.readinto(mv), 0):
699
+ h.update(mv[:n])
700
+ file_hash = h.hexdigest()
701
+ return file_hash
702
+
703
+ # Signal operations
704
+ def to_mono(self):
705
+ """Converts audio data to mono audio, by taking the mean
706
+ along the channels dimension.
707
+
708
+ Returns
709
+ -------
710
+ AudioSignal
711
+ AudioSignal with mean of channels.
712
+ """
713
+ self.audio_data = self.audio_data.mean(1, keepdim=True)
714
+ return self
715
+
716
+ def resample(self, sample_rate: int):
717
+ """Resamples the audio, using sinc interpolation. This works on both
718
+ cpu and gpu, and is much faster on gpu.
719
+
720
+ Parameters
721
+ ----------
722
+ sample_rate : int
723
+ Sample rate to resample to.
724
+
725
+ Returns
726
+ -------
727
+ AudioSignal
728
+ Resampled AudioSignal
729
+ """
730
+ if sample_rate == self.sample_rate:
731
+ return self
732
+ self.audio_data = julius.resample_frac(
733
+ self.audio_data, self.sample_rate, sample_rate
734
+ )
735
+ self.sample_rate = sample_rate
736
+ return self
737
+
738
+ # Tensor operations
739
+ def to(self, device: str):
740
+ """Moves all tensors contained in signal to the specified device.
741
+
742
+ Parameters
743
+ ----------
744
+ device : str
745
+ Device to move AudioSignal onto. Typical values are
746
+ "cuda", "cpu", or "cuda:n" to specify the nth gpu.
747
+
748
+ Returns
749
+ -------
750
+ AudioSignal
751
+ AudioSignal with all tensors moved to specified device.
752
+ """
753
+ if self._loudness is not None:
754
+ self._loudness = self._loudness.to(device)
755
+ if self.stft_data is not None:
756
+ self.stft_data = self.stft_data.to(device)
757
+ if self.audio_data is not None:
758
+ self.audio_data = self.audio_data.to(device)
759
+ return self
760
+
761
+ def float(self):
762
+ """Calls ``.float()`` on ``self.audio_data``.
763
+
764
+ Returns
765
+ -------
766
+ AudioSignal
767
+ """
768
+ self.audio_data = self.audio_data.float()
769
+ return self
770
+
771
+ def cpu(self):
772
+ """Moves AudioSignal to cpu.
773
+
774
+ Returns
775
+ -------
776
+ AudioSignal
777
+ """
778
+ return self.to("cpu")
779
+
780
+ def cuda(self): # pragma: no cover
781
+ """Moves AudioSignal to cuda.
782
+
783
+ Returns
784
+ -------
785
+ AudioSignal
786
+ """
787
+ return self.to("cuda")
788
+
789
+ def numpy(self):
790
+ """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
791
+
792
+ Returns
793
+ -------
794
+ np.ndarray
795
+ Audio data as a numpy array.
796
+ """
797
+ return self.audio_data.detach().cpu().numpy()
798
+
799
+ def zero_pad(self, before: int, after: int):
800
+ """Zero pads the audio_data tensor before and after.
801
+
802
+ Parameters
803
+ ----------
804
+ before : int
805
+ How many zeros to prepend to audio.
806
+ after : int
807
+ How many zeros to append to audio.
808
+
809
+ Returns
810
+ -------
811
+ AudioSignal
812
+ AudioSignal with padding applied.
813
+ """
814
+ self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
815
+ return self
816
+
817
+ def zero_pad_to(self, length: int, mode: str = "after"):
818
+ """Pad with zeros to a specified length, either before or after
819
+ the audio data.
820
+
821
+ Parameters
822
+ ----------
823
+ length : int
824
+ Length to pad to
825
+ mode : str, optional
826
+ Whether to prepend or append zeros to signal, by default "after"
827
+
828
+ Returns
829
+ -------
830
+ AudioSignal
831
+ AudioSignal with padding applied.
832
+ """
833
+ if mode == "before":
834
+ self.zero_pad(max(length - self.signal_length, 0), 0)
835
+ elif mode == "after":
836
+ self.zero_pad(0, max(length - self.signal_length, 0))
837
+ return self
838
+
839
+ def trim(self, before: int, after: int):
840
+ """Trims the audio_data tensor before and after.
841
+
842
+ Parameters
843
+ ----------
844
+ before : int
845
+ How many samples to trim from beginning.
846
+ after : int
847
+ How many samples to trim from end.
848
+
849
+ Returns
850
+ -------
851
+ AudioSignal
852
+ AudioSignal with trimming applied.
853
+ """
854
+ if after == 0:
855
+ self.audio_data = self.audio_data[..., before:]
856
+ else:
857
+ self.audio_data = self.audio_data[..., before:-after]
858
+ return self
859
+
860
+ def truncate_samples(self, length_in_samples: int):
861
+ """Truncate signal to specified length.
862
+
863
+ Parameters
864
+ ----------
865
+ length_in_samples : int
866
+ Truncate to this many samples.
867
+
868
+ Returns
869
+ -------
870
+ AudioSignal
871
+ AudioSignal with truncation applied.
872
+ """
873
+ self.audio_data = self.audio_data[..., :length_in_samples]
874
+ return self
875
+
876
+ @property
877
+ def device(self):
878
+ """Get device that AudioSignal is on.
879
+
880
+ Returns
881
+ -------
882
+ torch.device
883
+ Device that AudioSignal is on.
884
+ """
885
+ if self.audio_data is not None:
886
+ device = self.audio_data.device
887
+ elif self.stft_data is not None:
888
+ device = self.stft_data.device
889
+ return device
890
+
891
+ # Properties
892
+ @property
893
+ def audio_data(self):
894
+ """Returns the audio data tensor in the object.
895
+
896
+ Audio data is always of the shape
897
+ (batch_size, num_channels, num_samples). If value has less
898
+ than 3 dims (e.g. is (num_channels, num_samples)), then it will
899
+ be reshaped to (1, num_channels, num_samples) - a batch size of 1.
900
+
901
+ Parameters
902
+ ----------
903
+ data : typing.Union[torch.Tensor, np.ndarray]
904
+ Audio data to set.
905
+
906
+ Returns
907
+ -------
908
+ torch.Tensor
909
+ Audio samples.
910
+ """
911
+ return self._audio_data
912
+
913
+ @audio_data.setter
914
+ def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
915
+ if data is not None:
916
+ assert torch.is_tensor(data), "audio_data should be torch.Tensor"
917
+ assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
918
+ self._audio_data = data
919
+ # Old loudness value not guaranteed to be right, reset it.
920
+ self._loudness = None
921
+ return
922
+
923
+ # alias for audio_data
924
+ samples = audio_data
925
+
926
+ @property
927
+ def stft_data(self):
928
+ """Returns the STFT data inside the signal. Shape is
929
+ (batch, channels, frequencies, time).
930
+
931
+ Returns
932
+ -------
933
+ torch.Tensor
934
+ Complex spectrogram data.
935
+ """
936
+ return self._stft_data
937
+
938
+ @stft_data.setter
939
+ def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
940
+ if data is not None:
941
+ assert torch.is_tensor(data) and torch.is_complex(data)
942
+ if self.stft_data is not None and self.stft_data.shape != data.shape:
943
+ warnings.warn("stft_data changed shape")
944
+ self._stft_data = data
945
+ return
946
+
947
+ @property
948
+ def batch_size(self):
949
+ """Batch size of audio signal.
950
+
951
+ Returns
952
+ -------
953
+ int
954
+ Batch size of signal.
955
+ """
956
+ return self.audio_data.shape[0]
957
+
958
+ @property
959
+ def signal_length(self):
960
+ """Length of audio signal.
961
+
962
+ Returns
963
+ -------
964
+ int
965
+ Length of signal in samples.
966
+ """
967
+ return self.audio_data.shape[-1]
968
+
969
+ # alias for signal_length
970
+ length = signal_length
971
+
972
+ @property
973
+ def shape(self):
974
+ """Shape of audio data.
975
+
976
+ Returns
977
+ -------
978
+ tuple
979
+ Shape of audio data.
980
+ """
981
+ return self.audio_data.shape
982
+
983
+ @property
984
+ def signal_duration(self):
985
+ """Length of audio signal in seconds.
986
+
987
+ Returns
988
+ -------
989
+ float
990
+ Length of signal in seconds.
991
+ """
992
+ return self.signal_length / self.sample_rate
993
+
994
+ # alias for signal_duration
995
+ duration = signal_duration
996
+
997
+ @property
998
+ def num_channels(self):
999
+ """Number of audio channels.
1000
+
1001
+ Returns
1002
+ -------
1003
+ int
1004
+ Number of audio channels.
1005
+ """
1006
+ return self.audio_data.shape[1]
1007
+
1008
+ # STFT
1009
+ @staticmethod
1010
+ @functools.lru_cache(None)
1011
+ def get_window(window_type: str, window_length: int, device: str):
1012
+ """Wrapper around scipy.signal.get_window so one can also get the
1013
+ popular sqrt-hann window. This function caches for efficiency
1014
+ using functools.lru\\_cache.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ window_type : str
1019
+ Type of window to get
1020
+ window_length : int
1021
+ Length of the window
1022
+ device : str
1023
+ Device to put window onto.
1024
+
1025
+ Returns
1026
+ -------
1027
+ torch.Tensor
1028
+ Window returned by scipy.signal.get_window, as a tensor.
1029
+ """
1030
+ from scipy import signal
1031
+
1032
+ if window_type == "average":
1033
+ window = np.ones(window_length) / window_length
1034
+ elif window_type == "sqrt_hann":
1035
+ window = np.sqrt(signal.get_window("hann", window_length))
1036
+ else:
1037
+ window = signal.get_window(window_type, window_length)
1038
+ window = torch.from_numpy(window).to(device).float()
1039
+ return window
1040
+
1041
+ @property
1042
+ def stft_params(self):
1043
+ """Returns STFTParams object, which can be re-used to other
1044
+ AudioSignals.
1045
+
1046
+ This property can be set as well. If values are not defined in STFTParams,
1047
+ they are inferred automatically from the signal properties. The default is to use
1048
+ 32ms windows, with 8ms hop length, and the square root of the hann window.
1049
+
1050
+ Returns
1051
+ -------
1052
+ STFTParams
1053
+ STFT parameters for the AudioSignal.
1054
+
1055
+ Examples
1056
+ --------
1057
+ >>> stft_params = STFTParams(128, 32)
1058
+ >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
1059
+ >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
1060
+ >>> signal1.stft_params = STFTParams() # Defaults
1061
+ """
1062
+ return self._stft_params
1063
+
1064
+ @stft_params.setter
1065
+ def stft_params(self, value: STFTParams):
1066
+ default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
1067
+ default_hop_len = default_win_len // 4
1068
+ default_win_type = "hann"
1069
+ default_match_stride = False
1070
+ default_padding_type = "reflect"
1071
+
1072
+ default_stft_params = STFTParams(
1073
+ window_length=default_win_len,
1074
+ hop_length=default_hop_len,
1075
+ window_type=default_win_type,
1076
+ match_stride=default_match_stride,
1077
+ padding_type=default_padding_type,
1078
+ )._asdict()
1079
+
1080
+ value = value._asdict() if value else default_stft_params
1081
+
1082
+ for key in default_stft_params:
1083
+ if value[key] is None:
1084
+ value[key] = default_stft_params[key]
1085
+
1086
+ self._stft_params = STFTParams(**value)
1087
+ self.stft_data = None
1088
+
1089
+ def compute_stft_padding(
1090
+ self, window_length: int, hop_length: int, match_stride: bool
1091
+ ):
1092
+ """Compute how the STFT should be padded, based on match\\_stride.
1093
+
1094
+ Parameters
1095
+ ----------
1096
+ window_length : int
1097
+ Window length of STFT.
1098
+ hop_length : int
1099
+ Hop length of STFT.
1100
+ match_stride : bool
1101
+ Whether or not to match stride, making the STFT have the same alignment as
1102
+ convolutional layers.
1103
+
1104
+ Returns
1105
+ -------
1106
+ tuple
1107
+ Amount to pad on either side of audio.
1108
+ """
1109
+ length = self.signal_length
1110
+
1111
+ if match_stride:
1112
+ assert (
1113
+ hop_length == window_length // 4
1114
+ ), "For match_stride, hop must equal n_fft // 4"
1115
+ right_pad = math.ceil(length / hop_length) * hop_length - length
1116
+ pad = (window_length - hop_length) // 2
1117
+ else:
1118
+ right_pad = 0
1119
+ pad = 0
1120
+
1121
+ return right_pad, pad
1122
+
1123
+ def stft(
1124
+ self,
1125
+ window_length: int = None,
1126
+ hop_length: int = None,
1127
+ window_type: str = None,
1128
+ match_stride: bool = None,
1129
+ padding_type: str = None,
1130
+ ):
1131
+ """Computes the short-time Fourier transform of the audio data,
1132
+ with specified STFT parameters.
1133
+
1134
+ Parameters
1135
+ ----------
1136
+ window_length : int, optional
1137
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1138
+ hop_length : int, optional
1139
+ Hop length of STFT, by default ``window_length // 4``.
1140
+ window_type : str, optional
1141
+ Type of window to use, by default ``sqrt\\_hann``.
1142
+ match_stride : bool, optional
1143
+ Whether to match the stride of convolutional layers, by default False
1144
+ padding_type : str, optional
1145
+ Type of padding to use, by default 'reflect'
1146
+
1147
+ Returns
1148
+ -------
1149
+ torch.Tensor
1150
+ STFT of audio data.
1151
+
1152
+ Examples
1153
+ --------
1154
+ Compute the STFT of an AudioSignal:
1155
+
1156
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1157
+ >>> signal.stft()
1158
+
1159
+ Vary the window and hop length:
1160
+
1161
+ >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
1162
+ >>> for stft_param in stft_params:
1163
+ >>> signal.stft_params = stft_params
1164
+ >>> signal.stft()
1165
+
1166
+ """
1167
+ window_length = (
1168
+ self.stft_params.window_length
1169
+ if window_length is None
1170
+ else int(window_length)
1171
+ )
1172
+ hop_length = (
1173
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1174
+ )
1175
+ window_type = (
1176
+ self.stft_params.window_type if window_type is None else window_type
1177
+ )
1178
+ match_stride = (
1179
+ self.stft_params.match_stride if match_stride is None else match_stride
1180
+ )
1181
+ padding_type = (
1182
+ self.stft_params.padding_type if padding_type is None else padding_type
1183
+ )
1184
+
1185
+ window = self.get_window(window_type, window_length, self.audio_data.device)
1186
+ window = window.to(self.audio_data.device)
1187
+
1188
+ audio_data = self.audio_data
1189
+ right_pad, pad = self.compute_stft_padding(
1190
+ window_length, hop_length, match_stride
1191
+ )
1192
+ audio_data = torch.nn.functional.pad(
1193
+ audio_data, (pad, pad + right_pad), padding_type
1194
+ )
1195
+ stft_data = torch.stft(
1196
+ audio_data.reshape(-1, audio_data.shape[-1]),
1197
+ n_fft=window_length,
1198
+ hop_length=hop_length,
1199
+ window=window,
1200
+ return_complex=True,
1201
+ center=True,
1202
+ )
1203
+ _, nf, nt = stft_data.shape
1204
+ stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
1205
+
1206
+ if match_stride:
1207
+ # Drop first two and last two frames, which are added
1208
+ # because of padding. Now num_frames * hop_length = num_samples.
1209
+ stft_data = stft_data[..., 2:-2]
1210
+ self.stft_data = stft_data
1211
+
1212
+ return stft_data
1213
+
1214
+ def istft(
1215
+ self,
1216
+ window_length: int = None,
1217
+ hop_length: int = None,
1218
+ window_type: str = None,
1219
+ match_stride: bool = None,
1220
+ length: int = None,
1221
+ ):
1222
+ """Computes inverse STFT and sets it to audio\\_data.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ window_length : int, optional
1227
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1228
+ hop_length : int, optional
1229
+ Hop length of STFT, by default ``window_length // 4``.
1230
+ window_type : str, optional
1231
+ Type of window to use, by default ``sqrt\\_hann``.
1232
+ match_stride : bool, optional
1233
+ Whether to match the stride of convolutional layers, by default False
1234
+ length : int, optional
1235
+ Original length of signal, by default None
1236
+
1237
+ Returns
1238
+ -------
1239
+ AudioSignal
1240
+ AudioSignal with istft applied.
1241
+
1242
+ Raises
1243
+ ------
1244
+ RuntimeError
1245
+ Raises an error if stft was not called prior to istft on the signal,
1246
+ or if stft_data is not set.
1247
+ """
1248
+ if self.stft_data is None:
1249
+ raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
1250
+
1251
+ window_length = (
1252
+ self.stft_params.window_length
1253
+ if window_length is None
1254
+ else int(window_length)
1255
+ )
1256
+ hop_length = (
1257
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1258
+ )
1259
+ window_type = (
1260
+ self.stft_params.window_type if window_type is None else window_type
1261
+ )
1262
+ match_stride = (
1263
+ self.stft_params.match_stride if match_stride is None else match_stride
1264
+ )
1265
+
1266
+ window = self.get_window(window_type, window_length, self.stft_data.device)
1267
+
1268
+ nb, nch, nf, nt = self.stft_data.shape
1269
+ stft_data = self.stft_data.reshape(nb * nch, nf, nt)
1270
+ right_pad, pad = self.compute_stft_padding(
1271
+ window_length, hop_length, match_stride
1272
+ )
1273
+
1274
+ if length is None:
1275
+ length = self.original_signal_length
1276
+ length = length + 2 * pad + right_pad
1277
+
1278
+ if match_stride:
1279
+ # Zero-pad the STFT on either side, putting back the frames that were
1280
+ # dropped in stft().
1281
+ stft_data = torch.nn.functional.pad(stft_data, (2, 2))
1282
+
1283
+ audio_data = torch.istft(
1284
+ stft_data,
1285
+ n_fft=window_length,
1286
+ hop_length=hop_length,
1287
+ window=window,
1288
+ length=length,
1289
+ center=True,
1290
+ )
1291
+ audio_data = audio_data.reshape(nb, nch, -1)
1292
+ if match_stride:
1293
+ audio_data = audio_data[..., pad : -(pad + right_pad)]
1294
+ self.audio_data = audio_data
1295
+
1296
+ return self
1297
+
1298
+ @staticmethod
1299
+ @functools.lru_cache(None)
1300
+ def get_mel_filters(
1301
+ sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
1302
+ ):
1303
+ """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
1304
+
1305
+ Parameters
1306
+ ----------
1307
+ sr : int
1308
+ Sample rate of audio
1309
+ n_fft : int
1310
+ Number of FFT bins
1311
+ n_mels : int
1312
+ Number of mels
1313
+ fmin : float, optional
1314
+ Lowest frequency, in Hz, by default 0.0
1315
+ fmax : float, optional
1316
+ Highest frequency, by default None
1317
+
1318
+ Returns
1319
+ -------
1320
+ np.ndarray [shape=(n_mels, 1 + n_fft/2)]
1321
+ Mel transform matrix
1322
+ """
1323
+ from librosa.filters import mel as librosa_mel_fn
1324
+
1325
+ return librosa_mel_fn(
1326
+ sr=sr,
1327
+ n_fft=n_fft,
1328
+ n_mels=n_mels,
1329
+ fmin=fmin,
1330
+ fmax=fmax,
1331
+ )
1332
+
1333
+ def mel_spectrogram(
1334
+ self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
1335
+ ):
1336
+ """Computes a Mel spectrogram.
1337
+
1338
+ Parameters
1339
+ ----------
1340
+ n_mels : int, optional
1341
+ Number of mels, by default 80
1342
+ mel_fmin : float, optional
1343
+ Lowest frequency, in Hz, by default 0.0
1344
+ mel_fmax : float, optional
1345
+ Highest frequency, by default None
1346
+ kwargs : dict, optional
1347
+ Keyword arguments to self.stft().
1348
+
1349
+ Returns
1350
+ -------
1351
+ torch.Tensor [shape=(batch, channels, mels, time)]
1352
+ Mel spectrogram.
1353
+ """
1354
+ stft = self.stft(**kwargs)
1355
+ magnitude = torch.abs(stft)
1356
+
1357
+ nf = magnitude.shape[2]
1358
+ mel_basis = self.get_mel_filters(
1359
+ sr=self.sample_rate,
1360
+ n_fft=2 * (nf - 1),
1361
+ n_mels=n_mels,
1362
+ fmin=mel_fmin,
1363
+ fmax=mel_fmax,
1364
+ )
1365
+ mel_basis = torch.from_numpy(mel_basis).to(self.device)
1366
+
1367
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
1368
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
1369
+ return mel_spectrogram
1370
+
1371
+ @staticmethod
1372
+ @functools.lru_cache(None)
1373
+ def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
1374
+ """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
1375
+ it can be normalized depending on norm. For more information about dct:
1376
+ http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
1377
+
1378
+ Parameters
1379
+ ----------
1380
+ n_mfcc : int
1381
+ Number of mfccs
1382
+ n_mels : int
1383
+ Number of mels
1384
+ norm : str
1385
+ Use "ortho" to get a orthogonal matrix or None, by default "ortho"
1386
+ device : str, optional
1387
+ Device to load the transformation matrix on, by default None
1388
+
1389
+ Returns
1390
+ -------
1391
+ torch.Tensor [shape=(n_mels, n_mfcc)] T
1392
+ The dct transformation matrix.
1393
+ """
1394
+ from torchaudio.functional import create_dct
1395
+
1396
+ return create_dct(n_mfcc, n_mels, norm).to(device)
1397
+
1398
+ def mfcc(
1399
+ self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
1400
+ ):
1401
+ """Computes mel-frequency cepstral coefficients (MFCCs).
1402
+
1403
+ Parameters
1404
+ ----------
1405
+ n_mfcc : int, optional
1406
+ Number of mels, by default 40
1407
+ n_mels : int, optional
1408
+ Number of mels, by default 80
1409
+ log_offset: float, optional
1410
+ Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
1411
+ kwargs : dict, optional
1412
+ Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
1413
+
1414
+ Returns
1415
+ -------
1416
+ torch.Tensor [shape=(batch, channels, mfccs, time)]
1417
+ MFCCs.
1418
+ """
1419
+
1420
+ mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
1421
+ mel_spectrogram = torch.log(mel_spectrogram + log_offset)
1422
+ dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
1423
+
1424
+ mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
1425
+ mfcc = mfcc.transpose(-1, -2)
1426
+ return mfcc
1427
+
1428
+ @property
1429
+ def magnitude(self):
1430
+ """Computes and returns the absolute value of the STFT, which
1431
+ is the magnitude. This value can also be set to some tensor.
1432
+ When set, ``self.stft_data`` is manipulated so that its magnitude
1433
+ matches what this is set to, and modulated by the phase.
1434
+
1435
+ Returns
1436
+ -------
1437
+ torch.Tensor
1438
+ Magnitude of STFT.
1439
+
1440
+ Examples
1441
+ --------
1442
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1443
+ >>> magnitude = signal.magnitude # Computes stft if not computed
1444
+ >>> magnitude[magnitude < magnitude.mean()] = 0
1445
+ >>> signal.magnitude = magnitude
1446
+ >>> signal.istft()
1447
+ """
1448
+ if self.stft_data is None:
1449
+ self.stft()
1450
+ return torch.abs(self.stft_data)
1451
+
1452
+ @magnitude.setter
1453
+ def magnitude(self, value):
1454
+ self.stft_data = value * torch.exp(1j * self.phase)
1455
+ return
1456
+
1457
+ def log_magnitude(
1458
+ self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
1459
+ ):
1460
+ """Computes the log-magnitude of the spectrogram.
1461
+
1462
+ Parameters
1463
+ ----------
1464
+ ref_value : float, optional
1465
+ The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
1466
+ Zeros in the output correspond to positions where ``S == ref``,
1467
+ by default 1.0
1468
+ amin : float, optional
1469
+ Minimum threshold for ``S`` and ``ref``, by default 1e-5
1470
+ top_db : float, optional
1471
+ Threshold the output at ``top_db`` below the peak:
1472
+ ``max(10 * log10(S/ref)) - top_db``, by default -80.0
1473
+
1474
+ Returns
1475
+ -------
1476
+ torch.Tensor
1477
+ Log-magnitude spectrogram
1478
+ """
1479
+ magnitude = self.magnitude
1480
+
1481
+ amin = amin**2
1482
+ log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
1483
+ log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
1484
+
1485
+ if top_db is not None:
1486
+ log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
1487
+ return log_spec
1488
+
1489
+ @property
1490
+ def phase(self):
1491
+ """Computes and returns the phase of the STFT.
1492
+ This value can also be set to some tensor.
1493
+ When set, ``self.stft_data`` is manipulated so that its phase
1494
+ matches what this is set to, we original magnitudeith th.
1495
+
1496
+ Returns
1497
+ -------
1498
+ torch.Tensor
1499
+ Phase of STFT.
1500
+
1501
+ Examples
1502
+ --------
1503
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1504
+ >>> phase = signal.phase # Computes stft if not computed
1505
+ >>> phase[phase < phase.mean()] = 0
1506
+ >>> signal.phase = phase
1507
+ >>> signal.istft()
1508
+ """
1509
+ if self.stft_data is None:
1510
+ self.stft()
1511
+ return torch.angle(self.stft_data)
1512
+
1513
+ @phase.setter
1514
+ def phase(self, value):
1515
+ self.stft_data = self.magnitude * torch.exp(1j * value)
1516
+ return
1517
+
1518
+ # Operator overloading
1519
+ def __add__(self, other):
1520
+ new_signal = self.clone()
1521
+ new_signal.audio_data += util._get_value(other)
1522
+ return new_signal
1523
+
1524
+ def __iadd__(self, other):
1525
+ self.audio_data += util._get_value(other)
1526
+ return self
1527
+
1528
+ def __radd__(self, other):
1529
+ return self + other
1530
+
1531
+ def __sub__(self, other):
1532
+ new_signal = self.clone()
1533
+ new_signal.audio_data -= util._get_value(other)
1534
+ return new_signal
1535
+
1536
+ def __isub__(self, other):
1537
+ self.audio_data -= util._get_value(other)
1538
+ return self
1539
+
1540
+ def __mul__(self, other):
1541
+ new_signal = self.clone()
1542
+ new_signal.audio_data *= util._get_value(other)
1543
+ return new_signal
1544
+
1545
+ def __imul__(self, other):
1546
+ self.audio_data *= util._get_value(other)
1547
+ return self
1548
+
1549
+ def __rmul__(self, other):
1550
+ return self * other
1551
+
1552
+ # Representation
1553
+ def _info(self):
1554
+ dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
1555
+ info = {
1556
+ "duration": f"{dur} seconds",
1557
+ "batch_size": self.batch_size,
1558
+ "path": self.path_to_file if self.path_to_file else "path unknown",
1559
+ "sample_rate": self.sample_rate,
1560
+ "num_channels": self.num_channels if self.num_channels else "[unknown]",
1561
+ "audio_data.shape": self.audio_data.shape,
1562
+ "stft_params": self.stft_params,
1563
+ "device": self.device,
1564
+ }
1565
+
1566
+ return info
1567
+
1568
+ def markdown(self):
1569
+ """Produces a markdown representation of AudioSignal, in a markdown table.
1570
+
1571
+ Returns
1572
+ -------
1573
+ str
1574
+ Markdown representation of AudioSignal.
1575
+
1576
+ Examples
1577
+ --------
1578
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1579
+ >>> print(signal.markdown())
1580
+ | Key | Value
1581
+ |---|---
1582
+ | duration | 1.000 seconds |
1583
+ | batch_size | 1 |
1584
+ | path | path unknown |
1585
+ | sample_rate | 44100 |
1586
+ | num_channels | 1 |
1587
+ | audio_data.shape | torch.Size([1, 1, 44100]) |
1588
+ | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
1589
+ | device | cpu |
1590
+ """
1591
+ info = self._info()
1592
+
1593
+ FORMAT = "| Key | Value \n" "|---|--- \n"
1594
+ for k, v in info.items():
1595
+ row = f"| {k} | {v} |\n"
1596
+ FORMAT += row
1597
+ return FORMAT
1598
+
1599
+ def __str__(self):
1600
+ info = self._info()
1601
+
1602
+ desc = ""
1603
+ for k, v in info.items():
1604
+ desc += f"{k}: {v}\n"
1605
+ return desc
1606
+
1607
+ def __rich__(self):
1608
+ from rich.table import Table
1609
+
1610
+ info = self._info()
1611
+
1612
+ table = Table(title=f"{self.__class__.__name__}")
1613
+ table.add_column("Key", style="green")
1614
+ table.add_column("Value", style="cyan")
1615
+
1616
+ for k, v in info.items():
1617
+ table.add_row(k, str(v))
1618
+ return table
1619
+
1620
+ # Comparison
1621
+ def __eq__(self, other):
1622
+ for k, v in list(self.__dict__.items()):
1623
+ if torch.is_tensor(v):
1624
+ if not torch.allclose(v, other.__dict__[k], atol=1e-6):
1625
+ max_error = (v - other.__dict__[k]).abs().max()
1626
+ print(f"Max abs error for {k}: {max_error}")
1627
+ return False
1628
+ return True
1629
+
1630
+ # Indexing
1631
+ def __getitem__(self, key):
1632
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1633
+ assert self.batch_size == 1
1634
+ audio_data = self.audio_data
1635
+ _loudness = self._loudness
1636
+ stft_data = self.stft_data
1637
+
1638
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1639
+ torch.is_tensor(key) and key.ndim <= 1
1640
+ ):
1641
+ # Indexing only on the batch dimension.
1642
+ # Then let's copy over relevant stuff.
1643
+ # Future work: make this work for time-indexing
1644
+ # as well, using the hop length.
1645
+ audio_data = self.audio_data[key]
1646
+ _loudness = self._loudness[key] if self._loudness is not None else None
1647
+ stft_data = self.stft_data[key] if self.stft_data is not None else None
1648
+
1649
+ sources = None
1650
+
1651
+ copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
1652
+ copy._loudness = _loudness
1653
+ copy._stft_data = stft_data
1654
+ copy.sources = sources
1655
+
1656
+ return copy
1657
+
1658
+ def __setitem__(self, key, value):
1659
+ if not isinstance(value, type(self)):
1660
+ self.audio_data[key] = value
1661
+ return
1662
+
1663
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1664
+ assert self.batch_size == 1
1665
+ self.audio_data = value.audio_data
1666
+ self._loudness = value._loudness
1667
+ self.stft_data = value.stft_data
1668
+ return
1669
+
1670
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1671
+ torch.is_tensor(key) and key.ndim <= 1
1672
+ ):
1673
+ if self.audio_data is not None and value.audio_data is not None:
1674
+ self.audio_data[key] = value.audio_data
1675
+ if self._loudness is not None and value._loudness is not None:
1676
+ self._loudness[key] = value._loudness
1677
+ if self.stft_data is not None and value.stft_data is not None:
1678
+ self.stft_data[key] = value.stft_data
1679
+ return
1680
+
1681
+ def __ne__(self, other):
1682
+ return not self == other