xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
1
+ import torch
2
+
3
+
4
+ class WhisperMixin:
5
+ is_initialized = False
6
+
7
+ def setup_whisper(
8
+ self,
9
+ pretrained_model_name_or_path: str = "openai/whisper-base.en",
10
+ device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
11
+ ):
12
+ from transformers import WhisperForConditionalGeneration
13
+ from transformers import WhisperProcessor
14
+
15
+ self.whisper_device = device
16
+ self.whisper_processor = WhisperProcessor.from_pretrained(
17
+ pretrained_model_name_or_path
18
+ )
19
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
20
+ pretrained_model_name_or_path
21
+ ).to(self.whisper_device)
22
+ self.is_initialized = True
23
+
24
+ def get_whisper_features(self) -> torch.Tensor:
25
+ """Preprocess audio signal as per the whisper model's training config.
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ The prepinput features of the audio signal. Shape: (1, channels, seq_len)
31
+ """
32
+ import torch
33
+
34
+ if not self.is_initialized:
35
+ self.setup_whisper()
36
+
37
+ signal = self.to(self.device)
38
+ raw_speech = list(
39
+ (
40
+ signal.clone()
41
+ .resample(self.whisper_processor.feature_extractor.sampling_rate)
42
+ .audio_data[:, 0, :]
43
+ .numpy()
44
+ )
45
+ )
46
+
47
+ with torch.inference_mode():
48
+ input_features = self.whisper_processor(
49
+ raw_speech,
50
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
51
+ return_tensors="pt",
52
+ ).input_features
53
+
54
+ return input_features
55
+
56
+ def get_whisper_transcript(self) -> str:
57
+ """Get the transcript of the audio signal using the whisper model.
58
+
59
+ Returns
60
+ -------
61
+ str
62
+ The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
63
+ """
64
+
65
+ if not self.is_initialized:
66
+ self.setup_whisper()
67
+
68
+ input_features = self.get_whisper_features()
69
+
70
+ with torch.inference_mode():
71
+ input_features = input_features.to(self.whisper_device)
72
+ generated_ids = self.whisper_model.generate(inputs=input_features)
73
+
74
+ transcription = self.whisper_processor.batch_decode(generated_ids)
75
+ return transcription[0]
76
+
77
+ def get_whisper_embeddings(self) -> torch.Tensor:
78
+ """Get the last hidden state embeddings of the audio signal using the whisper model.
79
+
80
+ Returns
81
+ -------
82
+ torch.Tensor
83
+ The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
84
+ """
85
+ import torch
86
+
87
+ if not self.is_initialized:
88
+ self.setup_whisper()
89
+
90
+ input_features = self.get_whisper_features()
91
+ encoder = self.whisper_model.get_encoder()
92
+
93
+ with torch.inference_mode():
94
+ input_features = input_features.to(self.whisper_device)
95
+ embeddings = encoder(input_features)
96
+
97
+ return embeddings.last_hidden_state
@@ -0,0 +1,3 @@
1
+ from . import datasets
2
+ from . import preprocess
3
+ from . import transforms
@@ -0,0 +1,517 @@
1
+ from pathlib import Path
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ from torch.utils.data import SequentialSampler
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from ..core import AudioSignal
12
+ from ..core import util
13
+
14
+
15
+ class AudioLoader:
16
+ """Loads audio endlessly from a list of audio sources
17
+ containing paths to audio files. Audio sources can be
18
+ folders full of audio files (which are found via file
19
+ extension) or by providing a CSV file which contains paths
20
+ to audio files.
21
+
22
+ Parameters
23
+ ----------
24
+ sources : List[str], optional
25
+ Sources containing folders, or CSVs with
26
+ paths to audio files, by default None
27
+ weights : List[float], optional
28
+ Weights to sample audio files from each source, by default None
29
+ relative_path : str, optional
30
+ Path audio should be loaded relative to, by default ""
31
+ transform : Callable, optional
32
+ Transform to instantiate alongside audio sample,
33
+ by default None
34
+ ext : List[str]
35
+ List of extensions to find audio within each source by. Can
36
+ also be a file name (e.g. "vocals.wav"). by default
37
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
38
+ shuffle: bool
39
+ Whether to shuffle the files within the dataloader. Defaults to True.
40
+ shuffle_state: int
41
+ State to use to seed the shuffle of the files.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ sources: List[str] = None,
47
+ weights: List[float] = None,
48
+ transform: Callable = None,
49
+ relative_path: str = "",
50
+ ext: List[str] = util.AUDIO_EXTENSIONS,
51
+ shuffle: bool = True,
52
+ shuffle_state: int = 0,
53
+ ):
54
+ self.audio_lists = util.read_sources(
55
+ sources, relative_path=relative_path, ext=ext
56
+ )
57
+
58
+ self.audio_indices = [
59
+ (src_idx, item_idx)
60
+ for src_idx, src in enumerate(self.audio_lists)
61
+ for item_idx in range(len(src))
62
+ ]
63
+ if shuffle:
64
+ state = util.random_state(shuffle_state)
65
+ state.shuffle(self.audio_indices)
66
+
67
+ self.sources = sources
68
+ self.weights = weights
69
+ self.transform = transform
70
+
71
+ def __call__(
72
+ self,
73
+ state,
74
+ sample_rate: int,
75
+ duration: float,
76
+ loudness_cutoff: float = -40,
77
+ num_channels: int = 1,
78
+ offset: float = None,
79
+ source_idx: int = None,
80
+ item_idx: int = None,
81
+ global_idx: int = None,
82
+ ):
83
+ if source_idx is not None and item_idx is not None:
84
+ try:
85
+ audio_info = self.audio_lists[source_idx][item_idx]
86
+ except:
87
+ audio_info = {"path": "none"}
88
+ elif global_idx is not None:
89
+ source_idx, item_idx = self.audio_indices[
90
+ global_idx % len(self.audio_indices)
91
+ ]
92
+ audio_info = self.audio_lists[source_idx][item_idx]
93
+ else:
94
+ audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
95
+ state, self.audio_lists, p=self.weights
96
+ )
97
+
98
+ path = audio_info["path"]
99
+ signal = AudioSignal.zeros(duration, sample_rate, num_channels)
100
+
101
+ if path != "none":
102
+ if offset is None:
103
+ signal = AudioSignal.salient_excerpt(
104
+ path,
105
+ duration=duration,
106
+ state=state,
107
+ loudness_cutoff=loudness_cutoff,
108
+ )
109
+ else:
110
+ signal = AudioSignal(
111
+ path,
112
+ offset=offset,
113
+ duration=duration,
114
+ )
115
+
116
+ if num_channels == 1:
117
+ signal = signal.to_mono()
118
+ signal = signal.resample(sample_rate)
119
+
120
+ if signal.duration < duration:
121
+ signal = signal.zero_pad_to(int(duration * sample_rate))
122
+
123
+ for k, v in audio_info.items():
124
+ signal.metadata[k] = v
125
+
126
+ item = {
127
+ "signal": signal,
128
+ "source_idx": source_idx,
129
+ "item_idx": item_idx,
130
+ "source": str(self.sources[source_idx]),
131
+ "path": str(path),
132
+ }
133
+ if self.transform is not None:
134
+ item["transform_args"] = self.transform.instantiate(state, signal=signal)
135
+ return item
136
+
137
+
138
+ def default_matcher(x, y):
139
+ return Path(x).parent == Path(y).parent
140
+
141
+
142
+ def align_lists(lists, matcher: Callable = default_matcher):
143
+ longest_list = lists[np.argmax([len(l) for l in lists])]
144
+ for i, x in enumerate(longest_list):
145
+ for l in lists:
146
+ if i >= len(l):
147
+ l.append({"path": "none"})
148
+ elif not matcher(l[i]["path"], x["path"]):
149
+ l.insert(i, {"path": "none"})
150
+ return lists
151
+
152
+
153
+ class AudioDataset:
154
+ """Loads audio from multiple loaders (with associated transforms)
155
+ for a specified number of samples. Excerpts are drawn randomly
156
+ of the specified duration, above a specified loudness threshold
157
+ and are resampled on the fly to the desired sample rate
158
+ (if it is different from the audio source sample rate).
159
+
160
+ This takes either a single AudioLoader object,
161
+ a dictionary of AudioLoader objects, or a dictionary of AudioLoader
162
+ objects. Each AudioLoader is called by the dataset, and the
163
+ result is placed in the output dictionary. A transform can also be
164
+ specified for the entire dataset, rather than for each specific
165
+ loader. This transform can be applied to the output of all the
166
+ loaders if desired.
167
+
168
+ AudioLoader objects can be specified as aligned, which means the
169
+ loaders correspond to multitrack audio (e.g. a vocals, bass,
170
+ drums, and other loader for multitrack music mixtures).
171
+
172
+
173
+ Parameters
174
+ ----------
175
+ loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
176
+ AudioLoaders to sample audio from.
177
+ sample_rate : int
178
+ Desired sample rate.
179
+ n_examples : int, optional
180
+ Number of examples (length of dataset), by default 1000
181
+ duration : float, optional
182
+ Duration of audio samples, by default 0.5
183
+ loudness_cutoff : float, optional
184
+ Loudness cutoff threshold for audio samples, by default -40
185
+ num_channels : int, optional
186
+ Number of channels in output audio, by default 1
187
+ transform : Callable, optional
188
+ Transform to instantiate alongside each dataset item, by default None
189
+ aligned : bool, optional
190
+ Whether the loaders should be sampled in an aligned manner (e.g. same
191
+ offset, duration, and matched file name), by default False
192
+ shuffle_loaders : bool, optional
193
+ Whether to shuffle the loaders before sampling from them, by default False
194
+ matcher : Callable
195
+ How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
196
+ by default uses the parent directory of each file.
197
+ without_replacement : bool
198
+ Whether to choose files with or without replacement, by default True.
199
+
200
+
201
+ Examples
202
+ --------
203
+ >>> from audiotools.data.datasets import AudioLoader
204
+ >>> from audiotools.data.datasets import AudioDataset
205
+ >>> from audiotools import transforms as tfm
206
+ >>> import numpy as np
207
+ >>>
208
+ >>> loaders = [
209
+ >>> AudioLoader(
210
+ >>> sources=[f"tests/audio/spk"],
211
+ >>> transform=tfm.Equalizer(),
212
+ >>> ext=["wav"],
213
+ >>> )
214
+ >>> for i in range(5)
215
+ >>> ]
216
+ >>>
217
+ >>> dataset = AudioDataset(
218
+ >>> loaders = loaders,
219
+ >>> sample_rate = 44100,
220
+ >>> duration = 1.0,
221
+ >>> transform = tfm.RescaleAudio(),
222
+ >>> )
223
+ >>>
224
+ >>> item = dataset[np.random.randint(len(dataset))]
225
+ >>>
226
+ >>> for i in range(len(loaders)):
227
+ >>> item[i]["signal"] = loaders[i].transform(
228
+ >>> item[i]["signal"], **item[i]["transform_args"]
229
+ >>> )
230
+ >>> item[i]["signal"].widget(i)
231
+ >>>
232
+ >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
233
+ >>> mix = dataset.transform(mix, **item["transform_args"])
234
+ >>> mix.widget("mix")
235
+
236
+ Below is an example of how one could load MUSDB multitrack data:
237
+
238
+ >>> import audiotools as at
239
+ >>> from pathlib import Path
240
+ >>> from audiotools import transforms as tfm
241
+ >>> import numpy as np
242
+ >>> import torch
243
+ >>>
244
+ >>> def build_dataset(
245
+ >>> sample_rate: int = 44100,
246
+ >>> duration: float = 5.0,
247
+ >>> musdb_path: str = "~/.data/musdb/",
248
+ >>> ):
249
+ >>> musdb_path = Path(musdb_path).expanduser()
250
+ >>> loaders = {
251
+ >>> src: at.datasets.AudioLoader(
252
+ >>> sources=[musdb_path],
253
+ >>> transform=tfm.Compose(
254
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
255
+ >>> tfm.Silence(prob=0.1),
256
+ >>> ),
257
+ >>> ext=[f"{src}.wav"],
258
+ >>> )
259
+ >>> for src in ["vocals", "bass", "drums", "other"]
260
+ >>> }
261
+ >>>
262
+ >>> dataset = at.datasets.AudioDataset(
263
+ >>> loaders=loaders,
264
+ >>> sample_rate=sample_rate,
265
+ >>> duration=duration,
266
+ >>> num_channels=1,
267
+ >>> aligned=True,
268
+ >>> transform=tfm.RescaleAudio(),
269
+ >>> shuffle_loaders=True,
270
+ >>> )
271
+ >>> return dataset, list(loaders.keys())
272
+ >>>
273
+ >>> train_data, sources = build_dataset()
274
+ >>> dataloader = torch.utils.data.DataLoader(
275
+ >>> train_data,
276
+ >>> batch_size=16,
277
+ >>> num_workers=0,
278
+ >>> collate_fn=train_data.collate,
279
+ >>> )
280
+ >>> batch = next(iter(dataloader))
281
+ >>>
282
+ >>> for k in sources:
283
+ >>> src = batch[k]
284
+ >>> src["transformed"] = train_data.loaders[k].transform(
285
+ >>> src["signal"].clone(), **src["transform_args"]
286
+ >>> )
287
+ >>>
288
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
289
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
290
+ >>>
291
+ >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
292
+ >>> # Construct the targets:
293
+ >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
294
+
295
+ Similarly, here's example code for loading Slakh data:
296
+
297
+ >>> import audiotools as at
298
+ >>> from pathlib import Path
299
+ >>> from audiotools import transforms as tfm
300
+ >>> import numpy as np
301
+ >>> import torch
302
+ >>> import glob
303
+ >>>
304
+ >>> def build_dataset(
305
+ >>> sample_rate: int = 16000,
306
+ >>> duration: float = 10.0,
307
+ >>> slakh_path: str = "~/.data/slakh/",
308
+ >>> ):
309
+ >>> slakh_path = Path(slakh_path).expanduser()
310
+ >>>
311
+ >>> # Find the max number of sources in Slakh
312
+ >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
313
+ >>> n_sources = len(list(set(src_names)))
314
+ >>>
315
+ >>> loaders = {
316
+ >>> f"S{i:02d}": at.datasets.AudioLoader(
317
+ >>> sources=[slakh_path],
318
+ >>> transform=tfm.Compose(
319
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
320
+ >>> tfm.Silence(prob=0.1),
321
+ >>> ),
322
+ >>> ext=[f"S{i:02d}.wav"],
323
+ >>> )
324
+ >>> for i in range(n_sources)
325
+ >>> }
326
+ >>> dataset = at.datasets.AudioDataset(
327
+ >>> loaders=loaders,
328
+ >>> sample_rate=sample_rate,
329
+ >>> duration=duration,
330
+ >>> num_channels=1,
331
+ >>> aligned=True,
332
+ >>> transform=tfm.RescaleAudio(),
333
+ >>> shuffle_loaders=False,
334
+ >>> )
335
+ >>>
336
+ >>> return dataset, list(loaders.keys())
337
+ >>>
338
+ >>> train_data, sources = build_dataset()
339
+ >>> dataloader = torch.utils.data.DataLoader(
340
+ >>> train_data,
341
+ >>> batch_size=16,
342
+ >>> num_workers=0,
343
+ >>> collate_fn=train_data.collate,
344
+ >>> )
345
+ >>> batch = next(iter(dataloader))
346
+ >>>
347
+ >>> for k in sources:
348
+ >>> src = batch[k]
349
+ >>> src["transformed"] = train_data.loaders[k].transform(
350
+ >>> src["signal"].clone(), **src["transform_args"]
351
+ >>> )
352
+ >>>
353
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
354
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
355
+
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
361
+ sample_rate: int,
362
+ n_examples: int = 1000,
363
+ duration: float = 0.5,
364
+ offset: float = None,
365
+ loudness_cutoff: float = -40,
366
+ num_channels: int = 1,
367
+ transform: Callable = None,
368
+ aligned: bool = False,
369
+ shuffle_loaders: bool = False,
370
+ matcher: Callable = default_matcher,
371
+ without_replacement: bool = True,
372
+ ):
373
+ # Internally we convert loaders to a dictionary
374
+ if isinstance(loaders, list):
375
+ loaders = {i: l for i, l in enumerate(loaders)}
376
+ elif isinstance(loaders, AudioLoader):
377
+ loaders = {0: loaders}
378
+
379
+ self.loaders = loaders
380
+ self.loudness_cutoff = loudness_cutoff
381
+ self.num_channels = num_channels
382
+
383
+ self.length = n_examples
384
+ self.transform = transform
385
+ self.sample_rate = sample_rate
386
+ self.duration = duration
387
+ self.offset = offset
388
+ self.aligned = aligned
389
+ self.shuffle_loaders = shuffle_loaders
390
+ self.without_replacement = without_replacement
391
+
392
+ if aligned:
393
+ loaders_list = list(loaders.values())
394
+ for i in range(len(loaders_list[0].audio_lists)):
395
+ input_lists = [l.audio_lists[i] for l in loaders_list]
396
+ # Alignment happens in-place
397
+ align_lists(input_lists, matcher)
398
+
399
+ def __getitem__(self, idx):
400
+ state = util.random_state(idx)
401
+ offset = None if self.offset is None else self.offset
402
+ item = {}
403
+
404
+ keys = list(self.loaders.keys())
405
+ if self.shuffle_loaders:
406
+ state.shuffle(keys)
407
+
408
+ loader_kwargs = {
409
+ "state": state,
410
+ "sample_rate": self.sample_rate,
411
+ "duration": self.duration,
412
+ "loudness_cutoff": self.loudness_cutoff,
413
+ "num_channels": self.num_channels,
414
+ "global_idx": idx if self.without_replacement else None,
415
+ }
416
+
417
+ # Draw item from first loader
418
+ loader = self.loaders[keys[0]]
419
+ item[keys[0]] = loader(**loader_kwargs)
420
+
421
+ for key in keys[1:]:
422
+ loader = self.loaders[key]
423
+ if self.aligned:
424
+ # Path mapper takes the current loader + everything
425
+ # returned by the first loader.
426
+ offset = item[keys[0]]["signal"].metadata["offset"]
427
+ loader_kwargs.update(
428
+ {
429
+ "offset": offset,
430
+ "source_idx": item[keys[0]]["source_idx"],
431
+ "item_idx": item[keys[0]]["item_idx"],
432
+ }
433
+ )
434
+ item[key] = loader(**loader_kwargs)
435
+
436
+ # Sort dictionary back into original order
437
+ keys = list(self.loaders.keys())
438
+ item = {k: item[k] for k in keys}
439
+
440
+ item["idx"] = idx
441
+ if self.transform is not None:
442
+ item["transform_args"] = self.transform.instantiate(
443
+ state=state, signal=item[keys[0]]["signal"]
444
+ )
445
+
446
+ # If there's only one loader, pop it up
447
+ # to the main dictionary, instead of keeping it
448
+ # nested.
449
+ if len(keys) == 1:
450
+ item.update(item.pop(keys[0]))
451
+
452
+ return item
453
+
454
+ def __len__(self):
455
+ return self.length
456
+
457
+ @staticmethod
458
+ def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
459
+ """Collates items drawn from this dataset. Uses
460
+ :py:func:`audiotools.core.util.collate`.
461
+
462
+ Parameters
463
+ ----------
464
+ list_of_dicts : typing.Union[list, dict]
465
+ Data drawn from each item.
466
+ n_splits : int
467
+ Number of splits to make when creating the batches (split into
468
+ sub-batches). Useful for things like gradient accumulation.
469
+
470
+ Returns
471
+ -------
472
+ dict
473
+ Dictionary of batched data.
474
+ """
475
+ return util.collate(list_of_dicts, n_splits=n_splits)
476
+
477
+
478
+ class ConcatDataset(AudioDataset):
479
+ def __init__(self, datasets: list):
480
+ self.datasets = datasets
481
+
482
+ def __len__(self):
483
+ return sum([len(d) for d in self.datasets])
484
+
485
+ def __getitem__(self, idx):
486
+ dataset = self.datasets[idx % len(self.datasets)]
487
+ return dataset[idx // len(self.datasets)]
488
+
489
+
490
+ class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
491
+ """Distributed sampler that can be resumed from a given start index."""
492
+
493
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
494
+ super().__init__(dataset, **kwargs)
495
+ # Start index, allows to resume an experiment at the index it was
496
+ self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
497
+
498
+ def __iter__(self):
499
+ for i, idx in enumerate(super().__iter__()):
500
+ if i >= self.start_idx:
501
+ yield idx
502
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
503
+
504
+
505
+ class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
506
+ """Sequential sampler that can be resumed from a given start index."""
507
+
508
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
509
+ super().__init__(dataset, **kwargs)
510
+ # Start index, allows to resume an experiment at the index it was
511
+ self.start_idx = start_idx if start_idx is not None else 0
512
+
513
+ def __iter__(self):
514
+ for i, idx in enumerate(super().__iter__()):
515
+ if i >= self.start_idx:
516
+ yield idx
517
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
@@ -0,0 +1,81 @@
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from tqdm import tqdm
6
+
7
+ from ..core import AudioSignal
8
+
9
+
10
+ def create_csv(
11
+ audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
12
+ ):
13
+ """Converts a folder of audio files to a CSV file. If ``loudness = True``,
14
+ the output of this function will create a CSV file that looks something
15
+ like:
16
+
17
+ .. csv-table::
18
+ :header: path,loudness
19
+
20
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
21
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
22
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
23
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
24
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
25
+ daps/produced/f3_script1_produced.wav,-16.5
26
+
27
+ .. note::
28
+ The paths above are written relative to the ``data_path`` argument
29
+ which defaults to the environment variable ``PATH_TO_DATA`` if
30
+ it isn't passed to this function, and defaults to the empty string
31
+ if that environment variable is not set.
32
+
33
+ You can produce a CSV file from a directory of audio files via:
34
+
35
+ >>> import audiotools
36
+ >>> directory = ...
37
+ >>> audio_files = audiotools.util.find_audio(directory)
38
+ >>> output_path = "train.csv"
39
+ >>> audiotools.data.preprocess.create_csv(
40
+ >>> audio_files, output_csv, loudness=True
41
+ >>> )
42
+
43
+ Note that you can create empty rows in the CSV file by passing an empty
44
+ string or None in the ``audio_files`` list. This is useful if you want to
45
+ sync multiple CSV files in a multitrack setting. The loudness of these
46
+ empty rows will be set to -inf.
47
+
48
+ Parameters
49
+ ----------
50
+ audio_files : list
51
+ List of audio files.
52
+ output_csv : Path
53
+ Output CSV, with each row containing the relative path of every file
54
+ to ``data_path``, if specified (defaults to None).
55
+ loudness : bool
56
+ Compute loudness of entire file and store alongside path.
57
+ """
58
+
59
+ info = []
60
+ pbar = tqdm(audio_files)
61
+ for af in pbar:
62
+ af = Path(af)
63
+ pbar.set_description(f"Processing {af.name}")
64
+ _info = {}
65
+ if af.name == "":
66
+ _info["path"] = ""
67
+ if loudness:
68
+ _info["loudness"] = -float("inf")
69
+ else:
70
+ _info["path"] = af.relative_to(data_path) if data_path is not None else af
71
+ if loudness:
72
+ _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
73
+
74
+ info.append(_info)
75
+
76
+ with open(output_csv, "w") as f:
77
+ writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
78
+ writer.writeheader()
79
+
80
+ for item in info:
81
+ writer.writerow(item)