xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,184 @@
1
+ import os
2
+ import typing
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DataParallel
7
+ from torch.nn.parallel import DistributedDataParallel
8
+
9
+ from ..data.datasets import ResumableDistributedSampler as DistributedSampler
10
+ from ..data.datasets import ResumableSequentialSampler as SequentialSampler
11
+
12
+
13
+ class Accelerator: # pragma: no cover
14
+ """This class is used to prepare models and dataloaders for
15
+ usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
16
+ prepare the respective objects. In the case of models, they are moved to
17
+ the appropriate GPU and SyncBatchNorm is applied to them. In the case of
18
+ dataloaders, a sampler is created and the dataloader is initialized with
19
+ that sampler.
20
+
21
+ If the world size is 1, prepare_model and prepare_dataloader are
22
+ no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
23
+ script was launched without ``torchrun``, and ``DataParallel``
24
+ will be used instead of ``DistributedDataParallel`` (not recommended), if
25
+ the world size (number of GPUs) is greater than 1.
26
+
27
+ Parameters
28
+ ----------
29
+ amp : bool, optional
30
+ Whether or not to enable automatic mixed precision, by default False
31
+ """
32
+
33
+ def __init__(self, amp: bool = False):
34
+ local_rank = os.getenv("LOCAL_RANK", None)
35
+ self.world_size = torch.cuda.device_count()
36
+
37
+ self.use_ddp = self.world_size > 1 and local_rank is not None
38
+ self.use_dp = self.world_size > 1 and local_rank is None
39
+ self.device = "cpu" if self.world_size == 0 else "cuda"
40
+
41
+ if self.use_ddp:
42
+ local_rank = int(local_rank)
43
+ dist.init_process_group(
44
+ "nccl",
45
+ init_method="env://",
46
+ world_size=self.world_size,
47
+ rank=local_rank,
48
+ )
49
+
50
+ self.local_rank = 0 if local_rank is None else local_rank
51
+ self.amp = amp
52
+
53
+ class DummyScaler:
54
+ def __init__(self):
55
+ pass
56
+
57
+ def step(self, optimizer):
58
+ optimizer.step()
59
+
60
+ def scale(self, loss):
61
+ return loss
62
+
63
+ def unscale_(self, optimizer):
64
+ return optimizer
65
+
66
+ def update(self):
67
+ pass
68
+
69
+ self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
70
+ self.device_ctx = (
71
+ torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
72
+ )
73
+
74
+ def __enter__(self):
75
+ if self.device_ctx is not None:
76
+ self.device_ctx.__enter__()
77
+ return self
78
+
79
+ def __exit__(self, exc_type, exc_value, traceback):
80
+ if self.device_ctx is not None:
81
+ self.device_ctx.__exit__(exc_type, exc_value, traceback)
82
+
83
+ def prepare_model(self, model: torch.nn.Module, **kwargs):
84
+ """Prepares model for DDP or DP. The model is moved to
85
+ the device of the correct rank.
86
+
87
+ Parameters
88
+ ----------
89
+ model : torch.nn.Module
90
+ Model that is converted for DDP or DP.
91
+
92
+ Returns
93
+ -------
94
+ torch.nn.Module
95
+ Wrapped model, or original model if DDP and DP are turned off.
96
+ """
97
+ model = model.to(self.device)
98
+ if self.use_ddp:
99
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
100
+ model = DistributedDataParallel(
101
+ model, device_ids=[self.local_rank], **kwargs
102
+ )
103
+ elif self.use_dp:
104
+ model = DataParallel(model, **kwargs)
105
+ return model
106
+
107
+ # Automatic mixed-precision utilities
108
+ def autocast(self, *args, **kwargs):
109
+ """Context manager for autocasting. Arguments
110
+ go to ``torch.cuda.amp.autocast``.
111
+ """
112
+ return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
113
+
114
+ def backward(self, loss: torch.Tensor):
115
+ """Backwards pass, after scaling the loss if ``amp`` is
116
+ enabled.
117
+
118
+ Parameters
119
+ ----------
120
+ loss : torch.Tensor
121
+ Loss value.
122
+ """
123
+ self.scaler.scale(loss).backward()
124
+
125
+ def step(self, optimizer: torch.optim.Optimizer):
126
+ """Steps the optimizer, using a ``scaler`` if ``amp`` is
127
+ enabled.
128
+
129
+ Parameters
130
+ ----------
131
+ optimizer : torch.optim.Optimizer
132
+ Optimizer to step forward.
133
+ """
134
+ self.scaler.step(optimizer)
135
+
136
+ def update(self):
137
+ """Updates the scale factor."""
138
+ self.scaler.update()
139
+
140
+ def prepare_dataloader(
141
+ self, dataset: typing.Iterable, start_idx: int = None, **kwargs
142
+ ):
143
+ """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
144
+ enabled.
145
+
146
+ Parameters
147
+ ----------
148
+ dataset : typing.Iterable
149
+ Dataset to build Dataloader around.
150
+ start_idx : int, optional
151
+ Start index of sampler, useful if resuming from some epoch,
152
+ by default None
153
+
154
+ Returns
155
+ -------
156
+ _type_
157
+ _description_
158
+ """
159
+
160
+ if self.use_ddp:
161
+ sampler = DistributedSampler(
162
+ dataset,
163
+ start_idx,
164
+ num_replicas=self.world_size,
165
+ rank=self.local_rank,
166
+ )
167
+ if "num_workers" in kwargs:
168
+ kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
169
+ kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
170
+ else:
171
+ sampler = SequentialSampler(dataset, start_idx)
172
+
173
+ dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
174
+ return dataloader
175
+
176
+ @staticmethod
177
+ def unwrap(model):
178
+ """Unwraps the model if it was wrapped in DDP or DP, otherwise
179
+ just returns the model. Use this to unwrap the model returned by
180
+ :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
181
+ """
182
+ if hasattr(model, "module"):
183
+ return model.module
184
+ return model
@@ -0,0 +1,440 @@
1
+ import math
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from functools import wraps
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from rich import box
10
+ from rich.console import Console
11
+ from rich.console import Group
12
+ from rich.live import Live
13
+ from rich.markdown import Markdown
14
+ from rich.padding import Padding
15
+ from rich.panel import Panel
16
+ from rich.progress import BarColumn
17
+ from rich.progress import Progress
18
+ from rich.progress import SpinnerColumn
19
+ from rich.progress import TimeElapsedColumn
20
+ from rich.progress import TimeRemainingColumn
21
+ from rich.rule import Rule
22
+ from rich.table import Table
23
+ from torch.utils.tensorboard import SummaryWriter
24
+
25
+
26
+ # This is here so that the history can be pickled.
27
+ def default_list():
28
+ return []
29
+
30
+
31
+ class Mean:
32
+ """Keeps track of the running mean, along with the latest
33
+ value.
34
+ """
35
+
36
+ def __init__(self):
37
+ self.reset()
38
+
39
+ def __call__(self):
40
+ mean = self.total / max(self.count, 1)
41
+ return mean
42
+
43
+ def reset(self):
44
+ self.count = 0
45
+ self.total = 0
46
+
47
+ def update(self, val):
48
+ if math.isfinite(val):
49
+ self.count += 1
50
+ self.total += val
51
+
52
+
53
+ def when(condition):
54
+ """Runs a function only when the condition is met. The condition is
55
+ a function that is run.
56
+
57
+ Parameters
58
+ ----------
59
+ condition : Callable
60
+ Function to run to check whether or not to run the decorated
61
+ function.
62
+
63
+ Example
64
+ -------
65
+ Checkpoint only runs every 100 iterations, and only if the
66
+ local rank is 0.
67
+
68
+ >>> i = 0
69
+ >>> rank = 0
70
+ >>>
71
+ >>> @when(lambda: i % 100 == 0 and rank == 0)
72
+ >>> def checkpoint():
73
+ >>> print("Saving to /runs/exp1")
74
+ >>>
75
+ >>> for i in range(1000):
76
+ >>> checkpoint()
77
+
78
+ """
79
+
80
+ def decorator(fn):
81
+ @wraps(fn)
82
+ def decorated(*args, **kwargs):
83
+ if condition():
84
+ return fn(*args, **kwargs)
85
+
86
+ return decorated
87
+
88
+ return decorator
89
+
90
+
91
+ def timer(prefix: str = "time"):
92
+ """Adds execution time to the output dictionary of the decorated
93
+ function. The function decorated by this must output a dictionary.
94
+ The key added will follow the form "[prefix]/[name_of_function]"
95
+
96
+ Parameters
97
+ ----------
98
+ prefix : str, optional
99
+ The key added will follow the form "[prefix]/[name_of_function]",
100
+ by default "time".
101
+ """
102
+
103
+ def decorator(fn):
104
+ @wraps(fn)
105
+ def decorated(*args, **kwargs):
106
+ s = time.perf_counter()
107
+ output = fn(*args, **kwargs)
108
+ assert isinstance(output, dict)
109
+ e = time.perf_counter()
110
+ output[f"{prefix}/{fn.__name__}"] = e - s
111
+ return output
112
+
113
+ return decorated
114
+
115
+ return decorator
116
+
117
+
118
+ class Tracker:
119
+ """
120
+ A tracker class that helps to monitor the progress of training and logging the metrics.
121
+
122
+ Attributes
123
+ ----------
124
+ metrics : dict
125
+ A dictionary containing the metrics for each label.
126
+ history : dict
127
+ A dictionary containing the history of metrics for each label.
128
+ writer : SummaryWriter
129
+ A SummaryWriter object for logging the metrics.
130
+ rank : int
131
+ The rank of the current process.
132
+ step : int
133
+ The current step of the training.
134
+ tasks : dict
135
+ A dictionary containing the progress bars and tables for each label.
136
+ pbar : Progress
137
+ A progress bar object for displaying the progress.
138
+ consoles : list
139
+ A list of console objects for logging.
140
+ live : Live
141
+ A Live object for updating the display live.
142
+
143
+ Methods
144
+ -------
145
+ print(msg: str)
146
+ Prints the given message to all consoles.
147
+ update(label: str, fn_name: str)
148
+ Updates the progress bar and table for the given label.
149
+ done(label: str, title: str)
150
+ Resets the progress bar and table for the given label and prints the final result.
151
+ track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
152
+ A decorator for tracking the progress and metrics of a function.
153
+ log(label: str, value_type: str = "value", history: bool = True)
154
+ A decorator for logging the metrics of a function.
155
+ is_best(label: str, key: str) -> bool
156
+ Checks if the latest value of the given key in the label is the best so far.
157
+ state_dict() -> dict
158
+ Returns a dictionary containing the state of the tracker.
159
+ load_state_dict(state_dict: dict) -> Tracker
160
+ Loads the state of the tracker from the given state dictionary.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ writer: SummaryWriter = None,
166
+ log_file: str = None,
167
+ rank: int = 0,
168
+ console_width: int = 100,
169
+ step: int = 0,
170
+ ):
171
+ """
172
+ Initializes the Tracker object.
173
+
174
+ Parameters
175
+ ----------
176
+ writer : SummaryWriter, optional
177
+ A SummaryWriter object for logging the metrics, by default None.
178
+ log_file : str, optional
179
+ The path to the log file, by default None.
180
+ rank : int, optional
181
+ The rank of the current process, by default 0.
182
+ console_width : int, optional
183
+ The width of the console, by default 100.
184
+ step : int, optional
185
+ The current step of the training, by default 0.
186
+ """
187
+ self.metrics = {}
188
+ self.history = {}
189
+ self.writer = writer
190
+ self.rank = rank
191
+ self.step = step
192
+
193
+ # Create progress bars etc.
194
+ self.tasks = {}
195
+ self.pbar = Progress(
196
+ SpinnerColumn(),
197
+ "[progress.description]{task.description}",
198
+ "{task.completed}/{task.total}",
199
+ BarColumn(),
200
+ TimeElapsedColumn(),
201
+ "/",
202
+ TimeRemainingColumn(),
203
+ )
204
+ self.consoles = [Console(width=console_width)]
205
+ self.live = Live(console=self.consoles[0], refresh_per_second=10)
206
+ if log_file is not None:
207
+ self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
208
+
209
+ def print(self, msg):
210
+ """
211
+ Prints the given message to all consoles.
212
+
213
+ Parameters
214
+ ----------
215
+ msg : str
216
+ The message to be printed.
217
+ """
218
+ if self.rank == 0:
219
+ for c in self.consoles:
220
+ c.log(msg)
221
+
222
+ def update(self, label, fn_name):
223
+ """
224
+ Updates the progress bar and table for the given label.
225
+
226
+ Parameters
227
+ ----------
228
+ label : str
229
+ The label of the progress bar and table to be updated.
230
+ fn_name : str
231
+ The name of the function associated with the label.
232
+ """
233
+ if self.rank == 0:
234
+ self.pbar.advance(self.tasks[label]["pbar"])
235
+
236
+ # Create table
237
+ table = Table(title=label, expand=True, box=box.MINIMAL)
238
+ table.add_column("key", style="cyan")
239
+ table.add_column("value", style="bright_blue")
240
+ table.add_column("mean", style="bright_green")
241
+
242
+ keys = self.metrics[label]["value"].keys()
243
+ for k in keys:
244
+ value = self.metrics[label]["value"][k]
245
+ mean = self.metrics[label]["mean"][k]()
246
+ table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
247
+
248
+ self.tasks[label]["table"] = table
249
+ tables = [t["table"] for t in self.tasks.values()]
250
+ group = Group(*tables, self.pbar)
251
+ self.live.update(
252
+ Group(
253
+ Padding("", (0, 0)),
254
+ Rule(f"[italic]{fn_name}()", style="white"),
255
+ Padding("", (0, 0)),
256
+ Panel.fit(
257
+ group, padding=(0, 5), title="[b]Progress", border_style="blue"
258
+ ),
259
+ )
260
+ )
261
+
262
+ def done(self, label: str, title: str):
263
+ """
264
+ Resets the progress bar and table for the given label and prints the final result.
265
+
266
+ Parameters
267
+ ----------
268
+ label : str
269
+ The label of the progress bar and table to be reset.
270
+ title : str
271
+ The title to be displayed when printing the final result.
272
+ """
273
+ for label in self.metrics:
274
+ for v in self.metrics[label]["mean"].values():
275
+ v.reset()
276
+
277
+ if self.rank == 0:
278
+ self.pbar.reset(self.tasks[label]["pbar"])
279
+ tables = [t["table"] for t in self.tasks.values()]
280
+ group = Group(Markdown(f"# {title}"), *tables, self.pbar)
281
+ self.print(group)
282
+
283
+ def track(
284
+ self,
285
+ label: str,
286
+ length: int,
287
+ completed: int = 0,
288
+ op: dist.ReduceOp = dist.ReduceOp.AVG,
289
+ ddp_active: bool = "LOCAL_RANK" in os.environ,
290
+ ):
291
+ """
292
+ A decorator for tracking the progress and metrics of a function.
293
+
294
+ Parameters
295
+ ----------
296
+ label : str
297
+ The label to be associated with the progress and metrics.
298
+ length : int
299
+ The total number of iterations to be completed.
300
+ completed : int, optional
301
+ The number of iterations already completed, by default 0.
302
+ op : dist.ReduceOp, optional
303
+ The reduce operation to be used, by default dist.ReduceOp.AVG.
304
+ ddp_active : bool, optional
305
+ Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
306
+ """
307
+ self.tasks[label] = {
308
+ "pbar": self.pbar.add_task(
309
+ f"[white]Iteration ({label})", total=length, completed=completed
310
+ ),
311
+ "table": Table(),
312
+ }
313
+ self.metrics[label] = {
314
+ "value": defaultdict(),
315
+ "mean": defaultdict(lambda: Mean()),
316
+ }
317
+
318
+ def decorator(fn):
319
+ @wraps(fn)
320
+ def decorated(*args, **kwargs):
321
+ output = fn(*args, **kwargs)
322
+ if not isinstance(output, dict):
323
+ self.update(label, fn.__name__)
324
+ return output
325
+ # Collect across all DDP processes
326
+ scalar_keys = []
327
+ for k, v in output.items():
328
+ if isinstance(v, (int, float)):
329
+ v = torch.tensor([v])
330
+ if not torch.is_tensor(v):
331
+ continue
332
+ if ddp_active and v.is_cuda: # pragma: no cover
333
+ dist.all_reduce(v, op=op)
334
+ output[k] = v.detach()
335
+ if torch.numel(v) == 1:
336
+ scalar_keys.append(k)
337
+ output[k] = v.item()
338
+
339
+ # Save the outputs to tracker
340
+ for k, v in output.items():
341
+ if k not in scalar_keys:
342
+ continue
343
+ self.metrics[label]["value"][k] = v
344
+ # Update the running mean
345
+ self.metrics[label]["mean"][k].update(v)
346
+
347
+ self.update(label, fn.__name__)
348
+ return output
349
+
350
+ return decorated
351
+
352
+ return decorator
353
+
354
+ def log(self, label: str, value_type: str = "value", history: bool = True):
355
+ """
356
+ A decorator for logging the metrics of a function.
357
+
358
+ Parameters
359
+ ----------
360
+ label : str
361
+ The label to be associated with the logging.
362
+ value_type : str, optional
363
+ The type of value to be logged, by default "value".
364
+ history : bool, optional
365
+ Whether to save the history of the metrics, by default True.
366
+ """
367
+ assert value_type in ["mean", "value"]
368
+ if history:
369
+ if label not in self.history:
370
+ self.history[label] = defaultdict(default_list)
371
+
372
+ def decorator(fn):
373
+ @wraps(fn)
374
+ def decorated(*args, **kwargs):
375
+ output = fn(*args, **kwargs)
376
+ if self.rank == 0:
377
+ nonlocal value_type, label
378
+ metrics = self.metrics[label][value_type]
379
+ for k, v in metrics.items():
380
+ v = v() if isinstance(v, Mean) else v
381
+ if self.writer is not None:
382
+ self.writer.add_scalar(f"{k}/{label}", v, self.step)
383
+ if label in self.history:
384
+ self.history[label][k].append(v)
385
+
386
+ if label in self.history:
387
+ self.history[label]["step"].append(self.step)
388
+
389
+ return output
390
+
391
+ return decorated
392
+
393
+ return decorator
394
+
395
+ def is_best(self, label, key):
396
+ """
397
+ Checks if the latest value of the given key in the label is the best so far.
398
+
399
+ Parameters
400
+ ----------
401
+ label : str
402
+ The label of the metrics to be checked.
403
+ key : str
404
+ The key of the metric to be checked.
405
+
406
+ Returns
407
+ -------
408
+ bool
409
+ True if the latest value is the best so far, otherwise False.
410
+ """
411
+ return self.history[label][key][-1] == min(self.history[label][key])
412
+
413
+ def state_dict(self):
414
+ """
415
+ Returns a dictionary containing the state of the tracker.
416
+
417
+ Returns
418
+ -------
419
+ dict
420
+ A dictionary containing the history and step of the tracker.
421
+ """
422
+ return {"history": self.history, "step": self.step}
423
+
424
+ def load_state_dict(self, state_dict):
425
+ """
426
+ Loads the state of the tracker from the given state dictionary.
427
+
428
+ Parameters
429
+ ----------
430
+ state_dict : dict
431
+ A dictionary containing the history and step of the tracker.
432
+
433
+ Returns
434
+ -------
435
+ Tracker
436
+ The tracker object with the loaded state.
437
+ """
438
+ self.history = state_dict["history"]
439
+ self.step = state_dict["step"]
440
+ return self
@@ -0,0 +1,90 @@
1
+ """
2
+ Useful class for Experiment tracking, and ensuring code is
3
+ saved alongside files.
4
+ """ # fmt: skip
5
+ import datetime
6
+ import os
7
+ import shlex
8
+ import shutil
9
+ import subprocess
10
+ import typing
11
+ from pathlib import Path
12
+
13
+ import randomname
14
+
15
+
16
+ class Experiment:
17
+ """This class contains utilities for managing experiments.
18
+ It is a context manager, that when you enter it, changes
19
+ your directory to a specified experiment folder (which
20
+ optionally can have an automatically generated experiment
21
+ name, or a specified one), and changes the CUDA device used
22
+ to the specified device (or devices).
23
+
24
+ Parameters
25
+ ----------
26
+ exp_directory : str
27
+ Folder where all experiments are saved, by default "runs/".
28
+ exp_name : str, optional
29
+ Name of the experiment, by default uses the current time, date, and
30
+ hostname to save.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ exp_directory: str = "runs/",
36
+ exp_name: str = None,
37
+ ):
38
+ if exp_name is None:
39
+ exp_name = self.generate_exp_name()
40
+ exp_dir = Path(exp_directory) / exp_name
41
+ exp_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ self.exp_dir = exp_dir
44
+ self.exp_name = exp_name
45
+ self.git_tracked_files = (
46
+ subprocess.check_output(
47
+ shlex.split("git ls-tree --full-tree --name-only -r HEAD")
48
+ )
49
+ .decode("utf-8")
50
+ .splitlines()
51
+ )
52
+ self.parent_directory = Path(".").absolute()
53
+
54
+ def __enter__(self):
55
+ self.prev_dir = os.getcwd()
56
+ os.chdir(self.exp_dir)
57
+ return self
58
+
59
+ def __exit__(self, exc_type, exc_value, traceback):
60
+ os.chdir(self.prev_dir)
61
+
62
+ @staticmethod
63
+ def generate_exp_name():
64
+ """Generates a random experiment name based on the date
65
+ and a randomly generated adjective-noun tuple.
66
+
67
+ Returns
68
+ -------
69
+ str
70
+ Randomly generated experiment name.
71
+ """
72
+ date = datetime.datetime.now().strftime("%y%m%d")
73
+ name = f"{date}-{randomname.get_name()}"
74
+ return name
75
+
76
+ def snapshot(self, filter_fn: typing.Callable = lambda f: True):
77
+ """Captures a full snapshot of all the files tracked by git at the time
78
+ the experiment is run. It also captures the diff against the committed
79
+ code as a separate file.
80
+
81
+ Parameters
82
+ ----------
83
+ filter_fn : typing.Callable, optional
84
+ Function that can be used to exclude some files
85
+ from the snapshot, by default accepts all files
86
+ """
87
+ for f in self.git_tracked_files:
88
+ if filter_fn(f):
89
+ Path(f).parent.mkdir(parents=True, exist_ok=True)
90
+ shutil.copyfile(self.parent_directory / f, f)