pytorch-ignite 0.6.0.dev20250927__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/horovod.py +8 -1
- ignite/distributed/comp_models/native.py +2 -1
- ignite/distributed/comp_models/xla.py +2 -0
- ignite/distributed/launcher.py +4 -8
- ignite/engine/__init__.py +9 -9
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +9 -11
- ignite/engine/events.py +2 -1
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/checkpoint.py +2 -2
- ignite/handlers/clearml_logger.py +2 -2
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +10 -10
- ignite/handlers/neptune_logger.py +1 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +7 -2
- ignite/handlers/visdom_logger.py +2 -2
- ignite/handlers/wandb_logger.py +9 -8
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/metric.py +1 -0
- ignite/metrics/nlp/rouge.py +6 -3
- ignite/metrics/roc_auc.py +1 -0
- ignite/metrics/ssim.py +4 -0
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +2 -2
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +65 -65
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -1
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py
CHANGED
ignite/contrib/engines/common.py
CHANGED
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/base_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/metrics/average_precision.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/gpu_info.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.canberra_metric import CanberraMetric
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_bias import FractionalBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_error import MeanError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.r2_score import R2Score
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/roc_auc.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
ignite/distributed/auto.py
CHANGED
|
@@ -336,6 +336,7 @@ if idist.has_xla_support:
|
|
|
336
336
|
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
|
|
337
337
|
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
|
|
338
338
|
self._loader = loader
|
|
339
|
+
# pyrefly: ignore [read-only]
|
|
339
340
|
self._device = device
|
|
340
341
|
self._parallel_loader_kwargs = kwargs
|
|
341
342
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
|
|
3
|
+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -20,6 +20,11 @@ try:
|
|
|
20
20
|
except ImportError:
|
|
21
21
|
has_hvd_support = False
|
|
22
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
# Tell the type checker that hvd imports are always defined.
|
|
25
|
+
import horovod.torch as hvd
|
|
26
|
+
from horovod import run as hvd_mp_spawn
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
if has_hvd_support:
|
|
25
30
|
HOROVOD = "horovod"
|
|
@@ -70,6 +75,7 @@ if has_hvd_support:
|
|
|
70
75
|
self._init_from_context()
|
|
71
76
|
|
|
72
77
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
78
|
+
# pyrefly: ignore [bad-override]
|
|
73
79
|
self._backend: str = backend
|
|
74
80
|
comm = kwargs.get("comm", None)
|
|
75
81
|
hvd.init(comm=comm)
|
|
@@ -129,6 +135,7 @@ if has_hvd_support:
|
|
|
129
135
|
finalize()
|
|
130
136
|
|
|
131
137
|
@staticmethod
|
|
138
|
+
# pyrefly: ignore [bad-override]
|
|
132
139
|
def spawn(
|
|
133
140
|
fn: Callable,
|
|
134
141
|
args: Tuple,
|
|
@@ -178,7 +178,7 @@ if has_native_dist_support:
|
|
|
178
178
|
c: Counter = Counter(hostnames)
|
|
179
179
|
sizes = torch.tensor([0] + list(c.values()))
|
|
180
180
|
cumsum_sizes = torch.cumsum(sizes, dim=0)
|
|
181
|
-
node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
|
|
181
|
+
node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
|
|
182
182
|
local_rank = rank - cumsum_sizes[node_rank].item()
|
|
183
183
|
return int(local_rank), node_rank
|
|
184
184
|
|
|
@@ -345,6 +345,7 @@ if has_native_dist_support:
|
|
|
345
345
|
os.environ.update(copy_env_vars)
|
|
346
346
|
|
|
347
347
|
@staticmethod
|
|
348
|
+
# pyrefly: ignore [bad-override]
|
|
348
349
|
def spawn(
|
|
349
350
|
fn: Callable,
|
|
350
351
|
args: Tuple,
|
|
@@ -54,6 +54,7 @@ if has_xla_support:
|
|
|
54
54
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
55
55
|
xm.rendezvous("init")
|
|
56
56
|
|
|
57
|
+
# pyrefly: ignore [bad-override]
|
|
57
58
|
self._backend: str = backend
|
|
58
59
|
self._setup_attrs()
|
|
59
60
|
|
|
@@ -106,6 +107,7 @@ if has_xla_support:
|
|
|
106
107
|
finalize()
|
|
107
108
|
|
|
108
109
|
@staticmethod
|
|
110
|
+
# pyrefly: ignore [bad-override]
|
|
109
111
|
def spawn(
|
|
110
112
|
fn: Callable,
|
|
111
113
|
args: Tuple,
|
ignite/distributed/launcher.py
CHANGED
|
@@ -322,19 +322,15 @@ class Parallel:
|
|
|
322
322
|
idist.initialize(self.backend, init_method=self.init_method)
|
|
323
323
|
|
|
324
324
|
# The logger can be setup from now since idist.initialize() has been called (if needed)
|
|
325
|
-
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
325
|
+
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
326
326
|
|
|
327
327
|
if self.backend is not None:
|
|
328
328
|
if self._spawn_params is None:
|
|
329
|
-
self._logger.info(
|
|
330
|
-
f"Initialized processing group with backend: '{self.backend}'"
|
|
331
|
-
)
|
|
329
|
+
self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
|
|
332
330
|
else:
|
|
333
|
-
self._logger.info(
|
|
334
|
-
f"Initialized distributed launcher with backend: '{self.backend}'"
|
|
335
|
-
)
|
|
331
|
+
self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
|
|
336
332
|
msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
|
|
337
|
-
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
333
|
+
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
338
334
|
|
|
339
335
|
return self
|
|
340
336
|
|
ignite/engine/__init__.py
CHANGED
|
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
|
|
|
133
133
|
prepare_batch: Callable = _prepare_batch,
|
|
134
134
|
model_transform: Callable[[Any], Any] = lambda output: output,
|
|
135
135
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
136
|
-
scaler: Optional["torch.
|
|
136
|
+
scaler: Optional["torch.amp.GradScaler"] = None,
|
|
137
137
|
gradient_accumulation_steps: int = 1,
|
|
138
138
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
139
139
|
) -> Callable:
|
|
140
|
-
"""Factory function for supervised training using ``torch.
|
|
140
|
+
"""Factory function for supervised training using ``torch.amp``.
|
|
141
141
|
|
|
142
142
|
Args:
|
|
143
143
|
model: the model to train.
|
|
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
|
|
|
170
170
|
model = ...
|
|
171
171
|
optimizer = ...
|
|
172
172
|
loss_fn = ...
|
|
173
|
-
scaler = torch.
|
|
173
|
+
scaler = torch.amp.GradScaler('cuda', 2**10)
|
|
174
174
|
|
|
175
175
|
update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
|
|
176
176
|
trainer = Engine(update_fn)
|
|
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
|
|
|
393
393
|
|
|
394
394
|
|
|
395
395
|
def _check_arg(
|
|
396
|
-
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.
|
|
397
|
-
) -> Tuple[Optional[str], Optional["torch.
|
|
396
|
+
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
|
|
397
|
+
) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
|
|
398
398
|
"""Checking tpu, mps, amp and GradScaler instance combinations."""
|
|
399
399
|
if on_mps and amp_mode:
|
|
400
400
|
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
|
|
@@ -410,9 +410,9 @@ def _check_arg(
|
|
|
410
410
|
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
|
|
411
411
|
elif amp_mode == "amp" and isinstance(scaler, bool):
|
|
412
412
|
try:
|
|
413
|
-
from torch.
|
|
413
|
+
from torch.amp import GradScaler
|
|
414
414
|
except ImportError:
|
|
415
|
-
raise ImportError("Please install torch>=
|
|
415
|
+
raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
|
|
416
416
|
scaler = GradScaler(enabled=True)
|
|
417
417
|
|
|
418
418
|
if on_tpu:
|
|
@@ -434,7 +434,7 @@ def create_supervised_trainer(
|
|
|
434
434
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
435
435
|
deterministic: bool = False,
|
|
436
436
|
amp_mode: Optional[str] = None,
|
|
437
|
-
scaler: Union[bool, "torch.
|
|
437
|
+
scaler: Union[bool, "torch.amp.GradScaler"] = False,
|
|
438
438
|
gradient_accumulation_steps: int = 1,
|
|
439
439
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
440
440
|
) -> Engine:
|
|
@@ -459,7 +459,7 @@ def create_supervised_trainer(
|
|
|
459
459
|
:class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
|
|
460
460
|
(default: False).
|
|
461
461
|
amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
|
|
462
|
-
`torch.
|
|
462
|
+
`torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
|
|
463
463
|
using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
|
|
464
464
|
scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
|
|
465
465
|
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
|
ignite/engine/deterministic.py
CHANGED