pytorch-ignite 0.6.0.dev20251103__py3-none-any.whl → 0.6.0.dev20260102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (62) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/horovod.py +8 -1
  37. ignite/distributed/comp_models/native.py +1 -0
  38. ignite/distributed/comp_models/xla.py +2 -0
  39. ignite/distributed/launcher.py +4 -8
  40. ignite/engine/deterministic.py +1 -1
  41. ignite/engine/engine.py +9 -11
  42. ignite/engine/events.py +2 -1
  43. ignite/handlers/checkpoint.py +2 -2
  44. ignite/handlers/clearml_logger.py +2 -2
  45. ignite/handlers/lr_finder.py +10 -10
  46. ignite/handlers/neptune_logger.py +1 -0
  47. ignite/handlers/param_scheduler.py +7 -3
  48. ignite/handlers/state_param_scheduler.py +8 -2
  49. ignite/handlers/time_profilers.py +6 -3
  50. ignite/handlers/tqdm_logger.py +7 -2
  51. ignite/handlers/visdom_logger.py +2 -2
  52. ignite/handlers/wandb_logger.py +9 -8
  53. ignite/metrics/accuracy.py +2 -0
  54. ignite/metrics/metric.py +1 -0
  55. ignite/metrics/nlp/rouge.py +6 -3
  56. ignite/metrics/roc_auc.py +1 -0
  57. ignite/metrics/ssim.py +4 -0
  58. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  59. {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/METADATA +1 -1
  60. {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/RECORD +62 -62
  61. {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/WHEEL +1 -1
  62. {pytorch_ignite-0.6.0.dev20251103.dist-info → pytorch_ignite-0.6.0.dev20260102.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20251103"
9
+ __version__ = "0.6.0.dev20260102"
@@ -265,6 +265,7 @@ def _setup_common_distrib_training_handlers(
265
265
 
266
266
  @trainer.on(Events.EPOCH_STARTED)
267
267
  def distrib_set_epoch(engine: Engine) -> None:
268
+ # pyrefly: ignore [missing-attribute]
268
269
  train_sampler.set_epoch(engine.state.epoch - 1)
269
270
 
270
271
 
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/base_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/metrics/average_precision.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/gpu_info.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.canberra_metric import CanberraMetric
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_bias import FractionalBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.manhattan_distance import ManhattanDistance
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_error import MeanError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.r2_score import R2Score
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/roc_auc.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -336,6 +336,7 @@ if idist.has_xla_support:
336
336
  # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
337
337
  def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
338
338
  self._loader = loader
339
+ # pyrefly: ignore [read-only]
339
340
  self._device = device
340
341
  self._parallel_loader_kwargs = kwargs
341
342
 
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
3
+ from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
4
4
 
5
5
  import torch
6
6
 
@@ -20,6 +20,11 @@ try:
20
20
  except ImportError:
21
21
  has_hvd_support = False
22
22
 
23
+ if TYPE_CHECKING:
24
+ # Tell the type checker that hvd imports are always defined.
25
+ import horovod.torch as hvd
26
+ from horovod import run as hvd_mp_spawn
27
+
23
28
 
24
29
  if has_hvd_support:
25
30
  HOROVOD = "horovod"
@@ -70,6 +75,7 @@ if has_hvd_support:
70
75
  self._init_from_context()
71
76
 
72
77
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
78
+ # pyrefly: ignore [bad-override]
73
79
  self._backend: str = backend
74
80
  comm = kwargs.get("comm", None)
75
81
  hvd.init(comm=comm)
@@ -129,6 +135,7 @@ if has_hvd_support:
129
135
  finalize()
130
136
 
131
137
  @staticmethod
138
+ # pyrefly: ignore [bad-override]
132
139
  def spawn(
133
140
  fn: Callable,
134
141
  args: Tuple,
@@ -345,6 +345,7 @@ if has_native_dist_support:
345
345
  os.environ.update(copy_env_vars)
346
346
 
347
347
  @staticmethod
348
+ # pyrefly: ignore [bad-override]
348
349
  def spawn(
349
350
  fn: Callable,
350
351
  args: Tuple,
@@ -54,6 +54,7 @@ if has_xla_support:
54
54
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
55
55
  xm.rendezvous("init")
56
56
 
57
+ # pyrefly: ignore [bad-override]
57
58
  self._backend: str = backend
58
59
  self._setup_attrs()
59
60
 
@@ -106,6 +107,7 @@ if has_xla_support:
106
107
  finalize()
107
108
 
108
109
  @staticmethod
110
+ # pyrefly: ignore [bad-override]
109
111
  def spawn(
110
112
  fn: Callable,
111
113
  args: Tuple,
@@ -322,19 +322,15 @@ class Parallel:
322
322
  idist.initialize(self.backend, init_method=self.init_method)
323
323
 
324
324
  # The logger can be setup from now since idist.initialize() has been called (if needed)
325
- self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment]
325
+ self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
326
326
 
327
327
  if self.backend is not None:
328
328
  if self._spawn_params is None:
329
- self._logger.info( # type: ignore[attr-defined]
330
- f"Initialized processing group with backend: '{self.backend}'"
331
- )
329
+ self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
332
330
  else:
333
- self._logger.info( # type: ignore[attr-defined]
334
- f"Initialized distributed launcher with backend: '{self.backend}'"
335
- )
331
+ self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
336
332
  msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
337
- self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined]
333
+ self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
338
334
 
339
335
  return self
340
336
 
@@ -85,7 +85,7 @@ class ReproducibleBatchSampler(BatchSampler):
85
85
 
86
86
 
87
87
  def _get_rng_states() -> List[Any]:
88
- output = [random.getstate(), torch.get_rng_state()]
88
+ output: List[Any] = [random.getstate(), torch.get_rng_state()]
89
89
  try:
90
90
  import numpy as np
91
91
 
ignite/engine/engine.py CHANGED
@@ -148,7 +148,7 @@ class Engine(Serializable):
148
148
  self.should_interrupt = False
149
149
  self.state = State()
150
150
  self._state_dict_user_keys: List[str] = []
151
- self._allowed_events: List[EventEnum] = []
151
+ self._allowed_events: List[Union[str, EventEnum]] = []
152
152
 
153
153
  self._dataloader_iter: Optional[Iterator[Any]] = None
154
154
  self._init_iter: Optional[int] = None
@@ -163,9 +163,7 @@ class Engine(Serializable):
163
163
  # generator provided by self._internal_run_as_gen
164
164
  self._internal_run_generator: Optional[Generator[Any, None, State]] = None
165
165
 
166
- def register_events(
167
- self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
168
- ) -> None:
166
+ def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
169
167
  """Add events that can be fired.
170
168
 
171
169
  Registering an event will let the user trigger these events at any point.
@@ -450,7 +448,7 @@ class Engine(Serializable):
450
448
  first, others = ((resolved_engine,), args[1:])
451
449
  else:
452
450
  # metrics do not provide engine when registered
453
- first, others = (tuple(), args) # type: ignore[assignment]
451
+ first, others = (tuple(), args)
454
452
 
455
453
  func(*first, *(event_args + others), **kwargs)
456
454
 
@@ -989,9 +987,9 @@ class Engine(Serializable):
989
987
  def _internal_run_as_gen(self) -> Generator[Any, None, State]:
990
988
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
991
989
  self._init_timers(self.state)
990
+ start_time = time.time()
992
991
  try:
993
992
  try:
994
- start_time = time.time()
995
993
  self._fire_event(Events.STARTED)
996
994
  yield from self._maybe_terminate_or_interrupt()
997
995
 
@@ -1010,7 +1008,7 @@ class Engine(Serializable):
1010
1008
  # time is available for handlers but must be updated after fire
1011
1009
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1012
1010
 
1013
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1011
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1014
1012
  handlers_start_time = time.time()
1015
1013
  self._fire_event(Events.EPOCH_COMPLETED)
1016
1014
  epoch_time_taken += time.time() - handlers_start_time
@@ -1043,7 +1041,7 @@ class Engine(Serializable):
1043
1041
  self.state.times[Events.COMPLETED.name] = time_taken
1044
1042
 
1045
1043
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1046
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1044
+ if self.should_terminate != "skip_completed":
1047
1045
  handlers_start_time = time.time()
1048
1046
  self._fire_event(Events.COMPLETED)
1049
1047
  time_taken += time.time() - handlers_start_time
@@ -1189,9 +1187,9 @@ class Engine(Serializable):
1189
1187
  # internal_run without generator for BC
1190
1188
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
1191
1189
  self._init_timers(self.state)
1190
+ start_time = time.time()
1192
1191
  try:
1193
1192
  try:
1194
- start_time = time.time()
1195
1193
  self._fire_event(Events.STARTED)
1196
1194
  self._maybe_terminate_legacy()
1197
1195
 
@@ -1210,7 +1208,7 @@ class Engine(Serializable):
1210
1208
  # time is available for handlers but must be updated after fire
1211
1209
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1212
1210
 
1213
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1211
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1214
1212
  handlers_start_time = time.time()
1215
1213
  self._fire_event(Events.EPOCH_COMPLETED)
1216
1214
  epoch_time_taken += time.time() - handlers_start_time
@@ -1243,7 +1241,7 @@ class Engine(Serializable):
1243
1241
  self.state.times[Events.COMPLETED.name] = time_taken
1244
1242
 
1245
1243
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1246
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1244
+ if self.should_terminate != "skip_completed":
1247
1245
  handlers_start_time = time.time()
1248
1246
  self._fire_event(Events.COMPLETED)
1249
1247
  time_taken += time.time() - handlers_start_time
ignite/engine/events.py CHANGED
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
91
91
  raise ValueError("Argument every should be integer and greater than zero")
92
92
 
93
93
  if once is not None:
94
- c1 = isinstance(once, numbers.Integral) and once > 0
94
+ c1 = isinstance(once, int) and once > 0
95
95
  c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
96
96
  if not (c1 or c2):
97
97
  raise ValueError(
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
240
240
  def __new__(cls, value: str) -> "EventEnum":
241
241
  obj = CallableEventWithFilter.__new__(cls)
242
242
  obj._value_ = value
243
+ # pyrefly: ignore [bad-return]
243
244
  return obj
244
245
 
245
246
 
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
315
315
  """
316
316
 
317
317
  SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
318
- Item = NamedTuple("Item", [("priority", int), ("filename", str)])
318
+ Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
319
319
  _state_dict_all_req_keys = ("_saved",)
320
320
 
321
321
  def __init__(
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
323
323
  to_save: Mapping,
324
324
  save_handler: Union[str, Path, Callable, BaseSaveHandler],
325
325
  filename_prefix: str = "",
326
- score_function: Optional[Callable] = None,
326
+ score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
327
327
  score_name: Optional[str] = None,
328
328
  n_saved: Union[int, None] = 1,
329
329
  global_step_transform: Optional[Callable] = None,