pytorch-ignite 0.6.0.dev20251103__py3-none-any.whl → 0.6.0.dev20260102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/horovod.py +8 -1
- ignite/distributed/comp_models/native.py +1 -0
- ignite/distributed/comp_models/xla.py +2 -0
- ignite/distributed/launcher.py +4 -8
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +9 -11
- ignite/engine/events.py +2 -1
- ignite/handlers/checkpoint.py +2 -2
- ignite/handlers/clearml_logger.py +2 -2
- ignite/handlers/lr_finder.py +10 -10
- ignite/handlers/neptune_logger.py +1 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +7 -2
- ignite/handlers/visdom_logger.py +2 -2
- ignite/handlers/wandb_logger.py +9 -8
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/metric.py +1 -0
- ignite/metrics/nlp/rouge.py +6 -3
- ignite/metrics/roc_auc.py +1 -0
- ignite/metrics/ssim.py +4 -0
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/METADATA +1 -1
- {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/RECORD +62 -62
- {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/WHEEL +1 -1
- {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py
CHANGED
ignite/contrib/engines/common.py
CHANGED
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/base_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to /ignite/metrics/average_precision.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/gpu_info.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.canberra_metric import CanberraMetric
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.fractional_bias import FractionalBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_error import MeanError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.r2_score import R2Score
|
|
@@ -9,9 +9,8 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
|
|
12
|
-
f" and will be removed in version {removed_in}"
|
|
13
|
-
|
|
14
|
-
else "" ".\n Please refer to the documentation for more details."
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
|
+
+ ".\n Please refer to the documentation for more details."
|
|
15
14
|
)
|
|
16
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
17
16
|
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
|
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
removed_in = "0.6.0"
|
|
10
10
|
deprecation_warning = (
|
|
11
11
|
f"{__file__} has been moved to ignite/metrics/roc_auc.py"
|
|
12
|
-
+
|
|
12
|
+
+ f" and will be removed in version {removed_in}"
|
|
13
13
|
+ ".\n Please refer to the documentation for more details."
|
|
14
14
|
)
|
|
15
15
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
ignite/distributed/auto.py
CHANGED
|
@@ -336,6 +336,7 @@ if idist.has_xla_support:
|
|
|
336
336
|
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
|
|
337
337
|
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
|
|
338
338
|
self._loader = loader
|
|
339
|
+
# pyrefly: ignore [read-only]
|
|
339
340
|
self._device = device
|
|
340
341
|
self._parallel_loader_kwargs = kwargs
|
|
341
342
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
|
|
3
|
+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -20,6 +20,11 @@ try:
|
|
|
20
20
|
except ImportError:
|
|
21
21
|
has_hvd_support = False
|
|
22
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
# Tell the type checker that hvd imports are always defined.
|
|
25
|
+
import horovod.torch as hvd
|
|
26
|
+
from horovod import run as hvd_mp_spawn
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
if has_hvd_support:
|
|
25
30
|
HOROVOD = "horovod"
|
|
@@ -70,6 +75,7 @@ if has_hvd_support:
|
|
|
70
75
|
self._init_from_context()
|
|
71
76
|
|
|
72
77
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
78
|
+
# pyrefly: ignore [bad-override]
|
|
73
79
|
self._backend: str = backend
|
|
74
80
|
comm = kwargs.get("comm", None)
|
|
75
81
|
hvd.init(comm=comm)
|
|
@@ -129,6 +135,7 @@ if has_hvd_support:
|
|
|
129
135
|
finalize()
|
|
130
136
|
|
|
131
137
|
@staticmethod
|
|
138
|
+
# pyrefly: ignore [bad-override]
|
|
132
139
|
def spawn(
|
|
133
140
|
fn: Callable,
|
|
134
141
|
args: Tuple,
|
|
@@ -54,6 +54,7 @@ if has_xla_support:
|
|
|
54
54
|
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
|
|
55
55
|
xm.rendezvous("init")
|
|
56
56
|
|
|
57
|
+
# pyrefly: ignore [bad-override]
|
|
57
58
|
self._backend: str = backend
|
|
58
59
|
self._setup_attrs()
|
|
59
60
|
|
|
@@ -106,6 +107,7 @@ if has_xla_support:
|
|
|
106
107
|
finalize()
|
|
107
108
|
|
|
108
109
|
@staticmethod
|
|
110
|
+
# pyrefly: ignore [bad-override]
|
|
109
111
|
def spawn(
|
|
110
112
|
fn: Callable,
|
|
111
113
|
args: Tuple,
|
ignite/distributed/launcher.py
CHANGED
|
@@ -322,19 +322,15 @@ class Parallel:
|
|
|
322
322
|
idist.initialize(self.backend, init_method=self.init_method)
|
|
323
323
|
|
|
324
324
|
# The logger can be setup from now since idist.initialize() has been called (if needed)
|
|
325
|
-
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
325
|
+
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
326
326
|
|
|
327
327
|
if self.backend is not None:
|
|
328
328
|
if self._spawn_params is None:
|
|
329
|
-
self._logger.info(
|
|
330
|
-
f"Initialized processing group with backend: '{self.backend}'"
|
|
331
|
-
)
|
|
329
|
+
self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
|
|
332
330
|
else:
|
|
333
|
-
self._logger.info(
|
|
334
|
-
f"Initialized distributed launcher with backend: '{self.backend}'"
|
|
335
|
-
)
|
|
331
|
+
self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
|
|
336
332
|
msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
|
|
337
|
-
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
333
|
+
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
338
334
|
|
|
339
335
|
return self
|
|
340
336
|
|
ignite/engine/deterministic.py
CHANGED
ignite/engine/engine.py
CHANGED
|
@@ -148,7 +148,7 @@ class Engine(Serializable):
|
|
|
148
148
|
self.should_interrupt = False
|
|
149
149
|
self.state = State()
|
|
150
150
|
self._state_dict_user_keys: List[str] = []
|
|
151
|
-
self._allowed_events: List[EventEnum] = []
|
|
151
|
+
self._allowed_events: List[Union[str, EventEnum]] = []
|
|
152
152
|
|
|
153
153
|
self._dataloader_iter: Optional[Iterator[Any]] = None
|
|
154
154
|
self._init_iter: Optional[int] = None
|
|
@@ -163,9 +163,7 @@ class Engine(Serializable):
|
|
|
163
163
|
# generator provided by self._internal_run_as_gen
|
|
164
164
|
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
|
|
165
165
|
|
|
166
|
-
def register_events(
|
|
167
|
-
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
|
|
168
|
-
) -> None:
|
|
166
|
+
def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
|
|
169
167
|
"""Add events that can be fired.
|
|
170
168
|
|
|
171
169
|
Registering an event will let the user trigger these events at any point.
|
|
@@ -450,7 +448,7 @@ class Engine(Serializable):
|
|
|
450
448
|
first, others = ((resolved_engine,), args[1:])
|
|
451
449
|
else:
|
|
452
450
|
# metrics do not provide engine when registered
|
|
453
|
-
first, others = (tuple(), args)
|
|
451
|
+
first, others = (tuple(), args)
|
|
454
452
|
|
|
455
453
|
func(*first, *(event_args + others), **kwargs)
|
|
456
454
|
|
|
@@ -989,9 +987,9 @@ class Engine(Serializable):
|
|
|
989
987
|
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
|
|
990
988
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
991
989
|
self._init_timers(self.state)
|
|
990
|
+
start_time = time.time()
|
|
992
991
|
try:
|
|
993
992
|
try:
|
|
994
|
-
start_time = time.time()
|
|
995
993
|
self._fire_event(Events.STARTED)
|
|
996
994
|
yield from self._maybe_terminate_or_interrupt()
|
|
997
995
|
|
|
@@ -1010,7 +1008,7 @@ class Engine(Serializable):
|
|
|
1010
1008
|
# time is available for handlers but must be updated after fire
|
|
1011
1009
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1012
1010
|
|
|
1013
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1011
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1014
1012
|
handlers_start_time = time.time()
|
|
1015
1013
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1016
1014
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1043,7 +1041,7 @@ class Engine(Serializable):
|
|
|
1043
1041
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1044
1042
|
|
|
1045
1043
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1046
|
-
if self.should_terminate != "skip_completed":
|
|
1044
|
+
if self.should_terminate != "skip_completed":
|
|
1047
1045
|
handlers_start_time = time.time()
|
|
1048
1046
|
self._fire_event(Events.COMPLETED)
|
|
1049
1047
|
time_taken += time.time() - handlers_start_time
|
|
@@ -1189,9 +1187,9 @@ class Engine(Serializable):
|
|
|
1189
1187
|
# internal_run without generator for BC
|
|
1190
1188
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
1191
1189
|
self._init_timers(self.state)
|
|
1190
|
+
start_time = time.time()
|
|
1192
1191
|
try:
|
|
1193
1192
|
try:
|
|
1194
|
-
start_time = time.time()
|
|
1195
1193
|
self._fire_event(Events.STARTED)
|
|
1196
1194
|
self._maybe_terminate_legacy()
|
|
1197
1195
|
|
|
@@ -1210,7 +1208,7 @@ class Engine(Serializable):
|
|
|
1210
1208
|
# time is available for handlers but must be updated after fire
|
|
1211
1209
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1212
1210
|
|
|
1213
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1211
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1214
1212
|
handlers_start_time = time.time()
|
|
1215
1213
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1216
1214
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1243,7 +1241,7 @@ class Engine(Serializable):
|
|
|
1243
1241
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1244
1242
|
|
|
1245
1243
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1246
|
-
if self.should_terminate != "skip_completed":
|
|
1244
|
+
if self.should_terminate != "skip_completed":
|
|
1247
1245
|
handlers_start_time = time.time()
|
|
1248
1246
|
self._fire_event(Events.COMPLETED)
|
|
1249
1247
|
time_taken += time.time() - handlers_start_time
|
ignite/engine/events.py
CHANGED
|
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
|
|
|
91
91
|
raise ValueError("Argument every should be integer and greater than zero")
|
|
92
92
|
|
|
93
93
|
if once is not None:
|
|
94
|
-
c1 = isinstance(once,
|
|
94
|
+
c1 = isinstance(once, int) and once > 0
|
|
95
95
|
c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
|
|
96
96
|
if not (c1 or c2):
|
|
97
97
|
raise ValueError(
|
|
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
|
|
|
240
240
|
def __new__(cls, value: str) -> "EventEnum":
|
|
241
241
|
obj = CallableEventWithFilter.__new__(cls)
|
|
242
242
|
obj._value_ = value
|
|
243
|
+
# pyrefly: ignore [bad-return]
|
|
243
244
|
return obj
|
|
244
245
|
|
|
245
246
|
|
ignite/handlers/checkpoint.py
CHANGED
|
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
|
|
|
315
315
|
"""
|
|
316
316
|
|
|
317
317
|
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
|
|
318
|
-
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
|
|
318
|
+
Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
|
|
319
319
|
_state_dict_all_req_keys = ("_saved",)
|
|
320
320
|
|
|
321
321
|
def __init__(
|
|
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
|
|
|
323
323
|
to_save: Mapping,
|
|
324
324
|
save_handler: Union[str, Path, Callable, BaseSaveHandler],
|
|
325
325
|
filename_prefix: str = "",
|
|
326
|
-
score_function: Optional[Callable] = None,
|
|
326
|
+
score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
|
|
327
327
|
score_name: Optional[str] = None,
|
|
328
328
|
n_saved: Union[int, None] = 1,
|
|
329
329
|
global_step_transform: Optional[Callable] = None,
|