xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +11 -28
- xinference/client/restful/async_restful_client.py +20 -3
- xinference/client/restful/restful_client.py +20 -3
- xinference/core/supervisor.py +87 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +38 -1
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +464 -2
- xinference/model/llm/sglang/core.py +30 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/utils.py +12 -9
- xinference/model/llm/vllm/core.py +93 -17
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch.nn.parallel import DataParallel
|
|
7
|
+
from torch.nn.parallel import DistributedDataParallel
|
|
8
|
+
|
|
9
|
+
from ..data.datasets import ResumableDistributedSampler as DistributedSampler
|
|
10
|
+
from ..data.datasets import ResumableSequentialSampler as SequentialSampler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Accelerator: # pragma: no cover
|
|
14
|
+
"""This class is used to prepare models and dataloaders for
|
|
15
|
+
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
|
|
16
|
+
prepare the respective objects. In the case of models, they are moved to
|
|
17
|
+
the appropriate GPU and SyncBatchNorm is applied to them. In the case of
|
|
18
|
+
dataloaders, a sampler is created and the dataloader is initialized with
|
|
19
|
+
that sampler.
|
|
20
|
+
|
|
21
|
+
If the world size is 1, prepare_model and prepare_dataloader are
|
|
22
|
+
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
|
|
23
|
+
script was launched without ``torchrun``, and ``DataParallel``
|
|
24
|
+
will be used instead of ``DistributedDataParallel`` (not recommended), if
|
|
25
|
+
the world size (number of GPUs) is greater than 1.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
amp : bool, optional
|
|
30
|
+
Whether or not to enable automatic mixed precision, by default False
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, amp: bool = False):
|
|
34
|
+
local_rank = os.getenv("LOCAL_RANK", None)
|
|
35
|
+
self.world_size = torch.cuda.device_count()
|
|
36
|
+
|
|
37
|
+
self.use_ddp = self.world_size > 1 and local_rank is not None
|
|
38
|
+
self.use_dp = self.world_size > 1 and local_rank is None
|
|
39
|
+
self.device = "cpu" if self.world_size == 0 else "cuda"
|
|
40
|
+
|
|
41
|
+
if self.use_ddp:
|
|
42
|
+
local_rank = int(local_rank)
|
|
43
|
+
dist.init_process_group(
|
|
44
|
+
"nccl",
|
|
45
|
+
init_method="env://",
|
|
46
|
+
world_size=self.world_size,
|
|
47
|
+
rank=local_rank,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
self.local_rank = 0 if local_rank is None else local_rank
|
|
51
|
+
self.amp = amp
|
|
52
|
+
|
|
53
|
+
class DummyScaler:
|
|
54
|
+
def __init__(self):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def step(self, optimizer):
|
|
58
|
+
optimizer.step()
|
|
59
|
+
|
|
60
|
+
def scale(self, loss):
|
|
61
|
+
return loss
|
|
62
|
+
|
|
63
|
+
def unscale_(self, optimizer):
|
|
64
|
+
return optimizer
|
|
65
|
+
|
|
66
|
+
def update(self):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
|
|
70
|
+
self.device_ctx = (
|
|
71
|
+
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def __enter__(self):
|
|
75
|
+
if self.device_ctx is not None:
|
|
76
|
+
self.device_ctx.__enter__()
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
80
|
+
if self.device_ctx is not None:
|
|
81
|
+
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
|
82
|
+
|
|
83
|
+
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
|
84
|
+
"""Prepares model for DDP or DP. The model is moved to
|
|
85
|
+
the device of the correct rank.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
model : torch.nn.Module
|
|
90
|
+
Model that is converted for DDP or DP.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
torch.nn.Module
|
|
95
|
+
Wrapped model, or original model if DDP and DP are turned off.
|
|
96
|
+
"""
|
|
97
|
+
model = model.to(self.device)
|
|
98
|
+
if self.use_ddp:
|
|
99
|
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
100
|
+
model = DistributedDataParallel(
|
|
101
|
+
model, device_ids=[self.local_rank], **kwargs
|
|
102
|
+
)
|
|
103
|
+
elif self.use_dp:
|
|
104
|
+
model = DataParallel(model, **kwargs)
|
|
105
|
+
return model
|
|
106
|
+
|
|
107
|
+
# Automatic mixed-precision utilities
|
|
108
|
+
def autocast(self, *args, **kwargs):
|
|
109
|
+
"""Context manager for autocasting. Arguments
|
|
110
|
+
go to ``torch.cuda.amp.autocast``.
|
|
111
|
+
"""
|
|
112
|
+
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def backward(self, loss: torch.Tensor):
|
|
115
|
+
"""Backwards pass, after scaling the loss if ``amp`` is
|
|
116
|
+
enabled.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
loss : torch.Tensor
|
|
121
|
+
Loss value.
|
|
122
|
+
"""
|
|
123
|
+
self.scaler.scale(loss).backward()
|
|
124
|
+
|
|
125
|
+
def step(self, optimizer: torch.optim.Optimizer):
|
|
126
|
+
"""Steps the optimizer, using a ``scaler`` if ``amp`` is
|
|
127
|
+
enabled.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
optimizer : torch.optim.Optimizer
|
|
132
|
+
Optimizer to step forward.
|
|
133
|
+
"""
|
|
134
|
+
self.scaler.step(optimizer)
|
|
135
|
+
|
|
136
|
+
def update(self):
|
|
137
|
+
"""Updates the scale factor."""
|
|
138
|
+
self.scaler.update()
|
|
139
|
+
|
|
140
|
+
def prepare_dataloader(
|
|
141
|
+
self, dataset: typing.Iterable, start_idx: int = None, **kwargs
|
|
142
|
+
):
|
|
143
|
+
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
|
|
144
|
+
enabled.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
dataset : typing.Iterable
|
|
149
|
+
Dataset to build Dataloader around.
|
|
150
|
+
start_idx : int, optional
|
|
151
|
+
Start index of sampler, useful if resuming from some epoch,
|
|
152
|
+
by default None
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
_type_
|
|
157
|
+
_description_
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
if self.use_ddp:
|
|
161
|
+
sampler = DistributedSampler(
|
|
162
|
+
dataset,
|
|
163
|
+
start_idx,
|
|
164
|
+
num_replicas=self.world_size,
|
|
165
|
+
rank=self.local_rank,
|
|
166
|
+
)
|
|
167
|
+
if "num_workers" in kwargs:
|
|
168
|
+
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
|
|
169
|
+
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
|
|
170
|
+
else:
|
|
171
|
+
sampler = SequentialSampler(dataset, start_idx)
|
|
172
|
+
|
|
173
|
+
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
|
|
174
|
+
return dataloader
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def unwrap(model):
|
|
178
|
+
"""Unwraps the model if it was wrapped in DDP or DP, otherwise
|
|
179
|
+
just returns the model. Use this to unwrap the model returned by
|
|
180
|
+
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
|
|
181
|
+
"""
|
|
182
|
+
if hasattr(model, "module"):
|
|
183
|
+
return model.module
|
|
184
|
+
return model
|
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from functools import wraps
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from rich import box
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.console import Group
|
|
12
|
+
from rich.live import Live
|
|
13
|
+
from rich.markdown import Markdown
|
|
14
|
+
from rich.padding import Padding
|
|
15
|
+
from rich.panel import Panel
|
|
16
|
+
from rich.progress import BarColumn
|
|
17
|
+
from rich.progress import Progress
|
|
18
|
+
from rich.progress import SpinnerColumn
|
|
19
|
+
from rich.progress import TimeElapsedColumn
|
|
20
|
+
from rich.progress import TimeRemainingColumn
|
|
21
|
+
from rich.rule import Rule
|
|
22
|
+
from rich.table import Table
|
|
23
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# This is here so that the history can be pickled.
|
|
27
|
+
def default_list():
|
|
28
|
+
return []
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Mean:
|
|
32
|
+
"""Keeps track of the running mean, along with the latest
|
|
33
|
+
value.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.reset()
|
|
38
|
+
|
|
39
|
+
def __call__(self):
|
|
40
|
+
mean = self.total / max(self.count, 1)
|
|
41
|
+
return mean
|
|
42
|
+
|
|
43
|
+
def reset(self):
|
|
44
|
+
self.count = 0
|
|
45
|
+
self.total = 0
|
|
46
|
+
|
|
47
|
+
def update(self, val):
|
|
48
|
+
if math.isfinite(val):
|
|
49
|
+
self.count += 1
|
|
50
|
+
self.total += val
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def when(condition):
|
|
54
|
+
"""Runs a function only when the condition is met. The condition is
|
|
55
|
+
a function that is run.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
condition : Callable
|
|
60
|
+
Function to run to check whether or not to run the decorated
|
|
61
|
+
function.
|
|
62
|
+
|
|
63
|
+
Example
|
|
64
|
+
-------
|
|
65
|
+
Checkpoint only runs every 100 iterations, and only if the
|
|
66
|
+
local rank is 0.
|
|
67
|
+
|
|
68
|
+
>>> i = 0
|
|
69
|
+
>>> rank = 0
|
|
70
|
+
>>>
|
|
71
|
+
>>> @when(lambda: i % 100 == 0 and rank == 0)
|
|
72
|
+
>>> def checkpoint():
|
|
73
|
+
>>> print("Saving to /runs/exp1")
|
|
74
|
+
>>>
|
|
75
|
+
>>> for i in range(1000):
|
|
76
|
+
>>> checkpoint()
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def decorator(fn):
|
|
81
|
+
@wraps(fn)
|
|
82
|
+
def decorated(*args, **kwargs):
|
|
83
|
+
if condition():
|
|
84
|
+
return fn(*args, **kwargs)
|
|
85
|
+
|
|
86
|
+
return decorated
|
|
87
|
+
|
|
88
|
+
return decorator
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def timer(prefix: str = "time"):
|
|
92
|
+
"""Adds execution time to the output dictionary of the decorated
|
|
93
|
+
function. The function decorated by this must output a dictionary.
|
|
94
|
+
The key added will follow the form "[prefix]/[name_of_function]"
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
prefix : str, optional
|
|
99
|
+
The key added will follow the form "[prefix]/[name_of_function]",
|
|
100
|
+
by default "time".
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def decorator(fn):
|
|
104
|
+
@wraps(fn)
|
|
105
|
+
def decorated(*args, **kwargs):
|
|
106
|
+
s = time.perf_counter()
|
|
107
|
+
output = fn(*args, **kwargs)
|
|
108
|
+
assert isinstance(output, dict)
|
|
109
|
+
e = time.perf_counter()
|
|
110
|
+
output[f"{prefix}/{fn.__name__}"] = e - s
|
|
111
|
+
return output
|
|
112
|
+
|
|
113
|
+
return decorated
|
|
114
|
+
|
|
115
|
+
return decorator
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class Tracker:
|
|
119
|
+
"""
|
|
120
|
+
A tracker class that helps to monitor the progress of training and logging the metrics.
|
|
121
|
+
|
|
122
|
+
Attributes
|
|
123
|
+
----------
|
|
124
|
+
metrics : dict
|
|
125
|
+
A dictionary containing the metrics for each label.
|
|
126
|
+
history : dict
|
|
127
|
+
A dictionary containing the history of metrics for each label.
|
|
128
|
+
writer : SummaryWriter
|
|
129
|
+
A SummaryWriter object for logging the metrics.
|
|
130
|
+
rank : int
|
|
131
|
+
The rank of the current process.
|
|
132
|
+
step : int
|
|
133
|
+
The current step of the training.
|
|
134
|
+
tasks : dict
|
|
135
|
+
A dictionary containing the progress bars and tables for each label.
|
|
136
|
+
pbar : Progress
|
|
137
|
+
A progress bar object for displaying the progress.
|
|
138
|
+
consoles : list
|
|
139
|
+
A list of console objects for logging.
|
|
140
|
+
live : Live
|
|
141
|
+
A Live object for updating the display live.
|
|
142
|
+
|
|
143
|
+
Methods
|
|
144
|
+
-------
|
|
145
|
+
print(msg: str)
|
|
146
|
+
Prints the given message to all consoles.
|
|
147
|
+
update(label: str, fn_name: str)
|
|
148
|
+
Updates the progress bar and table for the given label.
|
|
149
|
+
done(label: str, title: str)
|
|
150
|
+
Resets the progress bar and table for the given label and prints the final result.
|
|
151
|
+
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
|
|
152
|
+
A decorator for tracking the progress and metrics of a function.
|
|
153
|
+
log(label: str, value_type: str = "value", history: bool = True)
|
|
154
|
+
A decorator for logging the metrics of a function.
|
|
155
|
+
is_best(label: str, key: str) -> bool
|
|
156
|
+
Checks if the latest value of the given key in the label is the best so far.
|
|
157
|
+
state_dict() -> dict
|
|
158
|
+
Returns a dictionary containing the state of the tracker.
|
|
159
|
+
load_state_dict(state_dict: dict) -> Tracker
|
|
160
|
+
Loads the state of the tracker from the given state dictionary.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
writer: SummaryWriter = None,
|
|
166
|
+
log_file: str = None,
|
|
167
|
+
rank: int = 0,
|
|
168
|
+
console_width: int = 100,
|
|
169
|
+
step: int = 0,
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Initializes the Tracker object.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
writer : SummaryWriter, optional
|
|
177
|
+
A SummaryWriter object for logging the metrics, by default None.
|
|
178
|
+
log_file : str, optional
|
|
179
|
+
The path to the log file, by default None.
|
|
180
|
+
rank : int, optional
|
|
181
|
+
The rank of the current process, by default 0.
|
|
182
|
+
console_width : int, optional
|
|
183
|
+
The width of the console, by default 100.
|
|
184
|
+
step : int, optional
|
|
185
|
+
The current step of the training, by default 0.
|
|
186
|
+
"""
|
|
187
|
+
self.metrics = {}
|
|
188
|
+
self.history = {}
|
|
189
|
+
self.writer = writer
|
|
190
|
+
self.rank = rank
|
|
191
|
+
self.step = step
|
|
192
|
+
|
|
193
|
+
# Create progress bars etc.
|
|
194
|
+
self.tasks = {}
|
|
195
|
+
self.pbar = Progress(
|
|
196
|
+
SpinnerColumn(),
|
|
197
|
+
"[progress.description]{task.description}",
|
|
198
|
+
"{task.completed}/{task.total}",
|
|
199
|
+
BarColumn(),
|
|
200
|
+
TimeElapsedColumn(),
|
|
201
|
+
"/",
|
|
202
|
+
TimeRemainingColumn(),
|
|
203
|
+
)
|
|
204
|
+
self.consoles = [Console(width=console_width)]
|
|
205
|
+
self.live = Live(console=self.consoles[0], refresh_per_second=10)
|
|
206
|
+
if log_file is not None:
|
|
207
|
+
self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
|
|
208
|
+
|
|
209
|
+
def print(self, msg):
|
|
210
|
+
"""
|
|
211
|
+
Prints the given message to all consoles.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
msg : str
|
|
216
|
+
The message to be printed.
|
|
217
|
+
"""
|
|
218
|
+
if self.rank == 0:
|
|
219
|
+
for c in self.consoles:
|
|
220
|
+
c.log(msg)
|
|
221
|
+
|
|
222
|
+
def update(self, label, fn_name):
|
|
223
|
+
"""
|
|
224
|
+
Updates the progress bar and table for the given label.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
label : str
|
|
229
|
+
The label of the progress bar and table to be updated.
|
|
230
|
+
fn_name : str
|
|
231
|
+
The name of the function associated with the label.
|
|
232
|
+
"""
|
|
233
|
+
if self.rank == 0:
|
|
234
|
+
self.pbar.advance(self.tasks[label]["pbar"])
|
|
235
|
+
|
|
236
|
+
# Create table
|
|
237
|
+
table = Table(title=label, expand=True, box=box.MINIMAL)
|
|
238
|
+
table.add_column("key", style="cyan")
|
|
239
|
+
table.add_column("value", style="bright_blue")
|
|
240
|
+
table.add_column("mean", style="bright_green")
|
|
241
|
+
|
|
242
|
+
keys = self.metrics[label]["value"].keys()
|
|
243
|
+
for k in keys:
|
|
244
|
+
value = self.metrics[label]["value"][k]
|
|
245
|
+
mean = self.metrics[label]["mean"][k]()
|
|
246
|
+
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
|
|
247
|
+
|
|
248
|
+
self.tasks[label]["table"] = table
|
|
249
|
+
tables = [t["table"] for t in self.tasks.values()]
|
|
250
|
+
group = Group(*tables, self.pbar)
|
|
251
|
+
self.live.update(
|
|
252
|
+
Group(
|
|
253
|
+
Padding("", (0, 0)),
|
|
254
|
+
Rule(f"[italic]{fn_name}()", style="white"),
|
|
255
|
+
Padding("", (0, 0)),
|
|
256
|
+
Panel.fit(
|
|
257
|
+
group, padding=(0, 5), title="[b]Progress", border_style="blue"
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def done(self, label: str, title: str):
|
|
263
|
+
"""
|
|
264
|
+
Resets the progress bar and table for the given label and prints the final result.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
label : str
|
|
269
|
+
The label of the progress bar and table to be reset.
|
|
270
|
+
title : str
|
|
271
|
+
The title to be displayed when printing the final result.
|
|
272
|
+
"""
|
|
273
|
+
for label in self.metrics:
|
|
274
|
+
for v in self.metrics[label]["mean"].values():
|
|
275
|
+
v.reset()
|
|
276
|
+
|
|
277
|
+
if self.rank == 0:
|
|
278
|
+
self.pbar.reset(self.tasks[label]["pbar"])
|
|
279
|
+
tables = [t["table"] for t in self.tasks.values()]
|
|
280
|
+
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
|
|
281
|
+
self.print(group)
|
|
282
|
+
|
|
283
|
+
def track(
|
|
284
|
+
self,
|
|
285
|
+
label: str,
|
|
286
|
+
length: int,
|
|
287
|
+
completed: int = 0,
|
|
288
|
+
op: dist.ReduceOp = dist.ReduceOp.AVG,
|
|
289
|
+
ddp_active: bool = "LOCAL_RANK" in os.environ,
|
|
290
|
+
):
|
|
291
|
+
"""
|
|
292
|
+
A decorator for tracking the progress and metrics of a function.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
label : str
|
|
297
|
+
The label to be associated with the progress and metrics.
|
|
298
|
+
length : int
|
|
299
|
+
The total number of iterations to be completed.
|
|
300
|
+
completed : int, optional
|
|
301
|
+
The number of iterations already completed, by default 0.
|
|
302
|
+
op : dist.ReduceOp, optional
|
|
303
|
+
The reduce operation to be used, by default dist.ReduceOp.AVG.
|
|
304
|
+
ddp_active : bool, optional
|
|
305
|
+
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
|
|
306
|
+
"""
|
|
307
|
+
self.tasks[label] = {
|
|
308
|
+
"pbar": self.pbar.add_task(
|
|
309
|
+
f"[white]Iteration ({label})", total=length, completed=completed
|
|
310
|
+
),
|
|
311
|
+
"table": Table(),
|
|
312
|
+
}
|
|
313
|
+
self.metrics[label] = {
|
|
314
|
+
"value": defaultdict(),
|
|
315
|
+
"mean": defaultdict(lambda: Mean()),
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
def decorator(fn):
|
|
319
|
+
@wraps(fn)
|
|
320
|
+
def decorated(*args, **kwargs):
|
|
321
|
+
output = fn(*args, **kwargs)
|
|
322
|
+
if not isinstance(output, dict):
|
|
323
|
+
self.update(label, fn.__name__)
|
|
324
|
+
return output
|
|
325
|
+
# Collect across all DDP processes
|
|
326
|
+
scalar_keys = []
|
|
327
|
+
for k, v in output.items():
|
|
328
|
+
if isinstance(v, (int, float)):
|
|
329
|
+
v = torch.tensor([v])
|
|
330
|
+
if not torch.is_tensor(v):
|
|
331
|
+
continue
|
|
332
|
+
if ddp_active and v.is_cuda: # pragma: no cover
|
|
333
|
+
dist.all_reduce(v, op=op)
|
|
334
|
+
output[k] = v.detach()
|
|
335
|
+
if torch.numel(v) == 1:
|
|
336
|
+
scalar_keys.append(k)
|
|
337
|
+
output[k] = v.item()
|
|
338
|
+
|
|
339
|
+
# Save the outputs to tracker
|
|
340
|
+
for k, v in output.items():
|
|
341
|
+
if k not in scalar_keys:
|
|
342
|
+
continue
|
|
343
|
+
self.metrics[label]["value"][k] = v
|
|
344
|
+
# Update the running mean
|
|
345
|
+
self.metrics[label]["mean"][k].update(v)
|
|
346
|
+
|
|
347
|
+
self.update(label, fn.__name__)
|
|
348
|
+
return output
|
|
349
|
+
|
|
350
|
+
return decorated
|
|
351
|
+
|
|
352
|
+
return decorator
|
|
353
|
+
|
|
354
|
+
def log(self, label: str, value_type: str = "value", history: bool = True):
|
|
355
|
+
"""
|
|
356
|
+
A decorator for logging the metrics of a function.
|
|
357
|
+
|
|
358
|
+
Parameters
|
|
359
|
+
----------
|
|
360
|
+
label : str
|
|
361
|
+
The label to be associated with the logging.
|
|
362
|
+
value_type : str, optional
|
|
363
|
+
The type of value to be logged, by default "value".
|
|
364
|
+
history : bool, optional
|
|
365
|
+
Whether to save the history of the metrics, by default True.
|
|
366
|
+
"""
|
|
367
|
+
assert value_type in ["mean", "value"]
|
|
368
|
+
if history:
|
|
369
|
+
if label not in self.history:
|
|
370
|
+
self.history[label] = defaultdict(default_list)
|
|
371
|
+
|
|
372
|
+
def decorator(fn):
|
|
373
|
+
@wraps(fn)
|
|
374
|
+
def decorated(*args, **kwargs):
|
|
375
|
+
output = fn(*args, **kwargs)
|
|
376
|
+
if self.rank == 0:
|
|
377
|
+
nonlocal value_type, label
|
|
378
|
+
metrics = self.metrics[label][value_type]
|
|
379
|
+
for k, v in metrics.items():
|
|
380
|
+
v = v() if isinstance(v, Mean) else v
|
|
381
|
+
if self.writer is not None:
|
|
382
|
+
self.writer.add_scalar(f"{k}/{label}", v, self.step)
|
|
383
|
+
if label in self.history:
|
|
384
|
+
self.history[label][k].append(v)
|
|
385
|
+
|
|
386
|
+
if label in self.history:
|
|
387
|
+
self.history[label]["step"].append(self.step)
|
|
388
|
+
|
|
389
|
+
return output
|
|
390
|
+
|
|
391
|
+
return decorated
|
|
392
|
+
|
|
393
|
+
return decorator
|
|
394
|
+
|
|
395
|
+
def is_best(self, label, key):
|
|
396
|
+
"""
|
|
397
|
+
Checks if the latest value of the given key in the label is the best so far.
|
|
398
|
+
|
|
399
|
+
Parameters
|
|
400
|
+
----------
|
|
401
|
+
label : str
|
|
402
|
+
The label of the metrics to be checked.
|
|
403
|
+
key : str
|
|
404
|
+
The key of the metric to be checked.
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
bool
|
|
409
|
+
True if the latest value is the best so far, otherwise False.
|
|
410
|
+
"""
|
|
411
|
+
return self.history[label][key][-1] == min(self.history[label][key])
|
|
412
|
+
|
|
413
|
+
def state_dict(self):
|
|
414
|
+
"""
|
|
415
|
+
Returns a dictionary containing the state of the tracker.
|
|
416
|
+
|
|
417
|
+
Returns
|
|
418
|
+
-------
|
|
419
|
+
dict
|
|
420
|
+
A dictionary containing the history and step of the tracker.
|
|
421
|
+
"""
|
|
422
|
+
return {"history": self.history, "step": self.step}
|
|
423
|
+
|
|
424
|
+
def load_state_dict(self, state_dict):
|
|
425
|
+
"""
|
|
426
|
+
Loads the state of the tracker from the given state dictionary.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
state_dict : dict
|
|
431
|
+
A dictionary containing the history and step of the tracker.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
Tracker
|
|
436
|
+
The tracker object with the loaded state.
|
|
437
|
+
"""
|
|
438
|
+
self.history = state_dict["history"]
|
|
439
|
+
self.step = state_dict["step"]
|
|
440
|
+
return self
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Useful class for Experiment tracking, and ensuring code is
|
|
3
|
+
saved alongside files.
|
|
4
|
+
""" # fmt: skip
|
|
5
|
+
import datetime
|
|
6
|
+
import os
|
|
7
|
+
import shlex
|
|
8
|
+
import shutil
|
|
9
|
+
import subprocess
|
|
10
|
+
import typing
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import randomname
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Experiment:
|
|
17
|
+
"""This class contains utilities for managing experiments.
|
|
18
|
+
It is a context manager, that when you enter it, changes
|
|
19
|
+
your directory to a specified experiment folder (which
|
|
20
|
+
optionally can have an automatically generated experiment
|
|
21
|
+
name, or a specified one), and changes the CUDA device used
|
|
22
|
+
to the specified device (or devices).
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
exp_directory : str
|
|
27
|
+
Folder where all experiments are saved, by default "runs/".
|
|
28
|
+
exp_name : str, optional
|
|
29
|
+
Name of the experiment, by default uses the current time, date, and
|
|
30
|
+
hostname to save.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
exp_directory: str = "runs/",
|
|
36
|
+
exp_name: str = None,
|
|
37
|
+
):
|
|
38
|
+
if exp_name is None:
|
|
39
|
+
exp_name = self.generate_exp_name()
|
|
40
|
+
exp_dir = Path(exp_directory) / exp_name
|
|
41
|
+
exp_dir.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
|
|
43
|
+
self.exp_dir = exp_dir
|
|
44
|
+
self.exp_name = exp_name
|
|
45
|
+
self.git_tracked_files = (
|
|
46
|
+
subprocess.check_output(
|
|
47
|
+
shlex.split("git ls-tree --full-tree --name-only -r HEAD")
|
|
48
|
+
)
|
|
49
|
+
.decode("utf-8")
|
|
50
|
+
.splitlines()
|
|
51
|
+
)
|
|
52
|
+
self.parent_directory = Path(".").absolute()
|
|
53
|
+
|
|
54
|
+
def __enter__(self):
|
|
55
|
+
self.prev_dir = os.getcwd()
|
|
56
|
+
os.chdir(self.exp_dir)
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
60
|
+
os.chdir(self.prev_dir)
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def generate_exp_name():
|
|
64
|
+
"""Generates a random experiment name based on the date
|
|
65
|
+
and a randomly generated adjective-noun tuple.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
str
|
|
70
|
+
Randomly generated experiment name.
|
|
71
|
+
"""
|
|
72
|
+
date = datetime.datetime.now().strftime("%y%m%d")
|
|
73
|
+
name = f"{date}-{randomname.get_name()}"
|
|
74
|
+
return name
|
|
75
|
+
|
|
76
|
+
def snapshot(self, filter_fn: typing.Callable = lambda f: True):
|
|
77
|
+
"""Captures a full snapshot of all the files tracked by git at the time
|
|
78
|
+
the experiment is run. It also captures the diff against the committed
|
|
79
|
+
code as a separate file.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
filter_fn : typing.Callable, optional
|
|
84
|
+
Function that can be used to exclude some files
|
|
85
|
+
from the snapshot, by default accepts all files
|
|
86
|
+
"""
|
|
87
|
+
for f in self.git_tracked_files:
|
|
88
|
+
if filter_fn(f):
|
|
89
|
+
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
shutil.copyfile(self.parent_directory / f, f)
|