pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/base.py +7 -0
- ignite/distributed/comp_models/horovod.py +35 -5
- ignite/distributed/comp_models/native.py +8 -4
- ignite/distributed/comp_models/xla.py +5 -0
- ignite/distributed/launcher.py +4 -8
- ignite/distributed/utils.py +12 -4
- ignite/engine/__init__.py +9 -9
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +38 -14
- ignite/engine/events.py +2 -1
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/base_logger.py +47 -12
- ignite/handlers/checkpoint.py +46 -5
- ignite/handlers/clearml_logger.py +16 -4
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +9 -9
- ignite/handlers/mlflow_logger.py +43 -0
- ignite/handlers/neptune_logger.py +8 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/polyaxon_logger.py +7 -0
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/tensorboard_logger.py +43 -0
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +9 -5
- ignite/handlers/visdom_logger.py +10 -3
- ignite/handlers/wandb_logger.py +16 -9
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
- ignite/metrics/clustering/silhouette_score.py +1 -1
- ignite/metrics/fbeta.py +17 -8
- ignite/metrics/gan/fid.py +3 -3
- ignite/metrics/js_divergence.py +1 -1
- ignite/metrics/maximum_mean_discrepancy.py +1 -1
- ignite/metrics/metric.py +3 -0
- ignite/metrics/nlp/bleu.py +8 -6
- ignite/metrics/nlp/rouge.py +9 -6
- ignite/metrics/nlp/utils.py +1 -1
- ignite/metrics/precision_recall_curve.py +5 -5
- ignite/metrics/regression/_base.py +4 -0
- ignite/metrics/regression/fractional_bias.py +1 -1
- ignite/metrics/roc_auc.py +4 -3
- ignite/metrics/ssim.py +63 -21
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
- pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
ignite/__init__.py
CHANGED
ignite/contrib/engines/common.py
CHANGED
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/base_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/metrics/average_precision.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/gpu_info.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.canberra_metric import CanberraMetric
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_bias import FractionalBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_error import MeanError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.r2_score import R2Score
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/roc_auc.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
ignite/distributed/auto.py
CHANGED
|
@@ -336,6 +336,7 @@ if idist.has_xla_support:
|
|
|
336
336
|
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
|
|
337
337
|
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
|
|
338
338
|
self._loader = loader
|
|
339
|
+
# pyrefly: ignore [read-only]
|
|
339
340
|
self._device = device
|
|
340
341
|
self._parallel_loader_kwargs = kwargs
|
|
341
342
|
|
|
@@ -298,6 +298,10 @@ class ComputationModel(metaclass=ABCMeta):
|
|
|
298
298
|
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
|
|
299
299
|
pass
|
|
300
300
|
|
|
301
|
+
@abstractmethod
|
|
302
|
+
def _rank_not_in_group(self, group: Any) -> bool:
|
|
303
|
+
pass
|
|
304
|
+
|
|
301
305
|
|
|
302
306
|
class _SerialModel(ComputationModel):
|
|
303
307
|
"""Private class defines non-distributed computation model for code compatibility with other distributed models."""
|
|
@@ -396,3 +400,6 @@ class _SerialModel(ComputationModel):
|
|
|
396
400
|
return self._do_new_group(ranks, **kwargs)
|
|
397
401
|
else:
|
|
398
402
|
raise ValueError("Argument ranks should be list of int")
|
|
403
|
+
|
|
404
|
+
def _rank_not_in_group(self, group: Any) -> bool:
|
|
405
|
+
return False
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import warnings
|
|
2
|
-
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
|
|
3
|
+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
|
|
@@ -19,10 +20,18 @@ try:
|
|
|
19
20
|
except ImportError:
|
|
20
21
|
has_hvd_support = False
|
|
21
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
# Tell the type checker that hvd imports are always defined.
|
|
25
|
+
import horovod.torch as hvd
|
|
26
|
+
from horovod import run as hvd_mp_spawn
|
|
27
|
+
|
|
22
28
|
|
|
23
29
|
if has_hvd_support:
|
|
24
30
|
HOROVOD = "horovod"
|
|
25
31
|
|
|
32
|
+
# Enables dynamic process sets: new_group methods and passing group into collective ops
|
|
33
|
+
os.environ["HOROVOD_DYNAMIC_PROCESS_SETS"] = "1"
|
|
34
|
+
|
|
26
35
|
class _HorovodDistModel(ComputationModel):
|
|
27
36
|
"""Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""
|
|
28
37
|
|
|
@@ -66,6 +75,7 @@ if has_hvd_support:
|
|
|
66
75
|
self._init_from_context()
|
|
67
76
|
|
|
68
77
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
78
|
+
# pyrefly: ignore [bad-override]
|
|
69
79
|
self._backend: str = backend
|
|
70
80
|
comm = kwargs.get("comm", None)
|
|
71
81
|
hvd.init(comm=comm)
|
|
@@ -125,6 +135,7 @@ if has_hvd_support:
|
|
|
125
135
|
finalize()
|
|
126
136
|
|
|
127
137
|
@staticmethod
|
|
138
|
+
# pyrefly: ignore [bad-override]
|
|
128
139
|
def spawn(
|
|
129
140
|
fn: Callable,
|
|
130
141
|
args: Tuple,
|
|
@@ -155,6 +166,15 @@ if has_hvd_support:
|
|
|
155
166
|
**kwargs,
|
|
156
167
|
)
|
|
157
168
|
|
|
169
|
+
def _setup_group(self, group: Any) -> hvd.ProcessSet:
|
|
170
|
+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
|
|
171
|
+
group = self._do_new_group(group)
|
|
172
|
+
if not isinstance(group, hvd.ProcessSet):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"Argument group should be list of int or hvd.ProcessSet, got {type(group)}, group={group}"
|
|
175
|
+
)
|
|
176
|
+
return group
|
|
177
|
+
|
|
158
178
|
_reduce_op_map = {
|
|
159
179
|
"SUM": hvd.mpi_ops.Sum,
|
|
160
180
|
"AVERAGE": hvd.mpi_ops.Average,
|
|
@@ -187,10 +207,15 @@ if has_hvd_support:
|
|
|
187
207
|
|
|
188
208
|
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
|
|
189
209
|
if group is not None:
|
|
190
|
-
|
|
210
|
+
group = self._setup_group(group)
|
|
211
|
+
if self._rank_not_in_group(group):
|
|
212
|
+
return tensor
|
|
191
213
|
if tensor.ndimension() == 0:
|
|
192
214
|
tensor = tensor.unsqueeze(0)
|
|
193
|
-
|
|
215
|
+
if group is not None:
|
|
216
|
+
return hvd.allgather(tensor, process_set=group)
|
|
217
|
+
else:
|
|
218
|
+
return hvd.allgather(tensor)
|
|
194
219
|
|
|
195
220
|
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
|
|
196
221
|
if group is not None:
|
|
@@ -198,8 +223,8 @@ if has_hvd_support:
|
|
|
198
223
|
|
|
199
224
|
return hvd.allgather_object(tensor)
|
|
200
225
|
|
|
201
|
-
def _do_new_group(self, ranks: List[int], **kwargs: Any) ->
|
|
202
|
-
return hvd.
|
|
226
|
+
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> hvd.ProcessSet:
|
|
227
|
+
return hvd.add_process_set(ranks)
|
|
203
228
|
|
|
204
229
|
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
|
205
230
|
return hvd.broadcast(tensor, root_rank=src)
|
|
@@ -208,3 +233,8 @@ if has_hvd_support:
|
|
|
208
233
|
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
|
|
209
234
|
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
|
|
210
235
|
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
|
|
236
|
+
|
|
237
|
+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
|
|
238
|
+
if group is None:
|
|
239
|
+
return False
|
|
240
|
+
return not group.included()
|
|
@@ -178,7 +178,7 @@ if has_native_dist_support:
|
|
|
178
178
|
c: Counter = Counter(hostnames)
|
|
179
179
|
sizes = torch.tensor([0] + list(c.values()))
|
|
180
180
|
cumsum_sizes = torch.cumsum(sizes, dim=0)
|
|
181
|
-
node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
|
|
181
|
+
node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
|
|
182
182
|
local_rank = rank - cumsum_sizes[node_rank].item()
|
|
183
183
|
return int(local_rank), node_rank
|
|
184
184
|
|
|
@@ -345,6 +345,7 @@ if has_native_dist_support:
|
|
|
345
345
|
os.environ.update(copy_env_vars)
|
|
346
346
|
|
|
347
347
|
@staticmethod
|
|
348
|
+
# pyrefly: ignore [bad-override]
|
|
348
349
|
def spawn(
|
|
349
350
|
fn: Callable,
|
|
350
351
|
args: Tuple,
|
|
@@ -408,7 +409,7 @@ if has_native_dist_support:
|
|
|
408
409
|
**spawn_kwargs,
|
|
409
410
|
)
|
|
410
411
|
|
|
411
|
-
def _setup_group(self, group:
|
|
412
|
+
def _setup_group(self, group: Any) -> dist.ProcessGroup:
|
|
412
413
|
if isinstance(group, list) and all(isinstance(item, int) for item in group):
|
|
413
414
|
group = self._do_new_group(group)
|
|
414
415
|
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
|
|
@@ -442,7 +443,7 @@ if has_native_dist_support:
|
|
|
442
443
|
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
|
|
443
444
|
if group is not None:
|
|
444
445
|
group = self._setup_group(group)
|
|
445
|
-
if group
|
|
446
|
+
if self._rank_not_in_group(group):
|
|
446
447
|
return tensor
|
|
447
448
|
if group is None:
|
|
448
449
|
group_size = self.get_world_size()
|
|
@@ -466,7 +467,7 @@ if has_native_dist_support:
|
|
|
466
467
|
)
|
|
467
468
|
if group is not None:
|
|
468
469
|
group = self._setup_group(group)
|
|
469
|
-
if group
|
|
470
|
+
if self._rank_not_in_group(group):
|
|
470
471
|
return tensor
|
|
471
472
|
if group is None:
|
|
472
473
|
group_size = self.get_world_size()
|
|
@@ -491,6 +492,9 @@ if has_native_dist_support:
|
|
|
491
492
|
def barrier(self) -> None:
|
|
492
493
|
dist.barrier()
|
|
493
494
|
|
|
495
|
+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
|
|
496
|
+
return dist._rank_not_in_group(group)
|
|
497
|
+
|
|
494
498
|
def _expand_hostlist(nodelist: str) -> List[str]:
|
|
495
499
|
"""Expand a compressed hostlist string and returns all hosts listed.
|
|
496
500
|
|
|
@@ -54,6 +54,7 @@ if has_xla_support:
|
|
|
54
54
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
55
55
|
xm.rendezvous("init")
|
|
56
56
|
|
|
57
|
+
# pyrefly: ignore [bad-override]
|
|
57
58
|
self._backend: str = backend
|
|
58
59
|
self._setup_attrs()
|
|
59
60
|
|
|
@@ -106,6 +107,7 @@ if has_xla_support:
|
|
|
106
107
|
finalize()
|
|
107
108
|
|
|
108
109
|
@staticmethod
|
|
110
|
+
# pyrefly: ignore [bad-override]
|
|
109
111
|
def spawn(
|
|
110
112
|
fn: Callable,
|
|
111
113
|
args: Tuple,
|
|
@@ -175,3 +177,6 @@ if has_xla_support:
|
|
|
175
177
|
if isinstance(group, list) and all(isinstance(item, int) for item in group):
|
|
176
178
|
return True
|
|
177
179
|
return False
|
|
180
|
+
|
|
181
|
+
def _rank_not_in_group(self, group: Any) -> bool:
|
|
182
|
+
return self.get_rank() not in group
|
ignite/distributed/launcher.py
CHANGED
|
@@ -322,19 +322,15 @@ class Parallel:
|
|
|
322
322
|
idist.initialize(self.backend, init_method=self.init_method)
|
|
323
323
|
|
|
324
324
|
# The logger can be setup from now since idist.initialize() has been called (if needed)
|
|
325
|
-
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
325
|
+
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
326
326
|
|
|
327
327
|
if self.backend is not None:
|
|
328
328
|
if self._spawn_params is None:
|
|
329
|
-
self._logger.info(
|
|
330
|
-
f"Initialized processing group with backend: '{self.backend}'"
|
|
331
|
-
)
|
|
329
|
+
self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
|
|
332
330
|
else:
|
|
333
|
-
self._logger.info(
|
|
334
|
-
f"Initialized distributed launcher with backend: '{self.backend}'"
|
|
335
|
-
)
|
|
331
|
+
self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
|
|
336
332
|
msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
|
|
337
|
-
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
333
|
+
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
338
334
|
|
|
339
335
|
return self
|
|
340
336
|
|