pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (83) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/base.py +7 -0
  37. ignite/distributed/comp_models/horovod.py +35 -5
  38. ignite/distributed/comp_models/native.py +8 -4
  39. ignite/distributed/comp_models/xla.py +5 -0
  40. ignite/distributed/launcher.py +4 -8
  41. ignite/distributed/utils.py +12 -4
  42. ignite/engine/__init__.py +9 -9
  43. ignite/engine/deterministic.py +1 -1
  44. ignite/engine/engine.py +38 -14
  45. ignite/engine/events.py +2 -1
  46. ignite/handlers/__init__.py +2 -0
  47. ignite/handlers/base_logger.py +47 -12
  48. ignite/handlers/checkpoint.py +46 -5
  49. ignite/handlers/clearml_logger.py +16 -4
  50. ignite/handlers/fbresearch_logger.py +2 -2
  51. ignite/handlers/lr_finder.py +9 -9
  52. ignite/handlers/mlflow_logger.py +43 -0
  53. ignite/handlers/neptune_logger.py +8 -0
  54. ignite/handlers/param_scheduler.py +7 -3
  55. ignite/handlers/polyaxon_logger.py +7 -0
  56. ignite/handlers/state_param_scheduler.py +8 -2
  57. ignite/handlers/tensorboard_logger.py +43 -0
  58. ignite/handlers/time_profilers.py +6 -3
  59. ignite/handlers/tqdm_logger.py +9 -5
  60. ignite/handlers/visdom_logger.py +10 -3
  61. ignite/handlers/wandb_logger.py +16 -9
  62. ignite/metrics/accuracy.py +2 -0
  63. ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
  64. ignite/metrics/clustering/silhouette_score.py +1 -1
  65. ignite/metrics/fbeta.py +17 -8
  66. ignite/metrics/gan/fid.py +3 -3
  67. ignite/metrics/js_divergence.py +1 -1
  68. ignite/metrics/maximum_mean_discrepancy.py +1 -1
  69. ignite/metrics/metric.py +3 -0
  70. ignite/metrics/nlp/bleu.py +8 -6
  71. ignite/metrics/nlp/rouge.py +9 -6
  72. ignite/metrics/nlp/utils.py +1 -1
  73. ignite/metrics/precision_recall_curve.py +5 -5
  74. ignite/metrics/regression/_base.py +4 -0
  75. ignite/metrics/regression/fractional_bias.py +1 -1
  76. ignite/metrics/roc_auc.py +4 -3
  77. ignite/metrics/ssim.py +63 -21
  78. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  79. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
  80. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
  81. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
  82. pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
  83. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20250310"
9
+ __version__ = "0.6.0.dev20260101"
@@ -265,6 +265,7 @@ def _setup_common_distrib_training_handlers(
265
265
 
266
266
  @trainer.on(Events.EPOCH_STARTED)
267
267
  def distrib_set_epoch(engine: Engine) -> None:
268
+ # pyrefly: ignore [missing-attribute]
268
269
  train_sampler.set_epoch(engine.state.epoch - 1)
269
270
 
270
271
 
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/base_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/clearml_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/lr_finder.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/mlflow_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/neptune_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/param_scheduler.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/polyaxon_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tensorboard_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/time_profilers.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/tqdm_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/visdom_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/handlers/wandb_logger.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to /ignite/metrics/average_precision.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/cohen_kappa.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/gpu_info.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/precision_recall_curve.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/canberra_metric.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.canberra_metric import CanberraMetric
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_absolute_error import FractionalAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/fractional_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.fractional_bias import FractionalBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/geometric_mean_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/manhattan_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.manhattan_distance import ManhattanDistance
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/maximum_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_absolute_relative_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_error import MeanError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/mean_normalized_bias.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.mean_normalized_bias import MeanNormalizedBias
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_absolute_percentage_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/median_relative_absolute_error.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/r2_score.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.r2_score import R2Score
@@ -9,9 +9,8 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/regression/wave_hedges_distance.py"
12
- f" and will be removed in version {removed_in}"
13
- if removed_in
14
- else "" ".\n Please refer to the documentation for more details."
12
+ + f" and will be removed in version {removed_in}"
13
+ + ".\n Please refer to the documentation for more details."
15
14
  )
16
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
17
16
  from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
@@ -9,7 +9,7 @@ import warnings
9
9
  removed_in = "0.6.0"
10
10
  deprecation_warning = (
11
11
  f"{__file__} has been moved to ignite/metrics/roc_auc.py"
12
- + (f" and will be removed in version {removed_in}" if removed_in else "")
12
+ + f" and will be removed in version {removed_in}"
13
13
  + ".\n Please refer to the documentation for more details."
14
14
  )
15
15
  warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
@@ -336,6 +336,7 @@ if idist.has_xla_support:
336
336
  # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
337
337
  def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
338
338
  self._loader = loader
339
+ # pyrefly: ignore [read-only]
339
340
  self._device = device
340
341
  self._parallel_loader_kwargs = kwargs
341
342
 
@@ -298,6 +298,10 @@ class ComputationModel(metaclass=ABCMeta):
298
298
  def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
299
299
  pass
300
300
 
301
+ @abstractmethod
302
+ def _rank_not_in_group(self, group: Any) -> bool:
303
+ pass
304
+
301
305
 
302
306
  class _SerialModel(ComputationModel):
303
307
  """Private class defines non-distributed computation model for code compatibility with other distributed models."""
@@ -396,3 +400,6 @@ class _SerialModel(ComputationModel):
396
400
  return self._do_new_group(ranks, **kwargs)
397
401
  else:
398
402
  raise ValueError("Argument ranks should be list of int")
403
+
404
+ def _rank_not_in_group(self, group: Any) -> bool:
405
+ return False
@@ -1,5 +1,6 @@
1
+ import os
1
2
  import warnings
2
- from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
3
+ from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
3
4
 
4
5
  import torch
5
6
 
@@ -19,10 +20,18 @@ try:
19
20
  except ImportError:
20
21
  has_hvd_support = False
21
22
 
23
+ if TYPE_CHECKING:
24
+ # Tell the type checker that hvd imports are always defined.
25
+ import horovod.torch as hvd
26
+ from horovod import run as hvd_mp_spawn
27
+
22
28
 
23
29
  if has_hvd_support:
24
30
  HOROVOD = "horovod"
25
31
 
32
+ # Enables dynamic process sets: new_group methods and passing group into collective ops
33
+ os.environ["HOROVOD_DYNAMIC_PROCESS_SETS"] = "1"
34
+
26
35
  class _HorovodDistModel(ComputationModel):
27
36
  """Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""
28
37
 
@@ -66,6 +75,7 @@ if has_hvd_support:
66
75
  self._init_from_context()
67
76
 
68
77
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
78
+ # pyrefly: ignore [bad-override]
69
79
  self._backend: str = backend
70
80
  comm = kwargs.get("comm", None)
71
81
  hvd.init(comm=comm)
@@ -125,6 +135,7 @@ if has_hvd_support:
125
135
  finalize()
126
136
 
127
137
  @staticmethod
138
+ # pyrefly: ignore [bad-override]
128
139
  def spawn(
129
140
  fn: Callable,
130
141
  args: Tuple,
@@ -155,6 +166,15 @@ if has_hvd_support:
155
166
  **kwargs,
156
167
  )
157
168
 
169
+ def _setup_group(self, group: Any) -> hvd.ProcessSet:
170
+ if isinstance(group, list) and all(isinstance(item, int) for item in group):
171
+ group = self._do_new_group(group)
172
+ if not isinstance(group, hvd.ProcessSet):
173
+ raise ValueError(
174
+ f"Argument group should be list of int or hvd.ProcessSet, got {type(group)}, group={group}"
175
+ )
176
+ return group
177
+
158
178
  _reduce_op_map = {
159
179
  "SUM": hvd.mpi_ops.Sum,
160
180
  "AVERAGE": hvd.mpi_ops.Average,
@@ -187,10 +207,15 @@ if has_hvd_support:
187
207
 
188
208
  def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
189
209
  if group is not None:
190
- raise NotImplementedError("all_gather with group for horovod is not implemented")
210
+ group = self._setup_group(group)
211
+ if self._rank_not_in_group(group):
212
+ return tensor
191
213
  if tensor.ndimension() == 0:
192
214
  tensor = tensor.unsqueeze(0)
193
- return hvd.allgather(tensor)
215
+ if group is not None:
216
+ return hvd.allgather(tensor, process_set=group)
217
+ else:
218
+ return hvd.allgather(tensor)
194
219
 
195
220
  def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
196
221
  if group is not None:
@@ -198,8 +223,8 @@ if has_hvd_support:
198
223
 
199
224
  return hvd.allgather_object(tensor)
200
225
 
201
- def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
202
- return hvd.ProcessSet(ranks)
226
+ def _do_new_group(self, ranks: List[int], **kwargs: Any) -> hvd.ProcessSet:
227
+ return hvd.add_process_set(ranks)
203
228
 
204
229
  def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
205
230
  return hvd.broadcast(tensor, root_rank=src)
@@ -208,3 +233,8 @@ if has_hvd_support:
208
233
  # https://github.com/horovod/horovod/issues/159#issuecomment-424834603
209
234
  # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
210
235
  hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
236
+
237
+ def _rank_not_in_group(self, group: Optional[Any]) -> bool:
238
+ if group is None:
239
+ return False
240
+ return not group.included()
@@ -178,7 +178,7 @@ if has_native_dist_support:
178
178
  c: Counter = Counter(hostnames)
179
179
  sizes = torch.tensor([0] + list(c.values()))
180
180
  cumsum_sizes = torch.cumsum(sizes, dim=0)
181
- node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
181
+ node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
182
182
  local_rank = rank - cumsum_sizes[node_rank].item()
183
183
  return int(local_rank), node_rank
184
184
 
@@ -345,6 +345,7 @@ if has_native_dist_support:
345
345
  os.environ.update(copy_env_vars)
346
346
 
347
347
  @staticmethod
348
+ # pyrefly: ignore [bad-override]
348
349
  def spawn(
349
350
  fn: Callable,
350
351
  args: Tuple,
@@ -408,7 +409,7 @@ if has_native_dist_support:
408
409
  **spawn_kwargs,
409
410
  )
410
411
 
411
- def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
412
+ def _setup_group(self, group: Any) -> dist.ProcessGroup:
412
413
  if isinstance(group, list) and all(isinstance(item, int) for item in group):
413
414
  group = self._do_new_group(group)
414
415
  if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
@@ -442,7 +443,7 @@ if has_native_dist_support:
442
443
  def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
443
444
  if group is not None:
444
445
  group = self._setup_group(group)
445
- if group == dist.GroupMember.NON_GROUP_MEMBER:
446
+ if self._rank_not_in_group(group):
446
447
  return tensor
447
448
  if group is None:
448
449
  group_size = self.get_world_size()
@@ -466,7 +467,7 @@ if has_native_dist_support:
466
467
  )
467
468
  if group is not None:
468
469
  group = self._setup_group(group)
469
- if group == dist.GroupMember.NON_GROUP_MEMBER:
470
+ if self._rank_not_in_group(group):
470
471
  return tensor
471
472
  if group is None:
472
473
  group_size = self.get_world_size()
@@ -491,6 +492,9 @@ if has_native_dist_support:
491
492
  def barrier(self) -> None:
492
493
  dist.barrier()
493
494
 
495
+ def _rank_not_in_group(self, group: Optional[Any]) -> bool:
496
+ return dist._rank_not_in_group(group)
497
+
494
498
  def _expand_hostlist(nodelist: str) -> List[str]:
495
499
  """Expand a compressed hostlist string and returns all hosts listed.
496
500
 
@@ -54,6 +54,7 @@ if has_xla_support:
54
54
  def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
55
55
  xm.rendezvous("init")
56
56
 
57
+ # pyrefly: ignore [bad-override]
57
58
  self._backend: str = backend
58
59
  self._setup_attrs()
59
60
 
@@ -106,6 +107,7 @@ if has_xla_support:
106
107
  finalize()
107
108
 
108
109
  @staticmethod
110
+ # pyrefly: ignore [bad-override]
109
111
  def spawn(
110
112
  fn: Callable,
111
113
  args: Tuple,
@@ -175,3 +177,6 @@ if has_xla_support:
175
177
  if isinstance(group, list) and all(isinstance(item, int) for item in group):
176
178
  return True
177
179
  return False
180
+
181
+ def _rank_not_in_group(self, group: Any) -> bool:
182
+ return self.get_rank() not in group
@@ -322,19 +322,15 @@ class Parallel:
322
322
  idist.initialize(self.backend, init_method=self.init_method)
323
323
 
324
324
  # The logger can be setup from now since idist.initialize() has been called (if needed)
325
- self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment]
325
+ self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
326
326
 
327
327
  if self.backend is not None:
328
328
  if self._spawn_params is None:
329
- self._logger.info( # type: ignore[attr-defined]
330
- f"Initialized processing group with backend: '{self.backend}'"
331
- )
329
+ self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
332
330
  else:
333
- self._logger.info( # type: ignore[attr-defined]
334
- f"Initialized distributed launcher with backend: '{self.backend}'"
335
- )
331
+ self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
336
332
  msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
337
- self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined]
333
+ self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
338
334
 
339
335
  return self
340
336