pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (83) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/base.py +7 -0
  37. ignite/distributed/comp_models/horovod.py +35 -5
  38. ignite/distributed/comp_models/native.py +8 -4
  39. ignite/distributed/comp_models/xla.py +5 -0
  40. ignite/distributed/launcher.py +4 -8
  41. ignite/distributed/utils.py +12 -4
  42. ignite/engine/__init__.py +9 -9
  43. ignite/engine/deterministic.py +1 -1
  44. ignite/engine/engine.py +38 -14
  45. ignite/engine/events.py +2 -1
  46. ignite/handlers/__init__.py +2 -0
  47. ignite/handlers/base_logger.py +47 -12
  48. ignite/handlers/checkpoint.py +46 -5
  49. ignite/handlers/clearml_logger.py +16 -4
  50. ignite/handlers/fbresearch_logger.py +2 -2
  51. ignite/handlers/lr_finder.py +9 -9
  52. ignite/handlers/mlflow_logger.py +43 -0
  53. ignite/handlers/neptune_logger.py +8 -0
  54. ignite/handlers/param_scheduler.py +7 -3
  55. ignite/handlers/polyaxon_logger.py +7 -0
  56. ignite/handlers/state_param_scheduler.py +8 -2
  57. ignite/handlers/tensorboard_logger.py +43 -0
  58. ignite/handlers/time_profilers.py +6 -3
  59. ignite/handlers/tqdm_logger.py +9 -5
  60. ignite/handlers/visdom_logger.py +10 -3
  61. ignite/handlers/wandb_logger.py +16 -9
  62. ignite/metrics/accuracy.py +2 -0
  63. ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
  64. ignite/metrics/clustering/silhouette_score.py +1 -1
  65. ignite/metrics/fbeta.py +17 -8
  66. ignite/metrics/gan/fid.py +3 -3
  67. ignite/metrics/js_divergence.py +1 -1
  68. ignite/metrics/maximum_mean_discrepancy.py +1 -1
  69. ignite/metrics/metric.py +3 -0
  70. ignite/metrics/nlp/bleu.py +8 -6
  71. ignite/metrics/nlp/rouge.py +9 -6
  72. ignite/metrics/nlp/utils.py +1 -1
  73. ignite/metrics/precision_recall_curve.py +5 -5
  74. ignite/metrics/regression/_base.py +4 -0
  75. ignite/metrics/regression/fractional_bias.py +1 -1
  76. ignite/metrics/roc_auc.py +4 -3
  77. ignite/metrics/ssim.py +63 -21
  78. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  79. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
  80. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
  81. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
  82. pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
  83. {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
@@ -2,10 +2,9 @@ import itertools
2
2
  import socket
3
3
  from contextlib import contextmanager
4
4
  from functools import wraps
5
- from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
5
+ from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Union
6
6
 
7
7
  import torch
8
- from torch import distributed as dist
9
8
 
10
9
  from ignite.distributed.comp_models import (
11
10
  _SerialModel,
@@ -384,7 +383,7 @@ def all_gather_tensors_with_shapes(
384
383
  if isinstance(group, list) and all(isinstance(item, int) for item in group):
385
384
  group = _model.new_group(group)
386
385
 
387
- if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
386
+ if _rank_not_in_group(group):
388
387
  return [tensor]
389
388
 
390
389
  max_shape = torch.tensor(shapes).amax(dim=0)
@@ -392,7 +391,7 @@ def all_gather_tensors_with_shapes(
392
391
  padded_tensor = torch.nn.functional.pad(
393
392
  tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
394
393
  )
395
- all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
394
+ all_padded_tensors: torch.Tensor = cast(torch.Tensor, _model.all_gather(padded_tensor, group=group))
396
395
  return [
397
396
  all_padded_tensors[
398
397
  [
@@ -731,3 +730,12 @@ def one_rank_first(rank: int = 0, local: bool = False) -> Any:
731
730
 
732
731
  if current_rank == rank:
733
732
  barrier()
733
+
734
+
735
+ def _rank_not_in_group(group: Optional[Union[Any, List[int]]]) -> bool:
736
+ """Check if the current process's rank is not in a given group."""
737
+ if group is None:
738
+ return False
739
+ if isinstance(group, list) and all(isinstance(item, int) for item in group):
740
+ group = new_group(group)
741
+ return _model._rank_not_in_group(group)
ignite/engine/__init__.py CHANGED
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
133
133
  prepare_batch: Callable = _prepare_batch,
134
134
  model_transform: Callable[[Any], Any] = lambda output: output,
135
135
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
136
- scaler: Optional["torch.cuda.amp.GradScaler"] = None,
136
+ scaler: Optional["torch.amp.GradScaler"] = None,
137
137
  gradient_accumulation_steps: int = 1,
138
138
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
139
139
  ) -> Callable:
140
- """Factory function for supervised training using ``torch.cuda.amp``.
140
+ """Factory function for supervised training using ``torch.amp``.
141
141
 
142
142
  Args:
143
143
  model: the model to train.
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
170
170
  model = ...
171
171
  optimizer = ...
172
172
  loss_fn = ...
173
- scaler = torch.cuda.amp.GradScaler(2**10)
173
+ scaler = torch.amp.GradScaler('cuda', 2**10)
174
174
 
175
175
  update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
176
176
  trainer = Engine(update_fn)
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
393
393
 
394
394
 
395
395
  def _check_arg(
396
- on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
397
- ) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
396
+ on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
397
+ ) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
398
398
  """Checking tpu, mps, amp and GradScaler instance combinations."""
399
399
  if on_mps and amp_mode:
400
400
  raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
@@ -410,9 +410,9 @@ def _check_arg(
410
410
  raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
411
411
  elif amp_mode == "amp" and isinstance(scaler, bool):
412
412
  try:
413
- from torch.cuda.amp import GradScaler
413
+ from torch.amp import GradScaler
414
414
  except ImportError:
415
- raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
415
+ raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
416
416
  scaler = GradScaler(enabled=True)
417
417
 
418
418
  if on_tpu:
@@ -434,7 +434,7 @@ def create_supervised_trainer(
434
434
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
435
435
  deterministic: bool = False,
436
436
  amp_mode: Optional[str] = None,
437
- scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
437
+ scaler: Union[bool, "torch.amp.GradScaler"] = False,
438
438
  gradient_accumulation_steps: int = 1,
439
439
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
440
440
  ) -> Engine:
@@ -459,7 +459,7 @@ def create_supervised_trainer(
459
459
  :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
460
460
  (default: False).
461
461
  amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
462
- `torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
462
+ `torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
463
463
  using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
464
464
  scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
465
465
  and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
@@ -85,7 +85,7 @@ class ReproducibleBatchSampler(BatchSampler):
85
85
 
86
86
 
87
87
  def _get_rng_states() -> List[Any]:
88
- output = [random.getstate(), torch.get_rng_state()]
88
+ output: List[Any] = [random.getstate(), torch.get_rng_state()]
89
89
  try:
90
90
  import numpy as np
91
91
 
ignite/engine/engine.py CHANGED
@@ -148,7 +148,7 @@ class Engine(Serializable):
148
148
  self.should_interrupt = False
149
149
  self.state = State()
150
150
  self._state_dict_user_keys: List[str] = []
151
- self._allowed_events: List[EventEnum] = []
151
+ self._allowed_events: List[Union[str, EventEnum]] = []
152
152
 
153
153
  self._dataloader_iter: Optional[Iterator[Any]] = None
154
154
  self._init_iter: Optional[int] = None
@@ -163,9 +163,7 @@ class Engine(Serializable):
163
163
  # generator provided by self._internal_run_as_gen
164
164
  self._internal_run_generator: Optional[Generator[Any, None, State]] = None
165
165
 
166
- def register_events(
167
- self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
168
- ) -> None:
166
+ def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
169
167
  """Add events that can be fired.
170
168
 
171
169
  Registering an event will let the user trigger these events at any point.
@@ -249,6 +247,17 @@ class Engine(Serializable):
249
247
  # we need to update state attributes associated with new custom events
250
248
  self.state._update_attrs()
251
249
 
250
+ def has_registered_events(self, event: Any) -> bool:
251
+ """Check whether engine has a registered event.
252
+
253
+ Args:
254
+ event: Event to check for registration.
255
+
256
+ Returns:
257
+ bool: True if the event is registered, False otherwise.
258
+ """
259
+ return event in self._allowed_events
260
+
252
261
  def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
253
262
  # signature of the following wrapper will be inspected during registering to check if engine is necessary
254
263
  # we have to build a wrapper with relevant signature : solution is functools.wraps
@@ -328,7 +337,7 @@ class Engine(Serializable):
328
337
 
329
338
  try:
330
339
  _check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331
- self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
340
+ self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332
341
  except ValueError:
333
342
  _check_signature(handler, "handler", *(event_args + args), **kwargs)
334
343
  self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +441,15 @@ class Engine(Serializable):
432
441
  self.last_event_name = event_name
433
442
  for func, args, kwargs in self._event_handlers[event_name]:
434
443
  kwargs.update(event_kwargs)
435
- first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
444
+ if args and isinstance(args[0], weakref.ref):
445
+ resolved_engine = args[0]()
446
+ if resolved_engine is None:
447
+ raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
448
+ first, others = ((resolved_engine,), args[1:])
449
+ else:
450
+ # metrics do not provide engine when registered
451
+ first, others = (tuple(), args)
452
+
436
453
  func(*first, *(event_args + others), **kwargs)
437
454
 
438
455
  def fire_event(self, event_name: Any) -> None:
@@ -970,9 +987,9 @@ class Engine(Serializable):
970
987
  def _internal_run_as_gen(self) -> Generator[Any, None, State]:
971
988
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
972
989
  self._init_timers(self.state)
990
+ start_time = time.time()
973
991
  try:
974
992
  try:
975
- start_time = time.time()
976
993
  self._fire_event(Events.STARTED)
977
994
  yield from self._maybe_terminate_or_interrupt()
978
995
 
@@ -991,7 +1008,7 @@ class Engine(Serializable):
991
1008
  # time is available for handlers but must be updated after fire
992
1009
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
993
1010
 
994
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1011
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
995
1012
  handlers_start_time = time.time()
996
1013
  self._fire_event(Events.EPOCH_COMPLETED)
997
1014
  epoch_time_taken += time.time() - handlers_start_time
@@ -1024,7 +1041,7 @@ class Engine(Serializable):
1024
1041
  self.state.times[Events.COMPLETED.name] = time_taken
1025
1042
 
1026
1043
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1027
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1044
+ if self.should_terminate != "skip_completed":
1028
1045
  handlers_start_time = time.time()
1029
1046
  self._fire_event(Events.COMPLETED)
1030
1047
  time_taken += time.time() - handlers_start_time
@@ -1069,7 +1086,7 @@ class Engine(Serializable):
1069
1086
  )
1070
1087
 
1071
1088
  while True:
1072
- self.state.batch = self.state.output = None
1089
+ self.state.batch = None
1073
1090
 
1074
1091
  try:
1075
1092
  # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
@@ -1081,6 +1098,9 @@ class Engine(Serializable):
1081
1098
  yield from self._maybe_terminate_or_interrupt()
1082
1099
 
1083
1100
  self.state.batch = next(self._dataloader_iter)
1101
+ # We on purpose reset state.output here as for iterable dataloaders
1102
+ # we accidentally can remove it when one epoch is completed.
1103
+ self.state.output = None
1084
1104
 
1085
1105
  # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
1086
1106
  # if no data was provided to engine.run(data=None, ...)
@@ -1167,9 +1187,9 @@ class Engine(Serializable):
1167
1187
  # internal_run without generator for BC
1168
1188
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
1169
1189
  self._init_timers(self.state)
1190
+ start_time = time.time()
1170
1191
  try:
1171
1192
  try:
1172
- start_time = time.time()
1173
1193
  self._fire_event(Events.STARTED)
1174
1194
  self._maybe_terminate_legacy()
1175
1195
 
@@ -1188,7 +1208,7 @@ class Engine(Serializable):
1188
1208
  # time is available for handlers but must be updated after fire
1189
1209
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1190
1210
 
1191
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1211
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1192
1212
  handlers_start_time = time.time()
1193
1213
  self._fire_event(Events.EPOCH_COMPLETED)
1194
1214
  epoch_time_taken += time.time() - handlers_start_time
@@ -1221,7 +1241,7 @@ class Engine(Serializable):
1221
1241
  self.state.times[Events.COMPLETED.name] = time_taken
1222
1242
 
1223
1243
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1224
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1244
+ if self.should_terminate != "skip_completed":
1225
1245
  handlers_start_time = time.time()
1226
1246
  self._fire_event(Events.COMPLETED)
1227
1247
  time_taken += time.time() - handlers_start_time
@@ -1254,7 +1274,7 @@ class Engine(Serializable):
1254
1274
  )
1255
1275
 
1256
1276
  while True:
1257
- self.state.batch = self.state.output = None
1277
+ self.state.batch = None
1258
1278
  try:
1259
1279
  # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
1260
1280
  if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
@@ -1265,6 +1285,10 @@ class Engine(Serializable):
1265
1285
  self._maybe_terminate_legacy()
1266
1286
 
1267
1287
  self.state.batch = next(self._dataloader_iter)
1288
+ # We on purpose reset state.output here as for iterable dataloaders
1289
+ # we accidentally can remove it when one epoch is completed.
1290
+ self.state.output = None
1291
+
1268
1292
  # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
1269
1293
  # if no data was provided to engine.run(data=None, ...)
1270
1294
  if self.state.dataloader is not None:
ignite/engine/events.py CHANGED
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
91
91
  raise ValueError("Argument every should be integer and greater than zero")
92
92
 
93
93
  if once is not None:
94
- c1 = isinstance(once, numbers.Integral) and once > 0
94
+ c1 = isinstance(once, int) and once > 0
95
95
  c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
96
96
  if not (c1 or c2):
97
97
  raise ValueError(
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
240
240
  def __new__(cls, value: str) -> "EventEnum":
241
241
  obj = CallableEventWithFilter.__new__(cls)
242
242
  obj._value_ = value
243
+ # pyrefly: ignore [bad-return]
243
244
  return obj
244
245
 
245
246
 
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
6
6
  from ignite.handlers.clearml_logger import ClearMLLogger
7
7
  from ignite.handlers.early_stopping import EarlyStopping
8
8
  from ignite.handlers.ema_handler import EMAHandler
9
+ from ignite.handlers.fbresearch_logger import FBResearchLogger
9
10
  from ignite.handlers.lr_finder import FastaiLRFinder
10
11
  from ignite.handlers.mlflow_logger import MLflowLogger
11
12
  from ignite.handlers.neptune_logger import NeptuneLogger
@@ -64,6 +65,7 @@ __all__ = [
64
65
  "CyclicalScheduler",
65
66
  "create_lr_scheduler_with_warmup",
66
67
  "FastaiLRFinder",
68
+ "FBResearchLogger",
67
69
  "EMAHandler",
68
70
  "BasicTimeProfiler",
69
71
  "HandlersTimeProfiler",
@@ -1,5 +1,6 @@
1
1
  """Base logger and its helper handlers."""
2
2
 
3
+ import collections.abc as collections
3
4
  import numbers
4
5
  import warnings
5
6
  from abc import ABCMeta, abstractmethod
@@ -145,30 +146,64 @@ class BaseOutputHandler(BaseHandler):
145
146
 
146
147
  metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
147
148
 
148
- def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
149
- return (tag, name) + args
149
+ def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
150
+ if parent_key is None or isinstance(parent_key, str):
151
+ return (parent_key,) + args
152
+ return parent_key + args
150
153
 
151
- def key_str_tf(tag: str, name: str, *args: str) -> str:
152
- return "/".join((tag, name) + args)
154
+ def key_str_fn(parent_key: str, *args: str) -> str:
155
+ args_str = "/".join(args)
156
+ return f"{parent_key}/{args_str}"
153
157
 
154
- key_tf = key_tuple_tf if key_tuple else key_str_tf
158
+ key_fn = key_tuple_fn if key_tuple else key_str_fn
155
159
 
156
- for name, value in metrics_state_attrs.items():
160
+ def handle_value_fn(
161
+ value: Union[str, int, float, numbers.Number, torch.Tensor]
162
+ ) -> Union[None, str, float, numbers.Number]:
157
163
  if isinstance(value, numbers.Number):
158
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value
164
+ return value
159
165
  elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
160
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
161
- elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
162
- for i, v in enumerate(value):
163
- metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
166
+ return value.item()
164
167
  else:
165
168
  if isinstance(value, str) and log_text:
166
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value
169
+ return value
167
170
  else:
168
171
  warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
172
+ return None
173
+
174
+ metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag)
169
175
  return metrics_state_attrs_dict
170
176
 
171
177
 
178
+ def _flatten_dict(
179
+ in_dict: collections.Mapping,
180
+ key_fn: Callable,
181
+ value_fn: Callable,
182
+ parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
183
+ ) -> Dict:
184
+ items = {}
185
+ for key, value in in_dict.items():
186
+ new_key = key_fn(parent_key, key)
187
+ if isinstance(value, collections.Mapping):
188
+ items.update(_flatten_dict(value, key_fn, value_fn, new_key))
189
+ elif any(
190
+ [
191
+ isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple
192
+ not isinstance(value, str) and isinstance(value, collections.Sequence),
193
+ ]
194
+ ):
195
+ for i, item in enumerate(value):
196
+ items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key))
197
+ elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
198
+ for i, item in enumerate(value):
199
+ items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key))
200
+ else:
201
+ new_value = value_fn(value)
202
+ if new_value is not None:
203
+ items[new_key] = new_value
204
+ return items
205
+
206
+
172
207
  class BaseWeightsScalarHandler(BaseWeightsHandler):
173
208
  """
174
209
  Helper handler to log model's weights or gradients as scalars.
@@ -21,10 +21,21 @@ else:
21
21
 
22
22
  import ignite.distributed as idist
23
23
  from ignite.base import Serializable
24
- from ignite.engine import Engine, Events
24
+ from ignite.engine import Engine, Events, EventEnum
25
25
  from ignite.utils import _tree_apply2, _tree_map
26
26
 
27
- __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
27
+ __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
28
+
29
+
30
+ class CheckpointEvents(EventEnum):
31
+ """Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
32
+
33
+ - SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
34
+
35
+ .. versionadded:: 0.5.3
36
+ """
37
+
38
+ SAVED_CHECKPOINT = "saved_checkpoint"
28
39
 
29
40
 
30
41
  class BaseSaveHandler(metaclass=ABCMeta):
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
264
275
  to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
265
276
  )
266
277
 
278
+ Respond to checkpoint events:
279
+
280
+ .. code-block:: python
281
+
282
+ from ignite.handlers import Checkpoint
283
+ from ignite.engine import Engine, Events
284
+
285
+ checkpoint_handler = Checkpoint(
286
+ {'model': model, 'optimizer': optimizer},
287
+ save_dir,
288
+ n_saved=2
289
+ )
290
+
291
+ @trainer.on(Checkpoint.SAVED_CHECKPOINT)
292
+ def on_checkpoint_saved(engine):
293
+ print(f"Checkpoint saved at epoch {engine.state.epoch}")
294
+
295
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
296
+
297
+ Attributes:
298
+ SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
299
+ :class:`~ignite.handlers.checkpoint.CheckpointEvents`.
300
+
267
301
  .. versionchanged:: 0.4.3
268
302
 
269
303
  - Checkpoint can save model with same filename.
@@ -274,9 +308,14 @@ class Checkpoint(Serializable):
274
308
  - `score_name` can be used to define `score_function` automatically without providing `score_function`.
275
309
  - `save_handler` automatically saves to disk if path to directory is provided.
276
310
  - `save_on_rank` saves objects on this rank in a distributed configuration.
311
+
312
+ .. versionchanged:: 0.5.3
313
+
314
+ - Added ``SAVED_CHECKPOINT`` class attribute.
277
315
  """
278
316
 
279
- Item = NamedTuple("Item", [("priority", int), ("filename", str)])
317
+ SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
318
+ Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
280
319
  _state_dict_all_req_keys = ("_saved",)
281
320
 
282
321
  def __init__(
@@ -284,7 +323,7 @@ class Checkpoint(Serializable):
284
323
  to_save: Mapping,
285
324
  save_handler: Union[str, Path, Callable, BaseSaveHandler],
286
325
  filename_prefix: str = "",
287
- score_function: Optional[Callable] = None,
326
+ score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
288
327
  score_name: Optional[str] = None,
289
328
  n_saved: Union[int, None] = 1,
290
329
  global_step_transform: Optional[Callable] = None,
@@ -400,6 +439,8 @@ class Checkpoint(Serializable):
400
439
  return new > self._saved[0].priority
401
440
 
402
441
  def __call__(self, engine: Engine) -> None:
442
+ if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
443
+ engine.register_events(*CheckpointEvents)
403
444
  global_step = None
404
445
  if self.global_step_transform is not None:
405
446
  global_step = self.global_step_transform(engine, engine.last_event_name)
@@ -460,11 +501,11 @@ class Checkpoint(Serializable):
460
501
  if self.include_self:
461
502
  # Now that we've updated _saved, we can add our own state_dict.
462
503
  checkpoint["checkpointer"] = self.state_dict()
463
-
464
504
  try:
465
505
  self.save_handler(checkpoint, filename, metadata)
466
506
  except TypeError:
467
507
  self.save_handler(checkpoint, filename)
508
+ engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
468
509
 
469
510
  def _setup_checkpoint(self) -> Dict[str, Any]:
470
511
  if self.to_save is not None:
@@ -109,8 +109,17 @@ class ClearMLLogger(BaseLogger):
109
109
  log_handler=WeightsScalarHandler(model)
110
110
  )
111
111
 
112
+ Note:
113
+ :class:`~ignite.handlers.clearml_logger.OutputHandler` can handle
114
+ metrics, state attributes and engine output values of the following format:
115
+ - scalar values (i.e. int, float)
116
+ - 0d and 1d pytorch tensors
117
+ - dicts and list/tuples of previous types
118
+
112
119
  """
113
120
 
121
+ _task: Any
122
+
114
123
  def __init__(self, **kwargs: Any):
115
124
  try:
116
125
  from clearml import Task
@@ -342,9 +351,10 @@ class OutputHandler(BaseOutputHandler):
342
351
  for key, value in metrics.items():
343
352
  if len(key) == 2:
344
353
  logger.clearml_logger.report_scalar(title=key[0], series=key[1], iteration=global_step, value=value)
345
- elif len(key) == 3:
354
+ elif len(key) >= 3:
355
+ series = "/".join(key[2:])
346
356
  logger.clearml_logger.report_scalar(
347
- title=f"{key[0]}/{key[1]}", series=key[2], iteration=global_step, value=value
357
+ title=f"{key[0]}/{key[1]}", series=series, iteration=global_step, value=value
348
358
  )
349
359
 
350
360
 
@@ -815,6 +825,8 @@ class ClearMLSaver(DiskSaver):
815
825
 
816
826
  """
817
827
 
828
+ _task: Any
829
+
818
830
  def __init__(
819
831
  self,
820
832
  logger: Optional[ClearMLLogger] = None,
@@ -949,8 +961,8 @@ class ClearMLSaver(DiskSaver):
949
961
  metadata=metadata,
950
962
  )
951
963
 
952
- pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
953
- post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)
964
+ pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback) # type: ignore[arg-type]
965
+ post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback) # type: ignore[arg-type]
954
966
 
955
967
  try:
956
968
  super(ClearMLSaver, self).__call__(checkpoint, filename, metadata)
@@ -7,7 +7,7 @@ import torch
7
7
 
8
8
  from ignite import utils
9
9
  from ignite.engine import Engine, Events
10
- from ignite.handlers import Timer
10
+ from ignite.handlers.timing import Timer
11
11
 
12
12
  MB = 1024.0 * 1024.0
13
13
 
@@ -154,7 +154,7 @@ class FBResearchLogger:
154
154
  if torch.cuda.is_available():
155
155
  cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
156
156
 
157
- current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
157
+ current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
158
158
  iter_avg_time = self.iter_timer.value()
159
159
 
160
160
  eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
@@ -98,15 +98,18 @@ class FastaiLRFinder:
98
98
  self._best_loss = None
99
99
  self._diverge_flag = False
100
100
 
101
+ assert trainer.state.epoch_length is not None
102
+ assert trainer.state.max_epochs is not None
103
+
101
104
  # attach LRScheduler to trainer.
102
105
  if num_iter is None:
103
106
  num_iter = trainer.state.epoch_length * trainer.state.max_epochs
104
107
  else:
105
- max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
108
+ max_iter = trainer.state.epoch_length * trainer.state.max_epochs
106
109
  if max_iter < num_iter:
107
110
  max_iter = num_iter
108
111
  trainer.state.max_iters = num_iter
109
- trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
112
+ trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
110
113
 
111
114
  if not trainer.has_event_handler(self._reached_num_iterations):
112
115
  trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
@@ -178,17 +181,14 @@ class FastaiLRFinder:
178
181
  loss = idist.all_reduce(loss)
179
182
  lr = self._lr_schedule.get_param()
180
183
  self._history["lr"].append(lr)
181
- if trainer.state.iteration == 1:
184
+ if trainer.state.iteration != 1 and smooth_f > 0:
185
+ loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
186
+ if self._best_loss is None or loss < self._best_loss:
182
187
  self._best_loss = loss
183
- else:
184
- if smooth_f > 0:
185
- loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
186
- if loss < self._best_loss:
187
- self._best_loss = loss
188
188
  self._history["loss"].append(loss)
189
189
 
190
190
  # Check if the loss has diverged; if it has, stop the trainer
191
- if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator]
191
+ if self._history["loss"][-1] > diverge_th * self._best_loss:
192
192
  self._diverge_flag = True
193
193
  self.logger.info("Stopping early, the loss has diverged")
194
194
  trainer.terminate()
@@ -84,6 +84,49 @@ class MLflowLogger(BaseLogger):
84
84
  optimizer=optimizer,
85
85
  param_name='lr' # optional
86
86
  )
87
+
88
+ Note:
89
+ :class:`~ignite.handlers.mlflow_logger.OutputHandler` can handle
90
+ metrics, state attributes and engine output values of the following format:
91
+ - scalar values (i.e. int, float)
92
+ - 0d and 1d pytorch tensors
93
+ - dicts and list/tuples of previous types
94
+
95
+ .. code-block:: python
96
+
97
+ # !!! This is not a runnable code !!!
98
+ evalutator.state.metrics = {
99
+ "a": 0,
100
+ "dict_value": {
101
+ "a": 111,
102
+ "c": {"d": 23, "e": [123, 234]},
103
+ },
104
+ "list_value": [12, 13, {"aa": 33, "bb": 44}],
105
+ "tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
106
+ }
107
+
108
+ handler = OutputHandler(
109
+ tag="tag",
110
+ metric_names="all",
111
+ )
112
+
113
+ handler(evaluator, mlflow_logger, event_name=Events.EPOCH_COMPLETED)
114
+ # Behind it would call `mlflow_logger.log_metrics` on
115
+ # {
116
+ # "tag/a": 0,
117
+ # "tag/dict_value/a": 111,
118
+ # "tag/dict_value/c/d": 23,
119
+ # "tag/dict_value/c/e/0": 123,
120
+ # "tag/dict_value/c/e/1": 234,
121
+ # "tag/list_value/0": 12,
122
+ # "tag/list_value/1": 13,
123
+ # "tag/list_value/2/aa": 33,
124
+ # "tag/list_value/2/bb": 44,
125
+ # "tag/tuple_value/0": 112,
126
+ # "tag/tuple_value/1": 113,
127
+ # "tag/tuple_value/2/aaa": 33,
128
+ # "tag/tuple_value/2/bbb": 44,
129
+ # }
87
130
  """
88
131
 
89
132
  def __init__(self, tracking_uri: Optional[str] = None):