xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,6 @@
1
+ """
2
+ Functions for comparing AudioSignal objects to one another.
3
+ """ # fmt: skip
4
+ from . import distance
5
+ from . import quality
6
+ from . import spectral
@@ -0,0 +1,131 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .. import AudioSignal
5
+
6
+
7
+ class L1Loss(nn.L1Loss):
8
+ """L1 Loss between AudioSignals. Defaults
9
+ to comparing ``audio_data``, but any
10
+ attribute of an AudioSignal can be used.
11
+
12
+ Parameters
13
+ ----------
14
+ attribute : str, optional
15
+ Attribute of signal to compare, defaults to ``audio_data``.
16
+ weight : float, optional
17
+ Weight of this loss, defaults to 1.0.
18
+ """
19
+
20
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
21
+ self.attribute = attribute
22
+ self.weight = weight
23
+ super().__init__(**kwargs)
24
+
25
+ def forward(self, x: AudioSignal, y: AudioSignal):
26
+ """
27
+ Parameters
28
+ ----------
29
+ x : AudioSignal
30
+ Estimate AudioSignal
31
+ y : AudioSignal
32
+ Reference AudioSignal
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ L1 loss between AudioSignal attributes.
38
+ """
39
+ if isinstance(x, AudioSignal):
40
+ x = getattr(x, self.attribute)
41
+ y = getattr(y, self.attribute)
42
+ return super().forward(x, y)
43
+
44
+
45
+ class SISDRLoss(nn.Module):
46
+ """
47
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
48
+ of estimated and reference audio signals or aligned features.
49
+
50
+ Parameters
51
+ ----------
52
+ scaling : int, optional
53
+ Whether to use scale-invariant (True) or
54
+ signal-to-noise ratio (False), by default True
55
+ reduction : str, optional
56
+ How to reduce across the batch (either 'mean',
57
+ 'sum', or none).], by default ' mean'
58
+ zero_mean : int, optional
59
+ Zero mean the references and estimates before
60
+ computing the loss, by default True
61
+ clip_min : int, optional
62
+ The minimum possible loss value. Helps network
63
+ to not focus on making already good examples better, by default None
64
+ weight : float, optional
65
+ Weight of this loss, defaults to 1.0.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ scaling: int = True,
71
+ reduction: str = "mean",
72
+ zero_mean: int = True,
73
+ clip_min: int = None,
74
+ weight: float = 1.0,
75
+ ):
76
+ self.scaling = scaling
77
+ self.reduction = reduction
78
+ self.zero_mean = zero_mean
79
+ self.clip_min = clip_min
80
+ self.weight = weight
81
+ super().__init__()
82
+
83
+ def forward(self, x: AudioSignal, y: AudioSignal):
84
+ eps = 1e-8
85
+ # nb, nc, nt
86
+ if isinstance(x, AudioSignal):
87
+ references = x.audio_data
88
+ estimates = y.audio_data
89
+ else:
90
+ references = x
91
+ estimates = y
92
+
93
+ nb = references.shape[0]
94
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
95
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
96
+
97
+ # samples now on axis 1
98
+ if self.zero_mean:
99
+ mean_reference = references.mean(dim=1, keepdim=True)
100
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
101
+ else:
102
+ mean_reference = 0
103
+ mean_estimate = 0
104
+
105
+ _references = references - mean_reference
106
+ _estimates = estimates - mean_estimate
107
+
108
+ references_projection = (_references**2).sum(dim=-2) + eps
109
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
110
+
111
+ scale = (
112
+ (references_on_estimates / references_projection).unsqueeze(1)
113
+ if self.scaling
114
+ else 1
115
+ )
116
+
117
+ e_true = scale * _references
118
+ e_res = _estimates - e_true
119
+
120
+ signal = (e_true**2).sum(dim=1)
121
+ noise = (e_res**2).sum(dim=1)
122
+ sdr = -10 * torch.log10(signal / noise + eps)
123
+
124
+ if self.clip_min is not None:
125
+ sdr = torch.clamp(sdr, min=self.clip_min)
126
+
127
+ if self.reduction == "mean":
128
+ sdr = sdr.mean()
129
+ elif self.reduction == "sum":
130
+ sdr = sdr.sum()
131
+ return sdr
@@ -0,0 +1,159 @@
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .. import AudioSignal
7
+
8
+
9
+ def stoi(
10
+ estimates: AudioSignal,
11
+ references: AudioSignal,
12
+ extended: int = False,
13
+ ):
14
+ """Short term objective intelligibility
15
+ Computes the STOI (See [1][2]) of a denoised signal compared to a clean
16
+ signal, The output is expected to have a monotonic relation with the
17
+ subjective speech-intelligibility, where a higher score denotes better
18
+ speech intelligibility. Uses pystoi under the hood.
19
+
20
+ Parameters
21
+ ----------
22
+ estimates : AudioSignal
23
+ Denoised speech
24
+ references : AudioSignal
25
+ Clean original speech
26
+ extended : int, optional
27
+ Boolean, whether to use the extended STOI described in [3], by default False
28
+
29
+ Returns
30
+ -------
31
+ Tensor[float]
32
+ Short time objective intelligibility measure between clean and
33
+ denoised speech
34
+
35
+ References
36
+ ----------
37
+ 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
38
+ Objective Intelligibility Measure for Time-Frequency Weighted Noisy
39
+ Speech', ICASSP 2010, Texas, Dallas.
40
+ 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
41
+ Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
42
+ IEEE Transactions on Audio, Speech, and Language Processing, 2011.
43
+ 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
44
+ Intelligibility of Speech Masked by Modulated Noise Maskers',
45
+ IEEE Transactions on Audio, Speech and Language Processing, 2016.
46
+ """
47
+ import pystoi
48
+
49
+ estimates = estimates.clone().to_mono()
50
+ references = references.clone().to_mono()
51
+
52
+ stois = []
53
+ for i in range(estimates.batch_size):
54
+ _stoi = pystoi.stoi(
55
+ references.audio_data[i, 0].detach().cpu().numpy(),
56
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
57
+ references.sample_rate,
58
+ extended=extended,
59
+ )
60
+ stois.append(_stoi)
61
+ return torch.from_numpy(np.array(stois))
62
+
63
+
64
+ def pesq(
65
+ estimates: AudioSignal,
66
+ references: AudioSignal,
67
+ mode: str = "wb",
68
+ target_sr: float = 16000,
69
+ ):
70
+ """_summary_
71
+
72
+ Parameters
73
+ ----------
74
+ estimates : AudioSignal
75
+ Degraded AudioSignal
76
+ references : AudioSignal
77
+ Reference AudioSignal
78
+ mode : str, optional
79
+ 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
80
+ target_sr : float, optional
81
+ Target sample rate, by default 16000
82
+
83
+ Returns
84
+ -------
85
+ Tensor[float]
86
+ PESQ score: P.862.2 Prediction (MOS-LQO)
87
+ """
88
+ from pesq import pesq as pesq_fn
89
+
90
+ estimates = estimates.clone().to_mono().resample(target_sr)
91
+ references = references.clone().to_mono().resample(target_sr)
92
+
93
+ pesqs = []
94
+ for i in range(estimates.batch_size):
95
+ _pesq = pesq_fn(
96
+ estimates.sample_rate,
97
+ references.audio_data[i, 0].detach().cpu().numpy(),
98
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
99
+ mode,
100
+ )
101
+ pesqs.append(_pesq)
102
+ return torch.from_numpy(np.array(pesqs))
103
+
104
+
105
+ def visqol(
106
+ estimates: AudioSignal,
107
+ references: AudioSignal,
108
+ mode: str = "audio",
109
+ ): # pragma: no cover
110
+ """ViSQOL score.
111
+
112
+ Parameters
113
+ ----------
114
+ estimates : AudioSignal
115
+ Degraded AudioSignal
116
+ references : AudioSignal
117
+ Reference AudioSignal
118
+ mode : str, optional
119
+ 'audio' or 'speech', by default 'audio'
120
+
121
+ Returns
122
+ -------
123
+ Tensor[float]
124
+ ViSQOL score (MOS-LQO)
125
+ """
126
+ from visqol import visqol_lib_py
127
+ from visqol.pb2 import visqol_config_pb2
128
+ from visqol.pb2 import similarity_result_pb2
129
+
130
+ config = visqol_config_pb2.VisqolConfig()
131
+ if mode == "audio":
132
+ target_sr = 48000
133
+ config.options.use_speech_scoring = False
134
+ svr_model_path = "libsvm_nu_svr_model.txt"
135
+ elif mode == "speech":
136
+ target_sr = 16000
137
+ config.options.use_speech_scoring = True
138
+ svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
139
+ else:
140
+ raise ValueError(f"Unrecognized mode: {mode}")
141
+ config.audio.sample_rate = target_sr
142
+ config.options.svr_model_path = os.path.join(
143
+ os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
144
+ )
145
+
146
+ api = visqol_lib_py.VisqolApi()
147
+ api.Create(config)
148
+
149
+ estimates = estimates.clone().to_mono().resample(target_sr)
150
+ references = references.clone().to_mono().resample(target_sr)
151
+
152
+ visqols = []
153
+ for i in range(estimates.batch_size):
154
+ _visqol = api.Measure(
155
+ references.audio_data[i, 0].detach().cpu().numpy().astype(float),
156
+ estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
157
+ )
158
+ visqols.append(_visqol.moslqo)
159
+ return torch.from_numpy(np.array(visqols))
@@ -0,0 +1,247 @@
1
+ import typing
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from torch import nn
6
+
7
+ from .. import AudioSignal
8
+ from .. import STFTParams
9
+
10
+
11
+ class MultiScaleSTFTLoss(nn.Module):
12
+ """Computes the multi-scale STFT loss from [1].
13
+
14
+ Parameters
15
+ ----------
16
+ window_lengths : List[int], optional
17
+ Length of each window of each STFT, by default [2048, 512]
18
+ loss_fn : typing.Callable, optional
19
+ How to compare each loss, by default nn.L1Loss()
20
+ clamp_eps : float, optional
21
+ Clamp on the log magnitude, below, by default 1e-5
22
+ mag_weight : float, optional
23
+ Weight of raw magnitude portion of loss, by default 1.0
24
+ log_weight : float, optional
25
+ Weight of log magnitude portion of loss, by default 1.0
26
+ pow : float, optional
27
+ Power to raise magnitude to before taking log, by default 2.0
28
+ weight : float, optional
29
+ Weight of this loss, by default 1.0
30
+ match_stride : bool, optional
31
+ Whether to match the stride of convolutional layers, by default False
32
+
33
+ References
34
+ ----------
35
+
36
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
37
+ "DDSP: Differentiable Digital Signal Processing."
38
+ International Conference on Learning Representations. 2019.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ window_lengths: List[int] = [2048, 512],
44
+ loss_fn: typing.Callable = nn.L1Loss(),
45
+ clamp_eps: float = 1e-5,
46
+ mag_weight: float = 1.0,
47
+ log_weight: float = 1.0,
48
+ pow: float = 2.0,
49
+ weight: float = 1.0,
50
+ match_stride: bool = False,
51
+ window_type: str = None,
52
+ ):
53
+ super().__init__()
54
+ self.stft_params = [
55
+ STFTParams(
56
+ window_length=w,
57
+ hop_length=w // 4,
58
+ match_stride=match_stride,
59
+ window_type=window_type,
60
+ )
61
+ for w in window_lengths
62
+ ]
63
+ self.loss_fn = loss_fn
64
+ self.log_weight = log_weight
65
+ self.mag_weight = mag_weight
66
+ self.clamp_eps = clamp_eps
67
+ self.weight = weight
68
+ self.pow = pow
69
+
70
+ def forward(self, x: AudioSignal, y: AudioSignal):
71
+ """Computes multi-scale STFT between an estimate and a reference
72
+ signal.
73
+
74
+ Parameters
75
+ ----------
76
+ x : AudioSignal
77
+ Estimate signal
78
+ y : AudioSignal
79
+ Reference signal
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ Multi-scale STFT loss.
85
+ """
86
+ loss = 0.0
87
+ for s in self.stft_params:
88
+ x.stft(s.window_length, s.hop_length, s.window_type)
89
+ y.stft(s.window_length, s.hop_length, s.window_type)
90
+ loss += self.log_weight * self.loss_fn(
91
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
92
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
93
+ )
94
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
95
+ return loss
96
+
97
+
98
+ class MelSpectrogramLoss(nn.Module):
99
+ """Compute distance between mel spectrograms. Can be used
100
+ in a multi-scale way.
101
+
102
+ Parameters
103
+ ----------
104
+ n_mels : List[int]
105
+ Number of mels per STFT, by default [150, 80],
106
+ window_lengths : List[int], optional
107
+ Length of each window of each STFT, by default [2048, 512]
108
+ loss_fn : typing.Callable, optional
109
+ How to compare each loss, by default nn.L1Loss()
110
+ clamp_eps : float, optional
111
+ Clamp on the log magnitude, below, by default 1e-5
112
+ mag_weight : float, optional
113
+ Weight of raw magnitude portion of loss, by default 1.0
114
+ log_weight : float, optional
115
+ Weight of log magnitude portion of loss, by default 1.0
116
+ pow : float, optional
117
+ Power to raise magnitude to before taking log, by default 2.0
118
+ weight : float, optional
119
+ Weight of this loss, by default 1.0
120
+ match_stride : bool, optional
121
+ Whether to match the stride of convolutional layers, by default False
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ n_mels: List[int] = [150, 80],
127
+ window_lengths: List[int] = [2048, 512],
128
+ loss_fn: typing.Callable = nn.L1Loss(),
129
+ clamp_eps: float = 1e-5,
130
+ mag_weight: float = 1.0,
131
+ log_weight: float = 1.0,
132
+ pow: float = 2.0,
133
+ weight: float = 1.0,
134
+ match_stride: bool = False,
135
+ mel_fmin: List[float] = [0.0, 0.0],
136
+ mel_fmax: List[float] = [None, None],
137
+ window_type: str = None,
138
+ ):
139
+ super().__init__()
140
+ self.stft_params = [
141
+ STFTParams(
142
+ window_length=w,
143
+ hop_length=w // 4,
144
+ match_stride=match_stride,
145
+ window_type=window_type,
146
+ )
147
+ for w in window_lengths
148
+ ]
149
+ self.n_mels = n_mels
150
+ self.loss_fn = loss_fn
151
+ self.clamp_eps = clamp_eps
152
+ self.log_weight = log_weight
153
+ self.mag_weight = mag_weight
154
+ self.weight = weight
155
+ self.mel_fmin = mel_fmin
156
+ self.mel_fmax = mel_fmax
157
+ self.pow = pow
158
+
159
+ def forward(self, x: AudioSignal, y: AudioSignal):
160
+ """Computes mel loss between an estimate and a reference
161
+ signal.
162
+
163
+ Parameters
164
+ ----------
165
+ x : AudioSignal
166
+ Estimate signal
167
+ y : AudioSignal
168
+ Reference signal
169
+
170
+ Returns
171
+ -------
172
+ torch.Tensor
173
+ Mel loss.
174
+ """
175
+ loss = 0.0
176
+ for n_mels, fmin, fmax, s in zip(
177
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
178
+ ):
179
+ kwargs = {
180
+ "window_length": s.window_length,
181
+ "hop_length": s.hop_length,
182
+ "window_type": s.window_type,
183
+ }
184
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
185
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
186
+
187
+ loss += self.log_weight * self.loss_fn(
188
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
189
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
190
+ )
191
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
192
+ return loss
193
+
194
+
195
+ class PhaseLoss(nn.Module):
196
+ """Difference between phase spectrograms.
197
+
198
+ Parameters
199
+ ----------
200
+ window_length : int, optional
201
+ Length of STFT window, by default 2048
202
+ hop_length : int, optional
203
+ Hop length of STFT window, by default 512
204
+ weight : float, optional
205
+ Weight of loss, by default 1.0
206
+ """
207
+
208
+ def __init__(
209
+ self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
210
+ ):
211
+ super().__init__()
212
+
213
+ self.weight = weight
214
+ self.stft_params = STFTParams(window_length, hop_length)
215
+
216
+ def forward(self, x: AudioSignal, y: AudioSignal):
217
+ """Computes phase loss between an estimate and a reference
218
+ signal.
219
+
220
+ Parameters
221
+ ----------
222
+ x : AudioSignal
223
+ Estimate signal
224
+ y : AudioSignal
225
+ Reference signal
226
+
227
+ Returns
228
+ -------
229
+ torch.Tensor
230
+ Phase loss.
231
+ """
232
+ s = self.stft_params
233
+ x.stft(s.window_length, s.hop_length, s.window_type)
234
+ y.stft(s.window_length, s.hop_length, s.window_type)
235
+
236
+ # Take circular difference
237
+ diff = x.phase - y.phase
238
+ diff[diff < -np.pi] += 2 * np.pi
239
+ diff[diff > np.pi] -= -2 * np.pi
240
+
241
+ # Scale true magnitude to weights in [0, 1]
242
+ x_min, x_max = x.magnitude.min(), x.magnitude.max()
243
+ weights = (x.magnitude - x_min) / (x_max - x_min)
244
+
245
+ # Take weighted mean of all phase errors
246
+ loss = ((weights * diff) ** 2).mean()
247
+ return loss
@@ -0,0 +1,5 @@
1
+ from . import decorators
2
+ from . import layers
3
+ from .accelerator import Accelerator
4
+ from .experiment import Experiment
5
+ from .layers import BaseModel