pytorch-ignite 0.6.0.dev20250927__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (65) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/horovod.py +8 -1
  37. ignite/distributed/comp_models/native.py +2 -1
  38. ignite/distributed/comp_models/xla.py +2 -0
  39. ignite/distributed/launcher.py +4 -8
  40. ignite/engine/__init__.py +9 -9
  41. ignite/engine/deterministic.py +1 -1
  42. ignite/engine/engine.py +9 -11
  43. ignite/engine/events.py +2 -1
  44. ignite/handlers/__init__.py +2 -0
  45. ignite/handlers/checkpoint.py +2 -2
  46. ignite/handlers/clearml_logger.py +2 -2
  47. ignite/handlers/fbresearch_logger.py +2 -2
  48. ignite/handlers/lr_finder.py +10 -10
  49. ignite/handlers/neptune_logger.py +1 -0
  50. ignite/handlers/param_scheduler.py +7 -3
  51. ignite/handlers/state_param_scheduler.py +8 -2
  52. ignite/handlers/time_profilers.py +6 -3
  53. ignite/handlers/tqdm_logger.py +7 -2
  54. ignite/handlers/visdom_logger.py +2 -2
  55. ignite/handlers/wandb_logger.py +9 -8
  56. ignite/metrics/accuracy.py +2 -0
  57. ignite/metrics/metric.py +1 -0
  58. ignite/metrics/nlp/rouge.py +6 -3
  59. ignite/metrics/roc_auc.py +1 -0
  60. ignite/metrics/ssim.py +4 -0
  61. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  62. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +2 -2
  63. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +65 -65
  64. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -1
  65. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20250927"
9
+ __version__ = "0.6.0.dev20260101"
@@ -265,6 +265,7 @@ def _setup_common_distrib_training_handlers(
265
265
 
266
266
  @trainer.on(Events.EPOCH_STARTED)
267
267
  def distrib_set_epoch(engine: Engine) -> None:
268
+ # pyrefly: ignore [missing-attribute]
268
269
  train_sampler.set_epoch(engine.state.epoch - 1)
269
270
 
270
271
 
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/base_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/metrics/average_precision.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/gpu_info.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.canberra_metric import CanberraMetric
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_bias import FractionalBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.manhattan_distance import ManhattanDistance
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_error import MeanError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.r2_score import R2Score
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/roc_auc.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -336,6 +336,7 @@ if idist.has_xla_support:
336
336
  # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
337
337
  def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
338
338
  self._loader = loader
339
+ # pyrefly: ignore [read-only]
339
340
  self._device = device
340
341
  self._parallel_loader_kwargs = kwargs
341
342
 
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
3
+ from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
4
4
 
5
5
  import torch
6
6
 
@@ -20,6 +20,11 @@ try:
20
20
  except ImportError:
21
21
  has_hvd_support = False
22
22
 
23
+ if TYPE_CHECKING:
24
+ # Tell the type checker that hvd imports are always defined.
25
+ import horovod.torch as hvd
26
+ from horovod import run as hvd_mp_spawn
27
+
23
28
 
24
29
  if has_hvd_support:
25
30
  HOROVOD = "horovod"
@@ -70,6 +75,7 @@ if has_hvd_support:
70
75
  self._init_from_context()
71
76
 
72
77
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
78
+ # pyrefly: ignore [bad-override]
73
79
  self._backend: str = backend
74
80
  comm = kwargs.get("comm", None)
75
81
  hvd.init(comm=comm)
@@ -129,6 +135,7 @@ if has_hvd_support:
129
135
  finalize()
130
136
 
131
137
  @staticmethod
138
+ # pyrefly: ignore [bad-override]
132
139
  def spawn(
133
140
  fn: Callable,
134
141
  args: Tuple,
@@ -178,7 +178,7 @@ if has_native_dist_support:
178
178
  c: Counter = Counter(hostnames)
179
179
  sizes = torch.tensor([0] + list(c.values()))
180
180
  cumsum_sizes = torch.cumsum(sizes, dim=0)
181
- node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
181
+ node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
182
182
  local_rank = rank - cumsum_sizes[node_rank].item()
183
183
  return int(local_rank), node_rank
184
184
 
@@ -345,6 +345,7 @@ if has_native_dist_support:
345
345
  os.environ.update(copy_env_vars)
346
346
 
347
347
  @staticmethod
348
+ # pyrefly: ignore [bad-override]
348
349
  def spawn(
349
350
  fn: Callable,
350
351
  args: Tuple,
@@ -54,6 +54,7 @@ if has_xla_support:
54
54
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
55
55
  xm.rendezvous("init")
56
56
 
57
+ # pyrefly: ignore [bad-override]
57
58
  self._backend: str = backend
58
59
  self._setup_attrs()
59
60
 
@@ -106,6 +107,7 @@ if has_xla_support:
106
107
  finalize()
107
108
 
108
109
  @staticmethod
110
+ # pyrefly: ignore [bad-override]
109
111
  def spawn(
110
112
  fn: Callable,
111
113
  args: Tuple,
@@ -322,19 +322,15 @@ class Parallel:
322
322
  idist.initialize(self.backend, init_method=self.init_method)
323
323
 
324
324
  # The logger can be setup from now since idist.initialize() has been called (if needed)
325
- self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment]
325
+ self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
326
326
 
327
327
  if self.backend is not None:
328
328
  if self._spawn_params is None:
329
- self._logger.info( # type: ignore[attr-defined]
330
- f"Initialized processing group with backend: '{self.backend}'"
331
- )
329
+ self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
332
330
  else:
333
- self._logger.info( # type: ignore[attr-defined]
334
- f"Initialized distributed launcher with backend: '{self.backend}'"
335
- )
331
+ self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
336
332
  msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
337
- self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined]
333
+ self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
338
334
 
339
335
  return self
340
336
 
ignite/engine/__init__.py CHANGED
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
133
133
  prepare_batch: Callable = _prepare_batch,
134
134
  model_transform: Callable[[Any], Any] = lambda output: output,
135
135
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
136
- scaler: Optional["torch.cuda.amp.GradScaler"] = None,
136
+ scaler: Optional["torch.amp.GradScaler"] = None,
137
137
  gradient_accumulation_steps: int = 1,
138
138
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
139
139
  ) -> Callable:
140
- """Factory function for supervised training using ``torch.cuda.amp``.
140
+ """Factory function for supervised training using ``torch.amp``.
141
141
 
142
142
  Args:
143
143
  model: the model to train.
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
170
170
  model = ...
171
171
  optimizer = ...
172
172
  loss_fn = ...
173
- scaler = torch.cuda.amp.GradScaler(2**10)
173
+ scaler = torch.amp.GradScaler('cuda', 2**10)
174
174
 
175
175
  update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
176
176
  trainer = Engine(update_fn)
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
393
393
 
394
394
 
395
395
  def _check_arg(
396
- on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
397
- ) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
396
+ on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
397
+ ) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
398
398
  """Checking tpu, mps, amp and GradScaler instance combinations."""
399
399
  if on_mps and amp_mode:
400
400
  raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
@@ -410,9 +410,9 @@ def _check_arg(
410
410
  raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
411
411
  elif amp_mode == "amp" and isinstance(scaler, bool):
412
412
  try:
413
- from torch.cuda.amp import GradScaler
413
+ from torch.amp import GradScaler
414
414
  except ImportError:
415
- raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
415
+ raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
416
416
  scaler = GradScaler(enabled=True)
417
417
 
418
418
  if on_tpu:
@@ -434,7 +434,7 @@ def create_supervised_trainer(
434
434
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
435
435
  deterministic: bool = False,
436
436
  amp_mode: Optional[str] = None,
437
- scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
437
+ scaler: Union[bool, "torch.amp.GradScaler"] = False,
438
438
  gradient_accumulation_steps: int = 1,
439
439
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
440
440
  ) -> Engine:
@@ -459,7 +459,7 @@ def create_supervised_trainer(
459
459
  :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
460
460
  (default: False).
461
461
  amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
462
- `torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
462
+ `torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
463
463
  using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
464
464
  scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
465
465
  and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
@@ -85,7 +85,7 @@ class ReproducibleBatchSampler(BatchSampler):
85
85
 
86
86
 
87
87
  def _get_rng_states() -> List[Any]:
88
- output = [random.getstate(), torch.get_rng_state()]
88
+ output: List[Any] = [random.getstate(), torch.get_rng_state()]
89
89
  try:
90
90
  import numpy as np
91
91