pytorch-ignite 0.6.0.dev20250324__py3-none-any.whl → 0.6.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (37) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/distributed/comp_models/native.py +1 -1
  3. ignite/engine/__init__.py +9 -9
  4. ignite/engine/engine.py +30 -4
  5. ignite/handlers/__init__.py +2 -0
  6. ignite/handlers/base_logger.py +47 -12
  7. ignite/handlers/checkpoint.py +44 -3
  8. ignite/handlers/clearml_logger.py +18 -6
  9. ignite/handlers/fbresearch_logger.py +2 -2
  10. ignite/handlers/lr_finder.py +1 -1
  11. ignite/handlers/mlflow_logger.py +43 -0
  12. ignite/handlers/neptune_logger.py +7 -0
  13. ignite/handlers/polyaxon_logger.py +7 -0
  14. ignite/handlers/tensorboard_logger.py +43 -0
  15. ignite/handlers/tqdm_logger.py +2 -3
  16. ignite/handlers/visdom_logger.py +9 -2
  17. ignite/handlers/wandb_logger.py +7 -1
  18. ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
  19. ignite/metrics/clustering/silhouette_score.py +1 -1
  20. ignite/metrics/fbeta.py +17 -8
  21. ignite/metrics/gan/fid.py +3 -3
  22. ignite/metrics/js_divergence.py +1 -1
  23. ignite/metrics/maximum_mean_discrepancy.py +1 -1
  24. ignite/metrics/metric.py +2 -0
  25. ignite/metrics/nlp/bleu.py +8 -6
  26. ignite/metrics/nlp/rouge.py +3 -3
  27. ignite/metrics/nlp/utils.py +1 -1
  28. ignite/metrics/precision_recall_curve.py +5 -5
  29. ignite/metrics/regression/_base.py +4 -0
  30. ignite/metrics/regression/fractional_bias.py +1 -1
  31. ignite/metrics/roc_auc.py +3 -3
  32. ignite/metrics/ssim.py +58 -20
  33. {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/METADATA +11 -17
  34. {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/RECORD +36 -37
  35. {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/WHEEL +1 -2
  36. pytorch_ignite-0.6.0.dev20250324.dist-info/top_level.txt +0 -1
  37. {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20250324"
9
+ __version__ = "0.6.0.dev20251103"
@@ -178,7 +178,7 @@ if has_native_dist_support:
178
178
  c: Counter = Counter(hostnames)
179
179
  sizes = torch.tensor([0] + list(c.values()))
180
180
  cumsum_sizes = torch.cumsum(sizes, dim=0)
181
- node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
181
+ node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
182
182
  local_rank = rank - cumsum_sizes[node_rank].item()
183
183
  return int(local_rank), node_rank
184
184
 
ignite/engine/__init__.py CHANGED
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
133
133
  prepare_batch: Callable = _prepare_batch,
134
134
  model_transform: Callable[[Any], Any] = lambda output: output,
135
135
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
136
- scaler: Optional["torch.cuda.amp.GradScaler"] = None,
136
+ scaler: Optional["torch.amp.GradScaler"] = None,
137
137
  gradient_accumulation_steps: int = 1,
138
138
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
139
139
  ) -> Callable:
140
- """Factory function for supervised training using ``torch.cuda.amp``.
140
+ """Factory function for supervised training using ``torch.amp``.
141
141
 
142
142
  Args:
143
143
  model: the model to train.
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
170
170
  model = ...
171
171
  optimizer = ...
172
172
  loss_fn = ...
173
- scaler = torch.cuda.amp.GradScaler(2**10)
173
+ scaler = torch.amp.GradScaler('cuda', 2**10)
174
174
 
175
175
  update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
176
176
  trainer = Engine(update_fn)
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
393
393
 
394
394
 
395
395
  def _check_arg(
396
- on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
397
- ) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
396
+ on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
397
+ ) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
398
398
  """Checking tpu, mps, amp and GradScaler instance combinations."""
399
399
  if on_mps and amp_mode:
400
400
  raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
@@ -410,9 +410,9 @@ def _check_arg(
410
410
  raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
411
411
  elif amp_mode == "amp" and isinstance(scaler, bool):
412
412
  try:
413
- from torch.cuda.amp import GradScaler
413
+ from torch.amp import GradScaler
414
414
  except ImportError:
415
- raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
415
+ raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
416
416
  scaler = GradScaler(enabled=True)
417
417
 
418
418
  if on_tpu:
@@ -434,7 +434,7 @@ def create_supervised_trainer(
434
434
  output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
435
435
  deterministic: bool = False,
436
436
  amp_mode: Optional[str] = None,
437
- scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
437
+ scaler: Union[bool, "torch.amp.GradScaler"] = False,
438
438
  gradient_accumulation_steps: int = 1,
439
439
  model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
440
440
  ) -> Engine:
@@ -459,7 +459,7 @@ def create_supervised_trainer(
459
459
  :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
460
460
  (default: False).
461
461
  amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
462
- `torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
462
+ `torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
463
463
  using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
464
464
  scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
465
465
  and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
ignite/engine/engine.py CHANGED
@@ -249,6 +249,17 @@ class Engine(Serializable):
249
249
  # we need to update state attributes associated with new custom events
250
250
  self.state._update_attrs()
251
251
 
252
+ def has_registered_events(self, event: Any) -> bool:
253
+ """Check whether engine has a registered event.
254
+
255
+ Args:
256
+ event: Event to check for registration.
257
+
258
+ Returns:
259
+ bool: True if the event is registered, False otherwise.
260
+ """
261
+ return event in self._allowed_events
262
+
252
263
  def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
253
264
  # signature of the following wrapper will be inspected during registering to check if engine is necessary
254
265
  # we have to build a wrapper with relevant signature : solution is functools.wraps
@@ -328,7 +339,7 @@ class Engine(Serializable):
328
339
 
329
340
  try:
330
341
  _check_signature(handler, "handler", self, *(event_args + args), **kwargs)
331
- self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
342
+ self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
332
343
  except ValueError:
333
344
  _check_signature(handler, "handler", *(event_args + args), **kwargs)
334
345
  self._event_handlers[event_name].append((handler, args, kwargs))
@@ -432,7 +443,15 @@ class Engine(Serializable):
432
443
  self.last_event_name = event_name
433
444
  for func, args, kwargs in self._event_handlers[event_name]:
434
445
  kwargs.update(event_kwargs)
435
- first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
446
+ if args and isinstance(args[0], weakref.ref):
447
+ resolved_engine = args[0]()
448
+ if resolved_engine is None:
449
+ raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
450
+ first, others = ((resolved_engine,), args[1:])
451
+ else:
452
+ # metrics do not provide engine when registered
453
+ first, others = (tuple(), args) # type: ignore[assignment]
454
+
436
455
  func(*first, *(event_args + others), **kwargs)
437
456
 
438
457
  def fire_event(self, event_name: Any) -> None:
@@ -1069,7 +1088,7 @@ class Engine(Serializable):
1069
1088
  )
1070
1089
 
1071
1090
  while True:
1072
- self.state.batch = self.state.output = None
1091
+ self.state.batch = None
1073
1092
 
1074
1093
  try:
1075
1094
  # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
@@ -1081,6 +1100,9 @@ class Engine(Serializable):
1081
1100
  yield from self._maybe_terminate_or_interrupt()
1082
1101
 
1083
1102
  self.state.batch = next(self._dataloader_iter)
1103
+ # We on purpose reset state.output here as for iterable dataloaders
1104
+ # we accidentally can remove it when one epoch is completed.
1105
+ self.state.output = None
1084
1106
 
1085
1107
  # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
1086
1108
  # if no data was provided to engine.run(data=None, ...)
@@ -1254,7 +1276,7 @@ class Engine(Serializable):
1254
1276
  )
1255
1277
 
1256
1278
  while True:
1257
- self.state.batch = self.state.output = None
1279
+ self.state.batch = None
1258
1280
  try:
1259
1281
  # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
1260
1282
  if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
@@ -1265,6 +1287,10 @@ class Engine(Serializable):
1265
1287
  self._maybe_terminate_legacy()
1266
1288
 
1267
1289
  self.state.batch = next(self._dataloader_iter)
1290
+ # We on purpose reset state.output here as for iterable dataloaders
1291
+ # we accidentally can remove it when one epoch is completed.
1292
+ self.state.output = None
1293
+
1268
1294
  # We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
1269
1295
  # if no data was provided to engine.run(data=None, ...)
1270
1296
  if self.state.dataloader is not None:
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
6
6
  from ignite.handlers.clearml_logger import ClearMLLogger
7
7
  from ignite.handlers.early_stopping import EarlyStopping
8
8
  from ignite.handlers.ema_handler import EMAHandler
9
+ from ignite.handlers.fbresearch_logger import FBResearchLogger
9
10
  from ignite.handlers.lr_finder import FastaiLRFinder
10
11
  from ignite.handlers.mlflow_logger import MLflowLogger
11
12
  from ignite.handlers.neptune_logger import NeptuneLogger
@@ -64,6 +65,7 @@ __all__ = [
64
65
  "CyclicalScheduler",
65
66
  "create_lr_scheduler_with_warmup",
66
67
  "FastaiLRFinder",
68
+ "FBResearchLogger",
67
69
  "EMAHandler",
68
70
  "BasicTimeProfiler",
69
71
  "HandlersTimeProfiler",
@@ -1,5 +1,6 @@
1
1
  """Base logger and its helper handlers."""
2
2
 
3
+ import collections.abc as collections
3
4
  import numbers
4
5
  import warnings
5
6
  from abc import ABCMeta, abstractmethod
@@ -145,30 +146,64 @@ class BaseOutputHandler(BaseHandler):
145
146
 
146
147
  metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
147
148
 
148
- def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
149
- return (tag, name) + args
149
+ def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
150
+ if parent_key is None or isinstance(parent_key, str):
151
+ return (parent_key,) + args
152
+ return parent_key + args
150
153
 
151
- def key_str_tf(tag: str, name: str, *args: str) -> str:
152
- return "/".join((tag, name) + args)
154
+ def key_str_fn(parent_key: str, *args: str) -> str:
155
+ args_str = "/".join(args)
156
+ return f"{parent_key}/{args_str}"
153
157
 
154
- key_tf = key_tuple_tf if key_tuple else key_str_tf
158
+ key_fn = key_tuple_fn if key_tuple else key_str_fn
155
159
 
156
- for name, value in metrics_state_attrs.items():
160
+ def handle_value_fn(
161
+ value: Union[str, int, float, numbers.Number, torch.Tensor]
162
+ ) -> Union[None, str, float, numbers.Number]:
157
163
  if isinstance(value, numbers.Number):
158
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value
164
+ return value
159
165
  elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
160
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
161
- elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
162
- for i, v in enumerate(value):
163
- metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
166
+ return value.item()
164
167
  else:
165
168
  if isinstance(value, str) and log_text:
166
- metrics_state_attrs_dict[key_tf(self.tag, name)] = value
169
+ return value
167
170
  else:
168
171
  warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
172
+ return None
173
+
174
+ metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag)
169
175
  return metrics_state_attrs_dict
170
176
 
171
177
 
178
+ def _flatten_dict(
179
+ in_dict: collections.Mapping,
180
+ key_fn: Callable,
181
+ value_fn: Callable,
182
+ parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
183
+ ) -> Dict:
184
+ items = {}
185
+ for key, value in in_dict.items():
186
+ new_key = key_fn(parent_key, key)
187
+ if isinstance(value, collections.Mapping):
188
+ items.update(_flatten_dict(value, key_fn, value_fn, new_key))
189
+ elif any(
190
+ [
191
+ isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple
192
+ not isinstance(value, str) and isinstance(value, collections.Sequence),
193
+ ]
194
+ ):
195
+ for i, item in enumerate(value):
196
+ items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key))
197
+ elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
198
+ for i, item in enumerate(value):
199
+ items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key))
200
+ else:
201
+ new_value = value_fn(value)
202
+ if new_value is not None:
203
+ items[new_key] = new_value
204
+ return items
205
+
206
+
172
207
  class BaseWeightsScalarHandler(BaseWeightsHandler):
173
208
  """
174
209
  Helper handler to log model's weights or gradients as scalars.
@@ -21,10 +21,21 @@ else:
21
21
 
22
22
  import ignite.distributed as idist
23
23
  from ignite.base import Serializable
24
- from ignite.engine import Engine, Events
24
+ from ignite.engine import Engine, Events, EventEnum
25
25
  from ignite.utils import _tree_apply2, _tree_map
26
26
 
27
- __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
27
+ __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
28
+
29
+
30
+ class CheckpointEvents(EventEnum):
31
+ """Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
32
+
33
+ - SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
34
+
35
+ .. versionadded:: 0.5.3
36
+ """
37
+
38
+ SAVED_CHECKPOINT = "saved_checkpoint"
28
39
 
29
40
 
30
41
  class BaseSaveHandler(metaclass=ABCMeta):
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
264
275
  to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
265
276
  )
266
277
 
278
+ Respond to checkpoint events:
279
+
280
+ .. code-block:: python
281
+
282
+ from ignite.handlers import Checkpoint
283
+ from ignite.engine import Engine, Events
284
+
285
+ checkpoint_handler = Checkpoint(
286
+ {'model': model, 'optimizer': optimizer},
287
+ save_dir,
288
+ n_saved=2
289
+ )
290
+
291
+ @trainer.on(Checkpoint.SAVED_CHECKPOINT)
292
+ def on_checkpoint_saved(engine):
293
+ print(f"Checkpoint saved at epoch {engine.state.epoch}")
294
+
295
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
296
+
297
+ Attributes:
298
+ SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
299
+ :class:`~ignite.handlers.checkpoint.CheckpointEvents`.
300
+
267
301
  .. versionchanged:: 0.4.3
268
302
 
269
303
  - Checkpoint can save model with same filename.
@@ -274,8 +308,13 @@ class Checkpoint(Serializable):
274
308
  - `score_name` can be used to define `score_function` automatically without providing `score_function`.
275
309
  - `save_handler` automatically saves to disk if path to directory is provided.
276
310
  - `save_on_rank` saves objects on this rank in a distributed configuration.
311
+
312
+ .. versionchanged:: 0.5.3
313
+
314
+ - Added ``SAVED_CHECKPOINT`` class attribute.
277
315
  """
278
316
 
317
+ SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
279
318
  Item = NamedTuple("Item", [("priority", int), ("filename", str)])
280
319
  _state_dict_all_req_keys = ("_saved",)
281
320
 
@@ -400,6 +439,8 @@ class Checkpoint(Serializable):
400
439
  return new > self._saved[0].priority
401
440
 
402
441
  def __call__(self, engine: Engine) -> None:
442
+ if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
443
+ engine.register_events(*CheckpointEvents)
403
444
  global_step = None
404
445
  if self.global_step_transform is not None:
405
446
  global_step = self.global_step_transform(engine, engine.last_event_name)
@@ -460,11 +501,11 @@ class Checkpoint(Serializable):
460
501
  if self.include_self:
461
502
  # Now that we've updated _saved, we can add our own state_dict.
462
503
  checkpoint["checkpointer"] = self.state_dict()
463
-
464
504
  try:
465
505
  self.save_handler(checkpoint, filename, metadata)
466
506
  except TypeError:
467
507
  self.save_handler(checkpoint, filename)
508
+ engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
468
509
 
469
510
  def _setup_checkpoint(self) -> Dict[str, Any]:
470
511
  if self.to_save is not None:
@@ -109,8 +109,17 @@ class ClearMLLogger(BaseLogger):
109
109
  log_handler=WeightsScalarHandler(model)
110
110
  )
111
111
 
112
+ Note:
113
+ :class:`~ignite.handlers.clearml_logger.OutputHandler` can handle
114
+ metrics, state attributes and engine output values of the following format:
115
+ - scalar values (i.e. int, float)
116
+ - 0d and 1d pytorch tensors
117
+ - dicts and list/tuples of previous types
118
+
112
119
  """
113
120
 
121
+ _task: Any
122
+
114
123
  def __init__(self, **kwargs: Any):
115
124
  try:
116
125
  from clearml import Task
@@ -342,9 +351,10 @@ class OutputHandler(BaseOutputHandler):
342
351
  for key, value in metrics.items():
343
352
  if len(key) == 2:
344
353
  logger.clearml_logger.report_scalar(title=key[0], series=key[1], iteration=global_step, value=value)
345
- elif len(key) == 3:
354
+ elif len(key) >= 3:
355
+ series = "/".join(key[2:])
346
356
  logger.clearml_logger.report_scalar(
347
- title=f"{key[0]}/{key[1]}", series=key[2], iteration=global_step, value=value
357
+ title=f"{key[0]}/{key[1]}", series=series, iteration=global_step, value=value
348
358
  )
349
359
 
350
360
 
@@ -815,6 +825,8 @@ class ClearMLSaver(DiskSaver):
815
825
 
816
826
  """
817
827
 
828
+ _task: Any
829
+
818
830
  def __init__(
819
831
  self,
820
832
  logger: Optional[ClearMLLogger] = None,
@@ -850,7 +862,7 @@ class ClearMLSaver(DiskSaver):
850
862
  except ImportError:
851
863
  try:
852
864
  # Backwards-compatibility for legacy Trains SDK
853
- from trains import Task
865
+ from trains import Task # type: ignore[no-redef]
854
866
  except ImportError:
855
867
  raise ModuleNotFoundError(
856
868
  "This contrib module requires clearml to be installed. "
@@ -925,7 +937,7 @@ class ClearMLSaver(DiskSaver):
925
937
  except ImportError:
926
938
  try:
927
939
  # Backwards-compatibility for legacy Trains SDK
928
- from trains.binding.frameworks import WeightsFileHandler
940
+ from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef]
929
941
  except ImportError:
930
942
  raise ModuleNotFoundError(
931
943
  "This contrib module requires clearml to be installed. "
@@ -949,8 +961,8 @@ class ClearMLSaver(DiskSaver):
949
961
  metadata=metadata,
950
962
  )
951
963
 
952
- pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
953
- post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)
964
+ pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback) # type: ignore[arg-type]
965
+ post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback) # type: ignore[arg-type]
954
966
 
955
967
  try:
956
968
  super(ClearMLSaver, self).__call__(checkpoint, filename, metadata)
@@ -7,7 +7,7 @@ import torch
7
7
 
8
8
  from ignite import utils
9
9
  from ignite.engine import Engine, Events
10
- from ignite.handlers import Timer
10
+ from ignite.handlers.timing import Timer
11
11
 
12
12
  MB = 1024.0 * 1024.0
13
13
 
@@ -154,7 +154,7 @@ class FBResearchLogger:
154
154
  if torch.cuda.is_available():
155
155
  cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
156
156
 
157
- current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
157
+ current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
158
158
  iter_avg_time = self.iter_timer.value()
159
159
 
160
160
  eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
@@ -179,7 +179,7 @@ class FastaiLRFinder:
179
179
  lr = self._lr_schedule.get_param()
180
180
  self._history["lr"].append(lr)
181
181
  if trainer.state.iteration == 1:
182
- self._best_loss = loss
182
+ self._best_loss = loss # type: ignore[assignment]
183
183
  else:
184
184
  if smooth_f > 0:
185
185
  loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
@@ -84,6 +84,49 @@ class MLflowLogger(BaseLogger):
84
84
  optimizer=optimizer,
85
85
  param_name='lr' # optional
86
86
  )
87
+
88
+ Note:
89
+ :class:`~ignite.handlers.mlflow_logger.OutputHandler` can handle
90
+ metrics, state attributes and engine output values of the following format:
91
+ - scalar values (i.e. int, float)
92
+ - 0d and 1d pytorch tensors
93
+ - dicts and list/tuples of previous types
94
+
95
+ .. code-block:: python
96
+
97
+ # !!! This is not a runnable code !!!
98
+ evalutator.state.metrics = {
99
+ "a": 0,
100
+ "dict_value": {
101
+ "a": 111,
102
+ "c": {"d": 23, "e": [123, 234]},
103
+ },
104
+ "list_value": [12, 13, {"aa": 33, "bb": 44}],
105
+ "tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
106
+ }
107
+
108
+ handler = OutputHandler(
109
+ tag="tag",
110
+ metric_names="all",
111
+ )
112
+
113
+ handler(evaluator, mlflow_logger, event_name=Events.EPOCH_COMPLETED)
114
+ # Behind it would call `mlflow_logger.log_metrics` on
115
+ # {
116
+ # "tag/a": 0,
117
+ # "tag/dict_value/a": 111,
118
+ # "tag/dict_value/c/d": 23,
119
+ # "tag/dict_value/c/e/0": 123,
120
+ # "tag/dict_value/c/e/1": 234,
121
+ # "tag/list_value/0": 12,
122
+ # "tag/list_value/1": 13,
123
+ # "tag/list_value/2/aa": 33,
124
+ # "tag/list_value/2/bb": 44,
125
+ # "tag/tuple_value/0": 112,
126
+ # "tag/tuple_value/1": 113,
127
+ # "tag/tuple_value/2/aaa": 33,
128
+ # "tag/tuple_value/2/bbb": 44,
129
+ # }
87
130
  """
88
131
 
89
132
  def __init__(self, tracking_uri: Optional[str] = None):
@@ -153,6 +153,13 @@ class NeptuneLogger(BaseLogger):
153
153
  output_transform=lambda loss: {"loss": loss},
154
154
  )
155
155
 
156
+ Note:
157
+ :class:`~ignite.handlers.neptune_logger.OutputHandler` can handle
158
+ metrics, state attributes and engine output values of the following format:
159
+ - scalar values (i.e. int, float)
160
+ - 0d and 1d pytorch tensors
161
+ - dicts and list/tuples of previous types
162
+
156
163
  """
157
164
 
158
165
  def __getattr__(self, attr: Any) -> Any:
@@ -92,6 +92,13 @@ class PolyaxonLogger(BaseLogger):
92
92
  )
93
93
  # to manually end a run
94
94
  plx_logger.close()
95
+
96
+ Note:
97
+ :class:`~ignite.handlers.polyaxon_logger.OutputHandler` can handle
98
+ metrics, state attributes and engine output values of the following format:
99
+ - scalar values (i.e. int, float)
100
+ - 0d and 1d pytorch tensors
101
+ - dicts and list/tuples of previous types
95
102
  """
96
103
 
97
104
  def __init__(self, *args: Any, **kwargs: Any):
@@ -145,6 +145,49 @@ class TensorboardLogger(BaseLogger):
145
145
  output_transform=lambda loss: {"loss": loss}
146
146
  )
147
147
 
148
+ Note:
149
+ :class:`~ignite.handlers.tensorboard_logger.OutputHandler` can handle
150
+ metrics, state attributes and engine output values of the following format:
151
+ - scalar values (i.e. int, float)
152
+ - 0d and 1d pytorch tensors
153
+ - dicts and list/tuples of previous types
154
+
155
+ .. code-block:: python
156
+
157
+ # !!! This is not a runnable code !!!
158
+ evalutator.state.metrics = {
159
+ "a": 0,
160
+ "dict_value": {
161
+ "a": 111,
162
+ "c": {"d": 23, "e": [123, 234]},
163
+ },
164
+ "list_value": [12, 13, {"aa": 33, "bb": 44}],
165
+ "tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
166
+ }
167
+
168
+ handler = OutputHandler(
169
+ tag="tag",
170
+ metric_names="all",
171
+ )
172
+
173
+ handler(evaluator, tb_logger, event_name=Events.EPOCH_COMPLETED)
174
+ # Behind it would call `tb_logger.writer.add_scalar` on
175
+ # {
176
+ # "tag/a": 0,
177
+ # "tag/dict_value/a": 111,
178
+ # "tag/dict_value/c/d": 23,
179
+ # "tag/dict_value/c/e/0": 123,
180
+ # "tag/dict_value/c/e/1": 234,
181
+ # "tag/list_value/0": 12,
182
+ # "tag/list_value/1": 13,
183
+ # "tag/list_value/2/aa": 33,
184
+ # "tag/list_value/2/bb": 44,
185
+ # "tag/tuple_value/0": 112,
186
+ # "tag/tuple_value/1": 113,
187
+ # "tag/tuple_value/2/aaa": 33,
188
+ # "tag/tuple_value/2/bbb": 44,
189
+ # }
190
+
148
191
  """
149
192
 
150
193
  def __init__(self, *args: Any, **kwargs: Any):
@@ -200,7 +200,7 @@ class ProgressBar(BaseLogger):
200
200
  Accepted output value types are numbers, 0d and 1d torch tensors and strings.
201
201
 
202
202
  """
203
- desc = self.tqdm_kwargs.get("desc", None)
203
+ desc = self.tqdm_kwargs.get("desc", "")
204
204
 
205
205
  if event_name not in engine._allowed_events:
206
206
  raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")
@@ -298,8 +298,7 @@ class _OutputHandler(BaseOutputHandler):
298
298
  rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
299
299
  metrics = OrderedDict()
300
300
  for key, value in rendered_metrics.items():
301
- key = "_".join(key[1:]) # tqdm has tag as description
302
-
301
+ key = "_".join(key[1:]) # skip tag as tqdm has tag as description
303
302
  metrics[key] = value
304
303
 
305
304
  if metrics:
@@ -1,7 +1,7 @@
1
1
  """Visdom logger and its helper handlers."""
2
2
 
3
3
  import os
4
- from typing import Any, Callable, cast, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
@@ -137,6 +137,13 @@ class VisdomLogger(BaseLogger):
137
137
  output_transform=lambda loss: {"loss": loss}
138
138
  )
139
139
 
140
+ Note:
141
+ :class:`~ignite.handlers.visdom_logger.OutputHandler` can handle
142
+ metrics, state attributes and engine output values of the following format:
143
+ - scalar values (i.e. int, float)
144
+ - 0d and 1d pytorch tensors
145
+ - dicts and list/tuples of previous types
146
+
140
147
  .. versionchanged:: 0.4.7
141
148
  accepts an optional list of `state_attributes`
142
149
  """
@@ -172,7 +179,7 @@ class VisdomLogger(BaseLogger):
172
179
  )
173
180
 
174
181
  if server is None:
175
- server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost"))
182
+ server = os.environ.get("VISDOM_SERVER_URL", "localhost")
176
183
 
177
184
  if port is None:
178
185
  port = int(os.environ.get("VISDOM_PORT", 8097))
@@ -26,7 +26,7 @@ class WandBLogger(BaseLogger):
26
26
  Args:
27
27
  args: Positional arguments accepted by `wandb.init`.
28
28
  kwargs: Keyword arguments accepted by `wandb.init`.
29
- Please see `wandb.init <https://docs.wandb.ai/ref/python/init>`_ for documentation of possible parameters.
29
+ Please see `wandb.init <https://docs.wandb.ai/ref/python/sdk/functions/init/>`_ for documentation of possible parameters.
30
30
 
31
31
  Examples:
32
32
  .. code-block:: python
@@ -120,6 +120,12 @@ class WandBLogger(BaseLogger):
120
120
  )
121
121
  evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
122
122
 
123
+ Note:
124
+ :class:`~ignite.handlers.wandb_logger.OutputHandler` can handle
125
+ metrics, state attributes and engine output values of the following format:
126
+ - scalar values (i.e. int, float)
127
+ - 0d and 1d pytorch tensors
128
+ - dicts and list/tuples of previous types
123
129
 
124
130
  """
125
131
 
@@ -86,7 +86,7 @@ class CalinskiHarabaszScore(_ClusteringMetricBase):
86
86
 
87
87
  .. testoutput::
88
88
 
89
- 5.733936
89
+ 5.733935832977295
90
90
 
91
91
  .. versionadded:: 0.5.2
92
92
  """