pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (83) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/base.py +7 -0
  37. ignite/distributed/comp_models/horovod.py +35 -5
  38. ignite/distributed/comp_models/native.py +8 -4
  39. ignite/distributed/comp_models/xla.py +5 -0
  40. ignite/distributed/launcher.py +4 -8
  41. ignite/distributed/utils.py +12 -4
  42. ignite/engine/__init__.py +9 -9
  43. ignite/engine/deterministic.py +1 -1
  44. ignite/engine/engine.py +38 -14
  45. ignite/engine/events.py +2 -1
  46. ignite/handlers/__init__.py +2 -0
  47. ignite/handlers/base_logger.py +47 -12
  48. ignite/handlers/checkpoint.py +46 -5
  49. ignite/handlers/clearml_logger.py +16 -4
  50. ignite/handlers/fbresearch_logger.py +2 -2
  51. ignite/handlers/lr_finder.py +9 -9
  52. ignite/handlers/mlflow_logger.py +43 -0
  53. ignite/handlers/neptune_logger.py +8 -0
  54. ignite/handlers/param_scheduler.py +7 -3
  55. ignite/handlers/polyaxon_logger.py +7 -0
  56. ignite/handlers/state_param_scheduler.py +8 -2
  57. ignite/handlers/tensorboard_logger.py +43 -0
  58. ignite/handlers/time_profilers.py +6 -3
  59. ignite/handlers/tqdm_logger.py +9 -5
  60. ignite/handlers/visdom_logger.py +10 -3
  61. ignite/handlers/wandb_logger.py +16 -9
  62. ignite/metrics/accuracy.py +2 -0
  63. ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
  64. ignite/metrics/clustering/silhouette_score.py +1 -1
  65. ignite/metrics/fbeta.py +17 -8
  66. ignite/metrics/gan/fid.py +3 -3
  67. ignite/metrics/js_divergence.py +1 -1
  68. ignite/metrics/maximum_mean_discrepancy.py +1 -1
  69. ignite/metrics/metric.py +3 -0
  70. ignite/metrics/nlp/bleu.py +8 -6
  71. ignite/metrics/nlp/rouge.py +9 -6
  72. ignite/metrics/nlp/utils.py +1 -1
  73. ignite/metrics/precision_recall_curve.py +5 -5
  74. ignite/metrics/regression/_base.py +4 -0
  75. ignite/metrics/regression/fractional_bias.py +1 -1
  76. ignite/metrics/roc_auc.py +4 -3
  77. ignite/metrics/ssim.py +63 -21
  78. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  79. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
  80. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
  81. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
  82. pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
  83. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
@@ -153,6 +153,13 @@ class NeptuneLogger(BaseLogger):
153
153
  output_transform=lambda loss: {"loss": loss},
154
154
  )
155
155
 
156
+ Note:
157
+ :class:`~ignite.handlers.neptune_logger.OutputHandler` can handle
158
+ metrics, state attributes and engine output values of the following format:
159
+ - scalar values (i.e. int, float)
160
+ - 0d and 1d pytorch tensors
161
+ - dicts and list/tuples of previous types
162
+
156
163
  """
157
164
 
158
165
  def __getattr__(self, attr: Any) -> Any:
@@ -691,6 +698,7 @@ class NeptuneSaver(BaseSaveHandler):
691
698
  # hold onto the file stream for uploading.
692
699
  # NOTE: This won't load the whole file in memory and upload
693
700
  # the stream in smaller chunks.
701
+ # pyrefly: ignore [bad-argument-type]
694
702
  self._logger[filename].upload(File.from_stream(tmp.file))
695
703
 
696
704
  @idist.one_rank_only(with_barrier=True)
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
1122
1122
  f"but given {type(lr_scheduler)}"
1123
1123
  )
1124
1124
 
1125
- if not isinstance(warmup_duration, numbers.Integral):
1125
+ if not isinstance(warmup_duration, int):
1126
1126
  raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
1127
1127
 
1128
1128
  if not (warmup_duration > 1):
1129
1129
  raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
1130
1130
 
1131
1131
  warmup_schedulers: List[ParamScheduler] = []
1132
+ milestones_values: List[Tuple[int, float]] = []
1132
1133
 
1133
1134
  for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
1134
1135
  if warmup_end_value is None:
@@ -1176,6 +1177,7 @@ def create_lr_scheduler_with_warmup(
1176
1177
  lr_scheduler,
1177
1178
  ]
1178
1179
  durations = [milestones_values[-1][0] + 1]
1180
+ # pyrefly: ignore [bad-argument-type]
1179
1181
  combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
1180
1182
 
1181
1183
  if output_simulated_values is not None:
@@ -1185,6 +1187,7 @@ def create_lr_scheduler_with_warmup(
1185
1187
  f"but given {type(output_simulated_values)}."
1186
1188
  )
1187
1189
  num_events = len(output_simulated_values)
1190
+ # pyrefly: ignore [bad-argument-type]
1188
1191
  result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations)
1189
1192
  for i in range(num_events):
1190
1193
  output_simulated_values[i] = result[i]
@@ -1650,6 +1653,7 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1650
1653
  self.trainer = trainer
1651
1654
  self.optimizer = optimizer
1652
1655
 
1656
+ min_lr: Union[float, List[float]]
1653
1657
  if "min_lr" in scheduler_kwargs and param_group_index is not None:
1654
1658
  min_lr = scheduler_kwargs["min_lr"]
1655
1659
  if not isinstance(min_lr, float):
@@ -1670,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1670
1674
  _scheduler_kwargs["verbose"] = False
1671
1675
 
1672
1676
  self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
1673
- self.scheduler._reduce_lr = self._reduce_lr # type: ignore[method-assign]
1677
+ self.scheduler._reduce_lr = self._reduce_lr
1674
1678
 
1675
1679
  self._state_attrs += ["metric_name", "scheduler"]
1676
1680
 
1677
- def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override]
1681
+ def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
1678
1682
  if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
1679
1683
  raise ValueError(
1680
1684
  "Argument engine should have in its 'state', attribute 'metrics' "
@@ -92,6 +92,13 @@ class PolyaxonLogger(BaseLogger):
92
92
  )
93
93
  # to manually end a run
94
94
  plx_logger.close()
95
+
96
+ Note:
97
+ :class:`~ignite.handlers.polyaxon_logger.OutputHandler` can handle
98
+ metrics, state attributes and engine output values of the following format:
99
+ - scalar values (i.e. int, float)
100
+ - 0d and 1d pytorch tensors
101
+ - dicts and list/tuples of previous types
95
102
  """
96
103
 
97
104
  def __init__(self, *args: Any, **kwargs: Any):
@@ -1,7 +1,7 @@
1
1
  import numbers
2
2
  import warnings
3
3
  from bisect import bisect_right
4
- from typing import Any, List, Sequence, Tuple, Union
4
+ from typing import Any, Callable, List, Sequence, Tuple, Union
5
5
 
6
6
  from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
7
7
  from ignite.handlers.param_scheduler import BaseParamScheduler
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
183
183
 
184
184
  """
185
185
 
186
- def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
186
+ def __init__(
187
+ self,
188
+ lambda_obj: Callable[[int], Union[List[float], float]],
189
+ param_name: str,
190
+ save_history: bool = False,
191
+ create_new: bool = False,
192
+ ):
187
193
  super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
188
194
 
189
195
  if not callable(lambda_obj):
@@ -145,6 +145,49 @@ class TensorboardLogger(BaseLogger):
145
145
  output_transform=lambda loss: {"loss": loss}
146
146
  )
147
147
 
148
+ Note:
149
+ :class:`~ignite.handlers.tensorboard_logger.OutputHandler` can handle
150
+ metrics, state attributes and engine output values of the following format:
151
+ - scalar values (i.e. int, float)
152
+ - 0d and 1d pytorch tensors
153
+ - dicts and list/tuples of previous types
154
+
155
+ .. code-block:: python
156
+
157
+ # !!! This is not a runnable code !!!
158
+ evalutator.state.metrics = {
159
+ "a": 0,
160
+ "dict_value": {
161
+ "a": 111,
162
+ "c": {"d": 23, "e": [123, 234]},
163
+ },
164
+ "list_value": [12, 13, {"aa": 33, "bb": 44}],
165
+ "tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
166
+ }
167
+
168
+ handler = OutputHandler(
169
+ tag="tag",
170
+ metric_names="all",
171
+ )
172
+
173
+ handler(evaluator, tb_logger, event_name=Events.EPOCH_COMPLETED)
174
+ # Behind it would call `tb_logger.writer.add_scalar` on
175
+ # {
176
+ # "tag/a": 0,
177
+ # "tag/dict_value/a": 111,
178
+ # "tag/dict_value/c/d": 23,
179
+ # "tag/dict_value/c/e/0": 123,
180
+ # "tag/dict_value/c/e/1": 234,
181
+ # "tag/list_value/0": 12,
182
+ # "tag/list_value/1": 13,
183
+ # "tag/list_value/2/aa": 33,
184
+ # "tag/list_value/2/bb": 44,
185
+ # "tag/tuple_value/0": 112,
186
+ # "tag/tuple_value/1": 113,
187
+ # "tag/tuple_value/2/aaa": 33,
188
+ # "tag/tuple_value/2/bbb": 44,
189
+ # }
190
+
148
191
  """
149
192
 
150
193
  def __init__(self, *args: Any, **kwargs: Any):
@@ -251,6 +251,7 @@ class BasicTimeProfiler:
251
251
  total_eh_time: Union[int, torch.Tensor] = sum(
252
252
  [(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]
253
253
  )
254
+ # pyrefly: ignore [no-matching-overload]
254
255
  event_handlers_stats = dict(
255
256
  [
256
257
  (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e]))
@@ -334,6 +335,7 @@ class BasicTimeProfiler:
334
335
 
335
336
  results_df = pd.DataFrame(
336
337
  data=results_dump,
338
+ # pyrefly: ignore [bad-argument-type]
337
339
  columns=[
338
340
  "epoch",
339
341
  "iteration",
@@ -498,14 +500,14 @@ class HandlersTimeProfiler:
498
500
 
499
501
  self.dataflow_times: List[float] = []
500
502
  self.processing_times: List[float] = []
501
- self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
503
+ self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
502
504
 
503
505
  @staticmethod
504
506
  def _get_callable_name(handler: Callable) -> str:
505
507
  # get name of the callable handler
506
508
  return getattr(handler, "__qualname__", handler.__class__.__name__)
507
509
 
508
- def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
510
+ def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
509
511
  @functools.wraps(handler)
510
512
  def _timeit_handler(*args: Any, **kwargs: Any) -> None:
511
513
  self._event_handlers_timer.reset()
@@ -530,7 +532,7 @@ class HandlersTimeProfiler:
530
532
  t = self._dataflow_timer.value()
531
533
  self.dataflow_times.append(t)
532
534
 
533
- def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
535
+ def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
534
536
  # reset the variables used for profiling
535
537
  self.dataflow_times = []
536
538
  self.processing_times = []
@@ -689,6 +691,7 @@ class HandlersTimeProfiler:
689
691
 
690
692
  results_dump = torch.stack(cols, dim=1).numpy()
691
693
 
694
+ # pyrefly: ignore [bad-argument-type]
692
695
  results_df = pd.DataFrame(data=results_dump, columns=headers)
693
696
  results_df.to_csv(output_path, index=False)
694
697
 
@@ -200,7 +200,7 @@ class ProgressBar(BaseLogger):
200
200
  Accepted output value types are numbers, 0d and 1d torch tensors and strings.
201
201
 
202
202
  """
203
- desc = self.tqdm_kwargs.get("desc", None)
203
+ desc = self.tqdm_kwargs.get("desc", "")
204
204
 
205
205
  if event_name not in engine._allowed_events:
206
206
  raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")
@@ -223,8 +223,13 @@ class ProgressBar(BaseLogger):
223
223
  super(ProgressBar, self).attach(engine, log_handler, event_name)
224
224
  engine.add_event_handler(closing_event_name, self._close)
225
225
 
226
- def attach_opt_params_handler( # type: ignore[empty-body]
227
- self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any
226
+ def attach_opt_params_handler(
227
+ self,
228
+ engine: Engine,
229
+ event_name: Union[str, Events],
230
+ *args: Any,
231
+ **kwargs: Any,
232
+ # pyrefly: ignore [bad-return]
228
233
  ) -> RemovableEventHandle:
229
234
  """Intentionally empty"""
230
235
  pass
@@ -298,8 +303,7 @@ class _OutputHandler(BaseOutputHandler):
298
303
  rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
299
304
  metrics = OrderedDict()
300
305
  for key, value in rendered_metrics.items():
301
- key = "_".join(key[1:]) # tqdm has tag as description
302
-
306
+ key = "_".join(key[1:]) # skip tag as tqdm has tag as description
303
307
  metrics[key] = value
304
308
 
305
309
  if metrics:
@@ -1,7 +1,7 @@
1
1
  """Visdom logger and its helper handlers."""
2
2
 
3
3
  import os
4
- from typing import Any, Callable, cast, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
@@ -137,6 +137,13 @@ class VisdomLogger(BaseLogger):
137
137
  output_transform=lambda loss: {"loss": loss}
138
138
  )
139
139
 
140
+ Note:
141
+ :class:`~ignite.handlers.visdom_logger.OutputHandler` can handle
142
+ metrics, state attributes and engine output values of the following format:
143
+ - scalar values (i.e. int, float)
144
+ - 0d and 1d pytorch tensors
145
+ - dicts and list/tuples of previous types
146
+
140
147
  .. versionchanged:: 0.4.7
141
148
  accepts an optional list of `state_attributes`
142
149
  """
@@ -158,7 +165,7 @@ class VisdomLogger(BaseLogger):
158
165
  "pip install git+https://github.com/fossasia/visdom.git"
159
166
  )
160
167
 
161
- if num_workers > 0:
168
+ if num_workers > 0 or TYPE_CHECKING:
162
169
  # If visdom is installed, one of its dependencies `tornado`
163
170
  # requires also `futures` to be installed.
164
171
  # Let's check anyway if we can import it.
@@ -172,7 +179,7 @@ class VisdomLogger(BaseLogger):
172
179
  )
173
180
 
174
181
  if server is None:
175
- server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost"))
182
+ server = os.environ.get("VISDOM_SERVER_URL", "localhost")
176
183
 
177
184
  if port is None:
178
185
  port = int(os.environ.get("VISDOM_PORT", 8097))
@@ -1,6 +1,7 @@
1
1
  """WandB logger and its helper handlers."""
2
2
 
3
3
  from typing import Any, Callable, List, Optional, Union
4
+ from warnings import warn
4
5
 
5
6
  from torch.optim import Optimizer
6
7
 
@@ -26,7 +27,7 @@ class WandBLogger(BaseLogger):
26
27
  Args:
27
28
  args: Positional arguments accepted by `wandb.init`.
28
29
  kwargs: Keyword arguments accepted by `wandb.init`.
29
- Please see `wandb.init <https://docs.wandb.ai/ref/python/init>`_ for documentation of possible parameters.
30
+ Please see `wandb.init <https://docs.wandb.ai/ref/python/sdk/functions/init/>`_ for documentation of possible parameters.
30
31
 
31
32
  Examples:
32
33
  .. code-block:: python
@@ -120,6 +121,12 @@ class WandBLogger(BaseLogger):
120
121
  )
121
122
  evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
122
123
 
124
+ Note:
125
+ :class:`~ignite.handlers.wandb_logger.OutputHandler` can handle
126
+ metrics, state attributes and engine output values of the following format:
127
+ - scalar values (i.e. int, float)
128
+ - 0d and 1d pytorch tensors
129
+ - dicts and list/tuples of previous types
123
130
 
124
131
  """
125
132
 
@@ -166,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
166
173
  Default is None, global_step based on attached engine. If provided,
167
174
  uses function output as global_step. To setup global step from another engine, please use
168
175
  :meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
169
- sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
170
- the default value of wandb.log.
176
+ sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
171
177
 
172
178
  Examples:
173
179
  .. code-block:: python
@@ -278,7 +284,8 @@ class OutputHandler(BaseOutputHandler):
278
284
  state_attributes: Optional[List[str]] = None,
279
285
  ):
280
286
  super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
281
- self.sync = sync
287
+ if sync is not None:
288
+ warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
282
289
 
283
290
  def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
284
291
  if not isinstance(logger, WandBLogger):
@@ -292,7 +299,7 @@ class OutputHandler(BaseOutputHandler):
292
299
  )
293
300
 
294
301
  metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
295
- logger.log(metrics, step=global_step, sync=self.sync)
302
+ logger.log(metrics, step=global_step)
296
303
 
297
304
 
298
305
  class OptimizerParamsHandler(BaseOptimizerParamsHandler):
@@ -303,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
303
310
  as a sequence.
304
311
  param_name: parameter name
305
312
  tag: common title for all produced plots. For example, "generator"
306
- sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
307
- the default value of wandb.log.
313
+ sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
308
314
 
309
315
  Examples:
310
316
  .. code-block:: python
@@ -340,7 +346,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
340
346
  self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
341
347
  ):
342
348
  super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
343
- self.sync = sync
349
+ if sync is not None:
350
+ warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
344
351
 
345
352
  def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
346
353
  if not isinstance(logger, WandBLogger):
@@ -352,4 +359,4 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
352
359
  f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
353
360
  for i, param_group in enumerate(self.optimizer.param_groups)
354
361
  }
355
- logger.log(params, step=global_step, sync=self.sync)
362
+ logger.log(params, step=global_step)
@@ -254,6 +254,8 @@ class Accuracy(_BaseClassification):
254
254
  y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
255
255
  y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
256
256
  correct = torch.all(y == y_pred.type_as(y), dim=-1)
257
+ else:
258
+ raise ValueError(f"Unexpected type: {self._type}")
257
259
 
258
260
  self._num_correct += torch.sum(correct).to(self._device)
259
261
  self._num_examples += correct.shape[0]
@@ -86,7 +86,7 @@ class CalinskiHarabaszScore(_ClusteringMetricBase):
86
86
 
87
87
  .. testoutput::
88
88
 
89
- 5.733936
89
+ 5.733935832977295
90
90
 
91
91
  .. versionadded:: 0.5.2
92
92
  """
@@ -86,7 +86,7 @@ class SilhouetteScore(_ClusteringMetricBase):
86
86
 
87
87
  .. testoutput::
88
88
 
89
- 0.12607366
89
+ 0.1260736584663391
90
90
 
91
91
  .. versionadded:: 0.5.2
92
92
  """
ignite/metrics/fbeta.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Union
1
+ from typing import Callable, cast, Optional, Union
2
2
 
3
3
  import torch
4
4
 
@@ -15,7 +15,7 @@ def Fbeta(
15
15
  precision: Optional[Precision] = None,
16
16
  recall: Optional[Recall] = None,
17
17
  output_transform: Optional[Callable] = None,
18
- device: Union[str, torch.device] = torch.device("cpu"),
18
+ device: Optional[Union[str, torch.device]] = None,
19
19
  ) -> MetricsLambda:
20
20
  r"""Calculates F-beta score.
21
21
 
@@ -143,17 +143,26 @@ def Fbeta(
143
143
  if not (beta > 0):
144
144
  raise ValueError(f"Beta should be a positive integer, but given {beta}")
145
145
 
146
- if precision is not None and output_transform is not None:
147
- raise ValueError("If precision argument is provided, output_transform should be None")
146
+ if precision is not None:
147
+ if output_transform is not None:
148
+ raise ValueError("If precision argument is provided, output_transform should be None")
149
+ if device is not None:
150
+ raise ValueError("If precision argument is provided, device should be None")
148
151
 
149
- if recall is not None and output_transform is not None:
150
- raise ValueError("If recall argument is provided, output_transform should be None")
152
+ if recall is not None:
153
+ if output_transform is not None:
154
+ raise ValueError("If recall argument is provided, output_transform should be None")
155
+ if device is not None:
156
+ raise ValueError("If recall argument is provided, device should be None")
157
+
158
+ if precision is None and recall is None and device is None:
159
+ device = torch.device("cpu")
151
160
 
152
161
  if precision is None:
153
162
  precision = Precision(
154
163
  output_transform=(lambda x: x) if output_transform is None else output_transform,
155
164
  average=False,
156
- device=device,
165
+ device=cast(Union[str, torch.device], recall._device if recall else device),
157
166
  )
158
167
  elif precision._average:
159
168
  raise ValueError("Input precision metric should have average=False")
@@ -162,7 +171,7 @@ def Fbeta(
162
171
  recall = Recall(
163
172
  output_transform=(lambda x: x) if output_transform is None else output_transform,
164
173
  average=False,
165
- device=device,
174
+ device=cast(Union[str, torch.device], precision._device if precision else device),
166
175
  )
167
176
  elif recall._average:
168
177
  raise ValueError("Input recall metric should have average=False")
ignite/metrics/gan/fid.py CHANGED
@@ -31,13 +31,13 @@ def fid_score(
31
31
  except ImportError:
32
32
  raise ModuleNotFoundError("fid_score requires scipy to be installed.")
33
33
 
34
- mu1, mu2 = mu1.cpu(), mu2.cpu()
35
- sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()
34
+ mu1, mu2 = mu1.detach().cpu(), mu2.detach().cpu()
35
+ sigma1, sigma2 = sigma1.detach().cpu(), sigma2.detach().cpu()
36
36
 
37
37
  diff = mu1 - mu2
38
38
 
39
39
  # Product might be almost singular
40
- covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
40
+ covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False)
41
41
  # Numerical error might give slight imaginary component
42
42
  if np.iscomplexobj(covmean):
43
43
  if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
@@ -73,7 +73,7 @@ class JSDivergence(KLDivergence):
73
73
 
74
74
  .. testoutput::
75
75
 
76
- 0.16266516844431558
76
+ 0.1626...
77
77
 
78
78
  .. versionchanged:: 0.5.1
79
79
  ``skip_unrolling`` argument is added.
@@ -78,7 +78,7 @@ class MaximumMeanDiscrepancy(Metric):
78
78
 
79
79
  .. testoutput::
80
80
 
81
- 1.072697639465332
81
+ 1.0726...
82
82
 
83
83
  .. versionchanged:: 0.5.1
84
84
  ``skip_unrolling`` argument is added.
ignite/metrics/metric.py CHANGED
@@ -361,12 +361,15 @@ class Metric(Serializable, metaclass=ABCMeta):
361
361
  device: Union[str, torch.device] = torch.device("cpu"),
362
362
  skip_unrolling: bool = False,
363
363
  ):
364
+ if not callable(output_transform):
365
+ raise TypeError(f"Argument output_transform should be callable, got {type(output_transform)}")
364
366
  self._output_transform = output_transform
365
367
 
366
368
  # Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it.
367
369
  if torch.device(device).type == "xla":
368
370
  raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
369
371
 
372
+ # pyrefly: ignore [read-only]
370
373
  self._device = torch.device(device)
371
374
  self._skip_unrolling = skip_unrolling
372
375
 
@@ -2,6 +2,7 @@ import math
2
2
  from typing import Any, Callable, Sequence, Tuple, Union
3
3
 
4
4
  import torch
5
+ from torch import Tensor
5
6
 
6
7
  from ignite.exceptions import NotComputableError
7
8
  from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
@@ -71,11 +72,11 @@ class Bleu(Metric):
71
72
 
72
73
  More details can be found in `Papineni et al. 2002`__.
73
74
 
74
- __ https://www.aclweb.org/anthology/P02-1040
75
+ __ https://aclanthology.org/P02-1040/
75
76
 
76
77
  In addition, a review of smoothing techniques can be found in `Chen et al. 2014`__
77
78
 
78
- __ https://aclanthology.org/W14-3346.pdf
79
+ __ https://aclanthology.org/W14-3346/
79
80
 
80
81
  - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
81
82
  - `y_pred` (list(list(str))) - a list of hypotheses sentences.
@@ -236,12 +237,12 @@ class Bleu(Metric):
236
237
  @reinit__is_reduced
237
238
  def reset(self) -> None:
238
239
  if self.average == "macro":
239
- self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
240
+ self._sum_of_bleu = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
240
241
  self._num_sentences = 0
241
242
 
242
243
  if self.average == "micro":
243
- self.p_numerators = torch.zeros(self.ngrams_order + 1)
244
- self.p_denominators = torch.zeros(self.ngrams_order + 1)
244
+ self.p_numerators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
245
+ self.p_denominators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
245
246
  self.hyp_length_sum = 0
246
247
  self.ref_length_sum = 0
247
248
 
@@ -278,8 +279,9 @@ class Bleu(Metric):
278
279
  )
279
280
  return bleu_score
280
281
 
281
- def compute(self) -> None:
282
+ def compute(self) -> Union[None, Tensor, float]:
282
283
  if self.average == "macro":
283
284
  return self._compute_macro()
284
285
  elif self.average == "micro":
285
286
  return self._compute_micro()
287
+ return None
@@ -1,6 +1,5 @@
1
1
  from abc import ABCMeta, abstractmethod
2
- from collections import namedtuple
3
- from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
2
+ from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
4
3
 
5
4
  import torch
6
5
 
@@ -13,11 +12,15 @@ from ignite.metrics.nlp.utils import lcs, ngrams
13
12
  __all__ = ["Rouge", "RougeN", "RougeL"]
14
13
 
15
14
 
16
- class Score(namedtuple("Score", ["match", "candidate", "reference"])):
15
+ class Score(NamedTuple):
17
16
  r"""
18
17
  Computes precision and recall for given matches, candidate and reference lengths.
19
18
  """
20
19
 
20
+ match: int
21
+ candidate: int
22
+ reference: int
23
+
21
24
  def precision(self) -> float:
22
25
  """
23
26
  Calculates precision.
@@ -191,7 +194,7 @@ class RougeN(_BaseRouge):
191
194
 
192
195
  More details can be found in `Lin 2004`__.
193
196
 
194
- __ https://www.aclweb.org/anthology/W04-1013.pdf
197
+ __ https://aclanthology.org/W04-1013
195
198
 
196
199
  - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
197
200
  - `y_pred` (list(list(str))) must be a sequence of tokens.
@@ -265,7 +268,7 @@ class RougeL(_BaseRouge):
265
268
 
266
269
  More details can be found in `Lin 2004`__.
267
270
 
268
- __ https://www.aclweb.org/anthology/W04-1013.pdf
271
+ __ https://aclanthology.org/W04-1013
269
272
 
270
273
  - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
271
274
  - `y_pred` (list(list(str))) must be a sequence of tokens.
@@ -331,7 +334,7 @@ class Rouge(Metric):
331
334
 
332
335
  More details can be found in `Lin 2004`__.
333
336
 
334
- __ https://www.aclweb.org/anthology/W04-1013.pdf
337
+ __ https://aclanthology.org/W04-1013
335
338
 
336
339
  - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
337
340
  - `y_pred` (list(list(str))) must be a sequence of tokens.
@@ -63,7 +63,7 @@ def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: i
63
63
 
64
64
  More details can be found in `Papineni et al. 2002`__.
65
65
 
66
- __ https://www.aclweb.org/anthology/P02-1040.pdf
66
+ __ https://aclanthology.org/P02-1040
67
67
 
68
68
  Args:
69
69
  references: list of references R
@@ -97,7 +97,7 @@ class PrecisionRecallCurve(EpochMetric):
97
97
  if len(self._predictions) < 1 or len(self._targets) < 1:
98
98
  raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")
99
99
 
100
- if self._result is None: # type: ignore
100
+ if self._result is None:
101
101
  _prediction_tensor = torch.cat(self._predictions, dim=0)
102
102
  _target_tensor = torch.cat(self._targets, dim=0)
103
103
 
@@ -110,11 +110,11 @@ class PrecisionRecallCurve(EpochMetric):
110
110
  if idist.get_rank() == 0:
111
111
  # Run compute_fn on zero rank only
112
112
  precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
113
- precision = torch.tensor(precision, device=_prediction_tensor.device)
114
- recall = torch.tensor(recall, device=_prediction_tensor.device)
113
+ precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype)
114
+ recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype)
115
115
  # thresholds can have negative strides, not compatible with torch tensors
116
116
  # https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
117
- thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
117
+ thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device, dtype=self._double_dtype)
118
118
  else:
119
119
  precision, recall, thresholds = None, None, None
120
120
 
@@ -126,4 +126,4 @@ class PrecisionRecallCurve(EpochMetric):
126
126
 
127
127
  self._result = (precision, recall, thresholds) # type: ignore[assignment]
128
128
 
129
- return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result) # type: ignore
129
+ return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)