xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
1
+ import csv
2
+ import glob
3
+ import math
4
+ import numbers
5
+ import os
6
+ import random
7
+ import typing
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Dict
12
+ from typing import List
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torchaudio
17
+ from flatten_dict import flatten
18
+ from flatten_dict import unflatten
19
+
20
+
21
+ @dataclass
22
+ class Info:
23
+ """Shim for torchaudio.info API changes."""
24
+
25
+ sample_rate: float
26
+ num_frames: int
27
+
28
+ @property
29
+ def duration(self) -> float:
30
+ return self.num_frames / self.sample_rate
31
+
32
+
33
+ def info(audio_path: str):
34
+ """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
35
+
36
+ Parameters
37
+ ----------
38
+ audio_path : str
39
+ Path to audio file.
40
+ """
41
+ # try default backend first, then fallback to soundfile
42
+ try:
43
+ info = torchaudio.info(str(audio_path))
44
+ except: # pragma: no cover
45
+ info = torchaudio.backend.soundfile_backend.info(str(audio_path))
46
+
47
+ if isinstance(info, tuple): # pragma: no cover
48
+ signal_info = info[0]
49
+ info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
50
+ else:
51
+ info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
52
+
53
+ return info
54
+
55
+
56
+ def ensure_tensor(
57
+ x: typing.Union[np.ndarray, torch.Tensor, float, int],
58
+ ndim: int = None,
59
+ batch_size: int = None,
60
+ ):
61
+ """Ensures that the input ``x`` is a tensor of specified
62
+ dimensions and batch size.
63
+
64
+ Parameters
65
+ ----------
66
+ x : typing.Union[np.ndarray, torch.Tensor, float, int]
67
+ Data that will become a tensor on its way out.
68
+ ndim : int, optional
69
+ How many dimensions should be in the output, by default None
70
+ batch_size : int, optional
71
+ The batch size of the output, by default None
72
+
73
+ Returns
74
+ -------
75
+ torch.Tensor
76
+ Modified version of ``x`` as a tensor.
77
+ """
78
+ if not torch.is_tensor(x):
79
+ x = torch.as_tensor(x)
80
+ if ndim is not None:
81
+ assert x.ndim <= ndim
82
+ while x.ndim < ndim:
83
+ x = x.unsqueeze(-1)
84
+ if batch_size is not None:
85
+ if x.shape[0] != batch_size:
86
+ shape = list(x.shape)
87
+ shape[0] = batch_size
88
+ x = x.expand(*shape)
89
+ return x
90
+
91
+
92
+ def _get_value(other):
93
+ from . import AudioSignal
94
+
95
+ if isinstance(other, AudioSignal):
96
+ return other.audio_data
97
+ return other
98
+
99
+
100
+ def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
101
+ """Closest frequency bin given a frequency, number
102
+ of bins, and a sampling rate.
103
+
104
+ Parameters
105
+ ----------
106
+ hz : torch.Tensor
107
+ Tensor of frequencies in Hz.
108
+ n_fft : int
109
+ Number of FFT bins.
110
+ sample_rate : int
111
+ Sample rate of audio.
112
+
113
+ Returns
114
+ -------
115
+ torch.Tensor
116
+ Closest bins to the data.
117
+ """
118
+ shape = hz.shape
119
+ hz = hz.flatten()
120
+ freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
121
+ hz[hz > sample_rate / 2] = sample_rate / 2
122
+
123
+ closest = (hz[None, :] - freqs[:, None]).abs()
124
+ closest_bins = closest.min(dim=0).indices
125
+
126
+ return closest_bins.reshape(*shape)
127
+
128
+
129
+ def random_state(seed: typing.Union[int, np.random.RandomState]):
130
+ """
131
+ Turn seed into a np.random.RandomState instance.
132
+
133
+ Parameters
134
+ ----------
135
+ seed : typing.Union[int, np.random.RandomState] or None
136
+ If seed is None, return the RandomState singleton used by np.random.
137
+ If seed is an int, return a new RandomState instance seeded with seed.
138
+ If seed is already a RandomState instance, return it.
139
+ Otherwise raise ValueError.
140
+
141
+ Returns
142
+ -------
143
+ np.random.RandomState
144
+ Random state object.
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ If seed is not valid, an error is thrown.
150
+ """
151
+ if seed is None or seed is np.random:
152
+ return np.random.mtrand._rand
153
+ elif isinstance(seed, (numbers.Integral, np.integer, int)):
154
+ return np.random.RandomState(seed)
155
+ elif isinstance(seed, np.random.RandomState):
156
+ return seed
157
+ else:
158
+ raise ValueError(
159
+ "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
160
+ )
161
+
162
+
163
+ def seed(random_seed, set_cudnn=False):
164
+ """
165
+ Seeds all random states with the same random seed
166
+ for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
167
+ random generators.
168
+ For full reproducibility, two further options must be set
169
+ according to the torch documentation:
170
+ https://pytorch.org/docs/stable/notes/randomness.html
171
+ To do this, ``set_cudnn`` must be True. It defaults to
172
+ False, since setting it to True results in a performance
173
+ hit.
174
+
175
+ Args:
176
+ random_seed (int): integer corresponding to random seed to
177
+ use.
178
+ set_cudnn (bool): Whether or not to set cudnn into determinstic
179
+ mode and off of benchmark mode. Defaults to False.
180
+ """
181
+
182
+ torch.manual_seed(random_seed)
183
+ np.random.seed(random_seed)
184
+ random.seed(random_seed)
185
+
186
+ if set_cudnn:
187
+ torch.backends.cudnn.deterministic = True
188
+ torch.backends.cudnn.benchmark = False
189
+
190
+
191
+ @contextmanager
192
+ def _close_temp_files(tmpfiles: list):
193
+ """Utility function for creating a context and closing all temporary files
194
+ once the context is exited. For correct functionality, all temporary file
195
+ handles created inside the context must be appended to the ```tmpfiles```
196
+ list.
197
+
198
+ This function is taken wholesale from Scaper.
199
+
200
+ Parameters
201
+ ----------
202
+ tmpfiles : list
203
+ List of temporary file handles
204
+ """
205
+
206
+ def _close():
207
+ for t in tmpfiles:
208
+ try:
209
+ t.close()
210
+ os.unlink(t.name)
211
+ except:
212
+ pass
213
+
214
+ try:
215
+ yield
216
+ except: # pragma: no cover
217
+ _close()
218
+ raise
219
+ _close()
220
+
221
+
222
+ AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
223
+
224
+
225
+ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
226
+ """Finds all audio files in a directory recursively.
227
+ Returns a list.
228
+
229
+ Parameters
230
+ ----------
231
+ folder : str
232
+ Folder to look for audio files in, recursively.
233
+ ext : List[str], optional
234
+ Extensions to look for without the ., by default
235
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
236
+ """
237
+ folder = Path(folder)
238
+ # Take care of case where user has passed in an audio file directly
239
+ # into one of the calling functions.
240
+ if str(folder).endswith(tuple(ext)):
241
+ # if, however, there's a glob in the path, we need to
242
+ # return the glob, not the file.
243
+ if "*" in str(folder):
244
+ return glob.glob(str(folder), recursive=("**" in str(folder)))
245
+ else:
246
+ return [folder]
247
+
248
+ files = []
249
+ for x in ext:
250
+ files += folder.glob(f"**/*{x}")
251
+ return files
252
+
253
+
254
+ def read_sources(
255
+ sources: List[str],
256
+ remove_empty: bool = True,
257
+ relative_path: str = "",
258
+ ext: List[str] = AUDIO_EXTENSIONS,
259
+ ):
260
+ """Reads audio sources that can either be folders
261
+ full of audio files, or CSV files that contain paths
262
+ to audio files. CSV files that adhere to the expected
263
+ format can be generated by
264
+ :py:func:`audiotools.data.preprocess.create_csv`.
265
+
266
+ Parameters
267
+ ----------
268
+ sources : List[str]
269
+ List of audio sources to be converted into a
270
+ list of lists of audio files.
271
+ remove_empty : bool, optional
272
+ Whether or not to remove rows with an empty "path"
273
+ from each CSV file, by default True.
274
+
275
+ Returns
276
+ -------
277
+ list
278
+ List of lists of rows of CSV files.
279
+ """
280
+ files = []
281
+ relative_path = Path(relative_path)
282
+ for source in sources:
283
+ source = str(source)
284
+ _files = []
285
+ if source.endswith(".csv"):
286
+ with open(source, "r") as f:
287
+ reader = csv.DictReader(f)
288
+ for x in reader:
289
+ if remove_empty and x["path"] == "":
290
+ continue
291
+ if x["path"] != "":
292
+ x["path"] = str(relative_path / x["path"])
293
+ _files.append(x)
294
+ else:
295
+ for x in find_audio(source, ext=ext):
296
+ x = str(relative_path / x)
297
+ _files.append({"path": x})
298
+ files.append(sorted(_files, key=lambda x: x["path"]))
299
+ return files
300
+
301
+
302
+ def choose_from_list_of_lists(
303
+ state: np.random.RandomState, list_of_lists: list, p: float = None
304
+ ):
305
+ """Choose a single item from a list of lists.
306
+
307
+ Parameters
308
+ ----------
309
+ state : np.random.RandomState
310
+ Random state to use when choosing an item.
311
+ list_of_lists : list
312
+ A list of lists from which items will be drawn.
313
+ p : float, optional
314
+ Probabilities of each list, by default None
315
+
316
+ Returns
317
+ -------
318
+ typing.Any
319
+ An item from the list of lists.
320
+ """
321
+ source_idx = state.choice(list(range(len(list_of_lists))), p=p)
322
+ item_idx = state.randint(len(list_of_lists[source_idx]))
323
+ return list_of_lists[source_idx][item_idx], source_idx, item_idx
324
+
325
+
326
+ @contextmanager
327
+ def chdir(newdir: typing.Union[Path, str]):
328
+ """
329
+ Context manager for switching directories to run a
330
+ function. Useful for when you want to use relative
331
+ paths to different runs.
332
+
333
+ Parameters
334
+ ----------
335
+ newdir : typing.Union[Path, str]
336
+ Directory to switch to.
337
+ """
338
+ curdir = os.getcwd()
339
+ try:
340
+ os.chdir(newdir)
341
+ yield
342
+ finally:
343
+ os.chdir(curdir)
344
+
345
+
346
+ def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
347
+ """Moves items in a batch (typically generated by a DataLoader as a list
348
+ or a dict) to the specified device. This works even if dictionaries
349
+ are nested.
350
+
351
+ Parameters
352
+ ----------
353
+ batch : typing.Union[dict, list, torch.Tensor]
354
+ Batch, typically generated by a dataloader, that will be moved to
355
+ the device.
356
+ device : str, optional
357
+ Device to move batch to, by default "cpu"
358
+
359
+ Returns
360
+ -------
361
+ typing.Union[dict, list, torch.Tensor]
362
+ Batch with all values moved to the specified device.
363
+ """
364
+ if isinstance(batch, dict):
365
+ batch = flatten(batch)
366
+ for key, val in batch.items():
367
+ try:
368
+ batch[key] = val.to(device)
369
+ except:
370
+ pass
371
+ batch = unflatten(batch)
372
+ elif torch.is_tensor(batch):
373
+ batch = batch.to(device)
374
+ elif isinstance(batch, list):
375
+ for i in range(len(batch)):
376
+ try:
377
+ batch[i] = batch[i].to(device)
378
+ except:
379
+ pass
380
+ return batch
381
+
382
+
383
+ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
384
+ """Samples from a distribution defined by a tuple. The first
385
+ item in the tuple is the distribution type, and the rest of the
386
+ items are arguments to that distribution. The distribution function
387
+ is gotten from the ``np.random.RandomState`` object.
388
+
389
+ Parameters
390
+ ----------
391
+ dist_tuple : tuple
392
+ Distribution tuple
393
+ state : np.random.RandomState, optional
394
+ Random state, or seed to use, by default None
395
+
396
+ Returns
397
+ -------
398
+ typing.Union[float, int, str]
399
+ Draw from the distribution.
400
+
401
+ Examples
402
+ --------
403
+ Sample from a uniform distribution:
404
+
405
+ >>> dist_tuple = ("uniform", 0, 1)
406
+ >>> sample_from_dist(dist_tuple)
407
+
408
+ Sample from a constant distribution:
409
+
410
+ >>> dist_tuple = ("const", 0)
411
+ >>> sample_from_dist(dist_tuple)
412
+
413
+ Sample from a normal distribution:
414
+
415
+ >>> dist_tuple = ("normal", 0, 0.5)
416
+ >>> sample_from_dist(dist_tuple)
417
+
418
+ """
419
+ if dist_tuple[0] == "const":
420
+ return dist_tuple[1]
421
+ state = random_state(state)
422
+ dist_fn = getattr(state, dist_tuple[0])
423
+ return dist_fn(*dist_tuple[1:])
424
+
425
+
426
+ def collate(list_of_dicts: list, n_splits: int = None):
427
+ """Collates a list of dictionaries (e.g. as returned by a
428
+ dataloader) into a dictionary with batched values. This routine
429
+ uses the default torch collate function for everything
430
+ except AudioSignal objects, which are handled by the
431
+ :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
432
+ function.
433
+
434
+ This function takes n_splits to enable splitting a batch
435
+ into multiple sub-batches for the purposes of gradient accumulation,
436
+ etc.
437
+
438
+ Parameters
439
+ ----------
440
+ list_of_dicts : list
441
+ List of dictionaries to be collated.
442
+ n_splits : int
443
+ Number of splits to make when creating the batches (split into
444
+ sub-batches). Useful for things like gradient accumulation.
445
+
446
+ Returns
447
+ -------
448
+ dict
449
+ Dictionary containing batched data.
450
+ """
451
+
452
+ from . import AudioSignal
453
+
454
+ batches = []
455
+ list_len = len(list_of_dicts)
456
+
457
+ return_list = False if n_splits is None else True
458
+ n_splits = 1 if n_splits is None else n_splits
459
+ n_items = int(math.ceil(list_len / n_splits))
460
+
461
+ for i in range(0, list_len, n_items):
462
+ # Flatten the dictionaries to avoid recursion.
463
+ list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
464
+ dict_of_lists = {
465
+ k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
466
+ }
467
+
468
+ batch = {}
469
+ for k, v in dict_of_lists.items():
470
+ if isinstance(v, list):
471
+ if all(isinstance(s, AudioSignal) for s in v):
472
+ batch[k] = AudioSignal.batch(v, pad_signals=True)
473
+ else:
474
+ # Borrow the default collate fn from torch.
475
+ batch[k] = torch.utils.data._utils.collate.default_collate(v)
476
+ batches.append(unflatten(batch))
477
+
478
+ batches = batches[0] if not return_list else batches
479
+ return batches
480
+
481
+
482
+ BASE_SIZE = 864
483
+ DEFAULT_FIG_SIZE = (9, 3)
484
+
485
+
486
+ def format_figure(
487
+ fig_size: tuple = None,
488
+ title: str = None,
489
+ fig=None,
490
+ format_axes: bool = True,
491
+ format: bool = True,
492
+ font_color: str = "white",
493
+ ):
494
+ """Prettifies the spectrogram and waveform plots. A title
495
+ can be inset into the top right corner, and the axes can be
496
+ inset into the figure, allowing the data to take up the entire
497
+ image. Used in
498
+
499
+ - :py:func:`audiotools.core.display.DisplayMixin.specshow`
500
+ - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
501
+ - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
502
+
503
+ Parameters
504
+ ----------
505
+ fig_size : tuple, optional
506
+ Size of figure, by default (9, 3)
507
+ title : str, optional
508
+ Title to inset in top right, by default None
509
+ fig : matplotlib.figure.Figure, optional
510
+ Figure object, if None ``plt.gcf()`` will be used, by default None
511
+ format_axes : bool, optional
512
+ Format the axes to be inside the figure, by default True
513
+ format : bool, optional
514
+ This formatting can be skipped entirely by passing ``format=False``
515
+ to any of the plotting functions that use this formater, by default True
516
+ font_color : str, optional
517
+ Color of font of axes, by default "white"
518
+ """
519
+ import matplotlib
520
+ import matplotlib.pyplot as plt
521
+
522
+ if fig_size is None:
523
+ fig_size = DEFAULT_FIG_SIZE
524
+ if not format:
525
+ return
526
+ if fig is None:
527
+ fig = plt.gcf()
528
+ fig.set_size_inches(*fig_size)
529
+ axs = fig.axes
530
+
531
+ pixels = (fig.get_size_inches() * fig.dpi)[0]
532
+ font_scale = pixels / BASE_SIZE
533
+
534
+ if format_axes:
535
+ axs = fig.axes
536
+
537
+ for ax in axs:
538
+ ymin, _ = ax.get_ylim()
539
+ xmin, _ = ax.get_xlim()
540
+
541
+ ticks = ax.get_yticks()
542
+ for t in ticks[2:-1]:
543
+ t = axs[0].annotate(
544
+ f"{(t / 1000):2.1f}k",
545
+ xy=(xmin, t),
546
+ xycoords="data",
547
+ xytext=(5, -5),
548
+ textcoords="offset points",
549
+ ha="left",
550
+ va="top",
551
+ color=font_color,
552
+ fontsize=12 * font_scale,
553
+ alpha=0.75,
554
+ )
555
+
556
+ ticks = ax.get_xticks()[2:]
557
+ for t in ticks[:-1]:
558
+ t = axs[0].annotate(
559
+ f"{t:2.1f}s",
560
+ xy=(t, ymin),
561
+ xycoords="data",
562
+ xytext=(5, 5),
563
+ textcoords="offset points",
564
+ ha="center",
565
+ va="bottom",
566
+ color=font_color,
567
+ fontsize=12 * font_scale,
568
+ alpha=0.75,
569
+ )
570
+
571
+ ax.margins(0, 0)
572
+ ax.set_axis_off()
573
+ ax.xaxis.set_major_locator(plt.NullLocator())
574
+ ax.yaxis.set_major_locator(plt.NullLocator())
575
+
576
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
577
+
578
+ if title is not None:
579
+ t = axs[0].annotate(
580
+ title,
581
+ xy=(1, 1),
582
+ xycoords="axes fraction",
583
+ fontsize=20 * font_scale,
584
+ xytext=(-5, -5),
585
+ textcoords="offset points",
586
+ ha="right",
587
+ va="top",
588
+ color="white",
589
+ )
590
+ t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
591
+
592
+
593
+ def generate_chord_dataset(
594
+ max_voices: int = 8,
595
+ sample_rate: int = 44100,
596
+ num_items: int = 5,
597
+ duration: float = 1.0,
598
+ min_note: str = "C2",
599
+ max_note: str = "C6",
600
+ output_dir: Path = "chords",
601
+ ):
602
+ """
603
+ Generates a toy multitrack dataset of chords, synthesized from sine waves.
604
+
605
+
606
+ Parameters
607
+ ----------
608
+ max_voices : int, optional
609
+ Maximum number of voices in a chord, by default 8
610
+ sample_rate : int, optional
611
+ Sample rate of audio, by default 44100
612
+ num_items : int, optional
613
+ Number of items to generate, by default 5
614
+ duration : float, optional
615
+ Duration of each item, by default 1.0
616
+ min_note : str, optional
617
+ Minimum note in the dataset, by default "C2"
618
+ max_note : str, optional
619
+ Maximum note in the dataset, by default "C6"
620
+ output_dir : Path, optional
621
+ Directory to save the dataset, by default "chords"
622
+
623
+ """
624
+ import librosa
625
+ from . import AudioSignal
626
+ from ..data.preprocess import create_csv
627
+
628
+ min_midi = librosa.note_to_midi(min_note)
629
+ max_midi = librosa.note_to_midi(max_note)
630
+
631
+ tracks = []
632
+ for idx in range(num_items):
633
+ track = {}
634
+ # figure out how many voices to put in this track
635
+ num_voices = random.randint(1, max_voices)
636
+ for voice_idx in range(num_voices):
637
+ # choose some random params
638
+ midinote = random.randint(min_midi, max_midi)
639
+ dur = random.uniform(0.85 * duration, duration)
640
+
641
+ sig = AudioSignal.wave(
642
+ frequency=librosa.midi_to_hz(midinote),
643
+ duration=dur,
644
+ sample_rate=sample_rate,
645
+ shape="sine",
646
+ )
647
+ track[f"voice_{voice_idx}"] = sig
648
+ tracks.append(track)
649
+
650
+ # save the tracks to disk
651
+ output_dir = Path(output_dir)
652
+ output_dir.mkdir(exist_ok=True)
653
+ for idx, track in enumerate(tracks):
654
+ track_dir = output_dir / f"track_{idx}"
655
+ track_dir.mkdir(exist_ok=True)
656
+ for voice_name, sig in track.items():
657
+ sig.write(track_dir / f"{voice_name}.wav")
658
+
659
+ all_voices = list(set([k for track in tracks for k in track.keys()]))
660
+ voice_lists = {voice: [] for voice in all_voices}
661
+ for track in tracks:
662
+ for voice_name in all_voices:
663
+ if voice_name in track:
664
+ voice_lists[voice_name].append(track[voice_name].path_to_file)
665
+ else:
666
+ voice_lists[voice_name].append("")
667
+
668
+ for voice_name, paths in voice_lists.items():
669
+ create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
670
+
671
+ return output_dir