pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/base.py +7 -0
- ignite/distributed/comp_models/horovod.py +35 -5
- ignite/distributed/comp_models/native.py +8 -4
- ignite/distributed/comp_models/xla.py +5 -0
- ignite/distributed/launcher.py +4 -8
- ignite/distributed/utils.py +12 -4
- ignite/engine/__init__.py +9 -9
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +38 -14
- ignite/engine/events.py +2 -1
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/base_logger.py +47 -12
- ignite/handlers/checkpoint.py +46 -5
- ignite/handlers/clearml_logger.py +16 -4
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +9 -9
- ignite/handlers/mlflow_logger.py +43 -0
- ignite/handlers/neptune_logger.py +8 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/polyaxon_logger.py +7 -0
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/tensorboard_logger.py +43 -0
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +9 -5
- ignite/handlers/visdom_logger.py +10 -3
- ignite/handlers/wandb_logger.py +16 -9
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
- ignite/metrics/clustering/silhouette_score.py +1 -1
- ignite/metrics/fbeta.py +17 -8
- ignite/metrics/gan/fid.py +3 -3
- ignite/metrics/js_divergence.py +1 -1
- ignite/metrics/maximum_mean_discrepancy.py +1 -1
- ignite/metrics/metric.py +3 -0
- ignite/metrics/nlp/bleu.py +8 -6
- ignite/metrics/nlp/rouge.py +9 -6
- ignite/metrics/nlp/utils.py +1 -1
- ignite/metrics/precision_recall_curve.py +5 -5
- ignite/metrics/regression/_base.py +4 -0
- ignite/metrics/regression/fractional_bias.py +1 -1
- ignite/metrics/roc_auc.py +4 -3
- ignite/metrics/ssim.py +63 -21
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
- pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
ignite/distributed/utils.py
CHANGED
|
@@ -2,10 +2,9 @@ import itertools
|
|
|
2
2
|
import socket
|
|
3
3
|
from contextlib import contextmanager
|
|
4
4
|
from functools import wraps
|
|
5
|
-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
|
|
5
|
+
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Union
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
from torch import distributed as dist
|
|
9
8
|
|
|
10
9
|
from ignite.distributed.comp_models import (
|
|
11
10
|
_SerialModel,
|
|
@@ -384,7 +383,7 @@ def all_gather_tensors_with_shapes(
|
|
|
384
383
|
if isinstance(group, list) and all(isinstance(item, int) for item in group):
|
|
385
384
|
group = _model.new_group(group)
|
|
386
385
|
|
|
387
|
-
if
|
|
386
|
+
if _rank_not_in_group(group):
|
|
388
387
|
return [tensor]
|
|
389
388
|
|
|
390
389
|
max_shape = torch.tensor(shapes).amax(dim=0)
|
|
@@ -392,7 +391,7 @@ def all_gather_tensors_with_shapes(
|
|
|
392
391
|
padded_tensor = torch.nn.functional.pad(
|
|
393
392
|
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
|
|
394
393
|
)
|
|
395
|
-
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
|
|
394
|
+
all_padded_tensors: torch.Tensor = cast(torch.Tensor, _model.all_gather(padded_tensor, group=group))
|
|
396
395
|
return [
|
|
397
396
|
all_padded_tensors[
|
|
398
397
|
[
|
|
@@ -731,3 +730,12 @@ def one_rank_first(rank: int = 0, local: bool = False) -> Any:
|
|
|
731
730
|
|
|
732
731
|
if current_rank == rank:
|
|
733
732
|
barrier()
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def _rank_not_in_group(group: Optional[Union[Any, List[int]]]) -> bool:
|
|
736
|
+
"""Check if the current process's rank is not in a given group."""
|
|
737
|
+
if group is None:
|
|
738
|
+
return False
|
|
739
|
+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
|
|
740
|
+
group = new_group(group)
|
|
741
|
+
return _model._rank_not_in_group(group)
|
ignite/engine/__init__.py
CHANGED
|
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
|
|
|
133
133
|
prepare_batch: Callable = _prepare_batch,
|
|
134
134
|
model_transform: Callable[[Any], Any] = lambda output: output,
|
|
135
135
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
136
|
-
scaler: Optional["torch.
|
|
136
|
+
scaler: Optional["torch.amp.GradScaler"] = None,
|
|
137
137
|
gradient_accumulation_steps: int = 1,
|
|
138
138
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
139
139
|
) -> Callable:
|
|
140
|
-
"""Factory function for supervised training using ``torch.
|
|
140
|
+
"""Factory function for supervised training using ``torch.amp``.
|
|
141
141
|
|
|
142
142
|
Args:
|
|
143
143
|
model: the model to train.
|
|
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
|
|
|
170
170
|
model = ...
|
|
171
171
|
optimizer = ...
|
|
172
172
|
loss_fn = ...
|
|
173
|
-
scaler = torch.
|
|
173
|
+
scaler = torch.amp.GradScaler('cuda', 2**10)
|
|
174
174
|
|
|
175
175
|
update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
|
|
176
176
|
trainer = Engine(update_fn)
|
|
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
|
|
|
393
393
|
|
|
394
394
|
|
|
395
395
|
def _check_arg(
|
|
396
|
-
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.
|
|
397
|
-
) -> Tuple[Optional[str], Optional["torch.
|
|
396
|
+
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
|
|
397
|
+
) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
|
|
398
398
|
"""Checking tpu, mps, amp and GradScaler instance combinations."""
|
|
399
399
|
if on_mps and amp_mode:
|
|
400
400
|
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
|
|
@@ -410,9 +410,9 @@ def _check_arg(
|
|
|
410
410
|
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
|
|
411
411
|
elif amp_mode == "amp" and isinstance(scaler, bool):
|
|
412
412
|
try:
|
|
413
|
-
from torch.
|
|
413
|
+
from torch.amp import GradScaler
|
|
414
414
|
except ImportError:
|
|
415
|
-
raise ImportError("Please install torch>=
|
|
415
|
+
raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
|
|
416
416
|
scaler = GradScaler(enabled=True)
|
|
417
417
|
|
|
418
418
|
if on_tpu:
|
|
@@ -434,7 +434,7 @@ def create_supervised_trainer(
|
|
|
434
434
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
435
435
|
deterministic: bool = False,
|
|
436
436
|
amp_mode: Optional[str] = None,
|
|
437
|
-
scaler: Union[bool, "torch.
|
|
437
|
+
scaler: Union[bool, "torch.amp.GradScaler"] = False,
|
|
438
438
|
gradient_accumulation_steps: int = 1,
|
|
439
439
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
440
440
|
) -> Engine:
|
|
@@ -459,7 +459,7 @@ def create_supervised_trainer(
|
|
|
459
459
|
:class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
|
|
460
460
|
(default: False).
|
|
461
461
|
amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
|
|
462
|
-
`torch.
|
|
462
|
+
`torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
|
|
463
463
|
using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
|
|
464
464
|
scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
|
|
465
465
|
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
|
ignite/engine/deterministic.py
CHANGED
ignite/engine/engine.py
CHANGED
|
@@ -148,7 +148,7 @@ class Engine(Serializable):
|
|
|
148
148
|
self.should_interrupt = False
|
|
149
149
|
self.state = State()
|
|
150
150
|
self._state_dict_user_keys: List[str] = []
|
|
151
|
-
self._allowed_events: List[EventEnum] = []
|
|
151
|
+
self._allowed_events: List[Union[str, EventEnum]] = []
|
|
152
152
|
|
|
153
153
|
self._dataloader_iter: Optional[Iterator[Any]] = None
|
|
154
154
|
self._init_iter: Optional[int] = None
|
|
@@ -163,9 +163,7 @@ class Engine(Serializable):
|
|
|
163
163
|
# generator provided by self._internal_run_as_gen
|
|
164
164
|
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
|
|
165
165
|
|
|
166
|
-
def register_events(
|
|
167
|
-
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
|
|
168
|
-
) -> None:
|
|
166
|
+
def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
|
|
169
167
|
"""Add events that can be fired.
|
|
170
168
|
|
|
171
169
|
Registering an event will let the user trigger these events at any point.
|
|
@@ -249,6 +247,17 @@ class Engine(Serializable):
|
|
|
249
247
|
# we need to update state attributes associated with new custom events
|
|
250
248
|
self.state._update_attrs()
|
|
251
249
|
|
|
250
|
+
def has_registered_events(self, event: Any) -> bool:
|
|
251
|
+
"""Check whether engine has a registered event.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
event: Event to check for registration.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
bool: True if the event is registered, False otherwise.
|
|
258
|
+
"""
|
|
259
|
+
return event in self._allowed_events
|
|
260
|
+
|
|
252
261
|
def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
|
|
253
262
|
# signature of the following wrapper will be inspected during registering to check if engine is necessary
|
|
254
263
|
# we have to build a wrapper with relevant signature : solution is functools.wraps
|
|
@@ -328,7 +337,7 @@ class Engine(Serializable):
|
|
|
328
337
|
|
|
329
338
|
try:
|
|
330
339
|
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
|
|
331
|
-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
|
|
340
|
+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
|
|
332
341
|
except ValueError:
|
|
333
342
|
_check_signature(handler, "handler", *(event_args + args), **kwargs)
|
|
334
343
|
self._event_handlers[event_name].append((handler, args, kwargs))
|
|
@@ -432,7 +441,15 @@ class Engine(Serializable):
|
|
|
432
441
|
self.last_event_name = event_name
|
|
433
442
|
for func, args, kwargs in self._event_handlers[event_name]:
|
|
434
443
|
kwargs.update(event_kwargs)
|
|
435
|
-
|
|
444
|
+
if args and isinstance(args[0], weakref.ref):
|
|
445
|
+
resolved_engine = args[0]()
|
|
446
|
+
if resolved_engine is None:
|
|
447
|
+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
|
|
448
|
+
first, others = ((resolved_engine,), args[1:])
|
|
449
|
+
else:
|
|
450
|
+
# metrics do not provide engine when registered
|
|
451
|
+
first, others = (tuple(), args)
|
|
452
|
+
|
|
436
453
|
func(*first, *(event_args + others), **kwargs)
|
|
437
454
|
|
|
438
455
|
def fire_event(self, event_name: Any) -> None:
|
|
@@ -970,9 +987,9 @@ class Engine(Serializable):
|
|
|
970
987
|
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
|
|
971
988
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
972
989
|
self._init_timers(self.state)
|
|
990
|
+
start_time = time.time()
|
|
973
991
|
try:
|
|
974
992
|
try:
|
|
975
|
-
start_time = time.time()
|
|
976
993
|
self._fire_event(Events.STARTED)
|
|
977
994
|
yield from self._maybe_terminate_or_interrupt()
|
|
978
995
|
|
|
@@ -991,7 +1008,7 @@ class Engine(Serializable):
|
|
|
991
1008
|
# time is available for handlers but must be updated after fire
|
|
992
1009
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
993
1010
|
|
|
994
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1011
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
995
1012
|
handlers_start_time = time.time()
|
|
996
1013
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
997
1014
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1024,7 +1041,7 @@ class Engine(Serializable):
|
|
|
1024
1041
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1025
1042
|
|
|
1026
1043
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1027
|
-
if self.should_terminate != "skip_completed":
|
|
1044
|
+
if self.should_terminate != "skip_completed":
|
|
1028
1045
|
handlers_start_time = time.time()
|
|
1029
1046
|
self._fire_event(Events.COMPLETED)
|
|
1030
1047
|
time_taken += time.time() - handlers_start_time
|
|
@@ -1069,7 +1086,7 @@ class Engine(Serializable):
|
|
|
1069
1086
|
)
|
|
1070
1087
|
|
|
1071
1088
|
while True:
|
|
1072
|
-
self.state.batch =
|
|
1089
|
+
self.state.batch = None
|
|
1073
1090
|
|
|
1074
1091
|
try:
|
|
1075
1092
|
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
|
|
@@ -1081,6 +1098,9 @@ class Engine(Serializable):
|
|
|
1081
1098
|
yield from self._maybe_terminate_or_interrupt()
|
|
1082
1099
|
|
|
1083
1100
|
self.state.batch = next(self._dataloader_iter)
|
|
1101
|
+
# We on purpose reset state.output here as for iterable dataloaders
|
|
1102
|
+
# we accidentally can remove it when one epoch is completed.
|
|
1103
|
+
self.state.output = None
|
|
1084
1104
|
|
|
1085
1105
|
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
|
|
1086
1106
|
# if no data was provided to engine.run(data=None, ...)
|
|
@@ -1167,9 +1187,9 @@ class Engine(Serializable):
|
|
|
1167
1187
|
# internal_run without generator for BC
|
|
1168
1188
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
1169
1189
|
self._init_timers(self.state)
|
|
1190
|
+
start_time = time.time()
|
|
1170
1191
|
try:
|
|
1171
1192
|
try:
|
|
1172
|
-
start_time = time.time()
|
|
1173
1193
|
self._fire_event(Events.STARTED)
|
|
1174
1194
|
self._maybe_terminate_legacy()
|
|
1175
1195
|
|
|
@@ -1188,7 +1208,7 @@ class Engine(Serializable):
|
|
|
1188
1208
|
# time is available for handlers but must be updated after fire
|
|
1189
1209
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1190
1210
|
|
|
1191
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1211
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1192
1212
|
handlers_start_time = time.time()
|
|
1193
1213
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1194
1214
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1221,7 +1241,7 @@ class Engine(Serializable):
|
|
|
1221
1241
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1222
1242
|
|
|
1223
1243
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1224
|
-
if self.should_terminate != "skip_completed":
|
|
1244
|
+
if self.should_terminate != "skip_completed":
|
|
1225
1245
|
handlers_start_time = time.time()
|
|
1226
1246
|
self._fire_event(Events.COMPLETED)
|
|
1227
1247
|
time_taken += time.time() - handlers_start_time
|
|
@@ -1254,7 +1274,7 @@ class Engine(Serializable):
|
|
|
1254
1274
|
)
|
|
1255
1275
|
|
|
1256
1276
|
while True:
|
|
1257
|
-
self.state.batch =
|
|
1277
|
+
self.state.batch = None
|
|
1258
1278
|
try:
|
|
1259
1279
|
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
|
|
1260
1280
|
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
|
|
@@ -1265,6 +1285,10 @@ class Engine(Serializable):
|
|
|
1265
1285
|
self._maybe_terminate_legacy()
|
|
1266
1286
|
|
|
1267
1287
|
self.state.batch = next(self._dataloader_iter)
|
|
1288
|
+
# We on purpose reset state.output here as for iterable dataloaders
|
|
1289
|
+
# we accidentally can remove it when one epoch is completed.
|
|
1290
|
+
self.state.output = None
|
|
1291
|
+
|
|
1268
1292
|
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
|
|
1269
1293
|
# if no data was provided to engine.run(data=None, ...)
|
|
1270
1294
|
if self.state.dataloader is not None:
|
ignite/engine/events.py
CHANGED
|
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
|
|
|
91
91
|
raise ValueError("Argument every should be integer and greater than zero")
|
|
92
92
|
|
|
93
93
|
if once is not None:
|
|
94
|
-
c1 = isinstance(once,
|
|
94
|
+
c1 = isinstance(once, int) and once > 0
|
|
95
95
|
c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
|
|
96
96
|
if not (c1 or c2):
|
|
97
97
|
raise ValueError(
|
|
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
|
|
|
240
240
|
def __new__(cls, value: str) -> "EventEnum":
|
|
241
241
|
obj = CallableEventWithFilter.__new__(cls)
|
|
242
242
|
obj._value_ = value
|
|
243
|
+
# pyrefly: ignore [bad-return]
|
|
243
244
|
return obj
|
|
244
245
|
|
|
245
246
|
|
ignite/handlers/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
|
|
|
6
6
|
from ignite.handlers.clearml_logger import ClearMLLogger
|
|
7
7
|
from ignite.handlers.early_stopping import EarlyStopping
|
|
8
8
|
from ignite.handlers.ema_handler import EMAHandler
|
|
9
|
+
from ignite.handlers.fbresearch_logger import FBResearchLogger
|
|
9
10
|
from ignite.handlers.lr_finder import FastaiLRFinder
|
|
10
11
|
from ignite.handlers.mlflow_logger import MLflowLogger
|
|
11
12
|
from ignite.handlers.neptune_logger import NeptuneLogger
|
|
@@ -64,6 +65,7 @@ __all__ = [
|
|
|
64
65
|
"CyclicalScheduler",
|
|
65
66
|
"create_lr_scheduler_with_warmup",
|
|
66
67
|
"FastaiLRFinder",
|
|
68
|
+
"FBResearchLogger",
|
|
67
69
|
"EMAHandler",
|
|
68
70
|
"BasicTimeProfiler",
|
|
69
71
|
"HandlersTimeProfiler",
|
ignite/handlers/base_logger.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base logger and its helper handlers."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as collections
|
|
3
4
|
import numbers
|
|
4
5
|
import warnings
|
|
5
6
|
from abc import ABCMeta, abstractmethod
|
|
@@ -145,30 +146,64 @@ class BaseOutputHandler(BaseHandler):
|
|
|
145
146
|
|
|
146
147
|
metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
|
|
147
148
|
|
|
148
|
-
def
|
|
149
|
-
|
|
149
|
+
def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
|
|
150
|
+
if parent_key is None or isinstance(parent_key, str):
|
|
151
|
+
return (parent_key,) + args
|
|
152
|
+
return parent_key + args
|
|
150
153
|
|
|
151
|
-
def
|
|
152
|
-
|
|
154
|
+
def key_str_fn(parent_key: str, *args: str) -> str:
|
|
155
|
+
args_str = "/".join(args)
|
|
156
|
+
return f"{parent_key}/{args_str}"
|
|
153
157
|
|
|
154
|
-
|
|
158
|
+
key_fn = key_tuple_fn if key_tuple else key_str_fn
|
|
155
159
|
|
|
156
|
-
|
|
160
|
+
def handle_value_fn(
|
|
161
|
+
value: Union[str, int, float, numbers.Number, torch.Tensor]
|
|
162
|
+
) -> Union[None, str, float, numbers.Number]:
|
|
157
163
|
if isinstance(value, numbers.Number):
|
|
158
|
-
|
|
164
|
+
return value
|
|
159
165
|
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
|
|
160
|
-
|
|
161
|
-
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
|
|
162
|
-
for i, v in enumerate(value):
|
|
163
|
-
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
|
|
166
|
+
return value.item()
|
|
164
167
|
else:
|
|
165
168
|
if isinstance(value, str) and log_text:
|
|
166
|
-
|
|
169
|
+
return value
|
|
167
170
|
else:
|
|
168
171
|
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag)
|
|
169
175
|
return metrics_state_attrs_dict
|
|
170
176
|
|
|
171
177
|
|
|
178
|
+
def _flatten_dict(
|
|
179
|
+
in_dict: collections.Mapping,
|
|
180
|
+
key_fn: Callable,
|
|
181
|
+
value_fn: Callable,
|
|
182
|
+
parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
|
|
183
|
+
) -> Dict:
|
|
184
|
+
items = {}
|
|
185
|
+
for key, value in in_dict.items():
|
|
186
|
+
new_key = key_fn(parent_key, key)
|
|
187
|
+
if isinstance(value, collections.Mapping):
|
|
188
|
+
items.update(_flatten_dict(value, key_fn, value_fn, new_key))
|
|
189
|
+
elif any(
|
|
190
|
+
[
|
|
191
|
+
isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple
|
|
192
|
+
not isinstance(value, str) and isinstance(value, collections.Sequence),
|
|
193
|
+
]
|
|
194
|
+
):
|
|
195
|
+
for i, item in enumerate(value):
|
|
196
|
+
items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key))
|
|
197
|
+
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
|
|
198
|
+
for i, item in enumerate(value):
|
|
199
|
+
items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key))
|
|
200
|
+
else:
|
|
201
|
+
new_value = value_fn(value)
|
|
202
|
+
if new_value is not None:
|
|
203
|
+
items[new_key] = new_value
|
|
204
|
+
return items
|
|
205
|
+
|
|
206
|
+
|
|
172
207
|
class BaseWeightsScalarHandler(BaseWeightsHandler):
|
|
173
208
|
"""
|
|
174
209
|
Helper handler to log model's weights or gradients as scalars.
|
ignite/handlers/checkpoint.py
CHANGED
|
@@ -21,10 +21,21 @@ else:
|
|
|
21
21
|
|
|
22
22
|
import ignite.distributed as idist
|
|
23
23
|
from ignite.base import Serializable
|
|
24
|
-
from ignite.engine import Engine, Events
|
|
24
|
+
from ignite.engine import Engine, Events, EventEnum
|
|
25
25
|
from ignite.utils import _tree_apply2, _tree_map
|
|
26
26
|
|
|
27
|
-
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
|
|
27
|
+
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CheckpointEvents(EventEnum):
|
|
31
|
+
"""Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
|
|
32
|
+
|
|
33
|
+
- SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
|
|
34
|
+
|
|
35
|
+
.. versionadded:: 0.5.3
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
SAVED_CHECKPOINT = "saved_checkpoint"
|
|
28
39
|
|
|
29
40
|
|
|
30
41
|
class BaseSaveHandler(metaclass=ABCMeta):
|
|
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
|
|
|
264
275
|
to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
|
|
265
276
|
)
|
|
266
277
|
|
|
278
|
+
Respond to checkpoint events:
|
|
279
|
+
|
|
280
|
+
.. code-block:: python
|
|
281
|
+
|
|
282
|
+
from ignite.handlers import Checkpoint
|
|
283
|
+
from ignite.engine import Engine, Events
|
|
284
|
+
|
|
285
|
+
checkpoint_handler = Checkpoint(
|
|
286
|
+
{'model': model, 'optimizer': optimizer},
|
|
287
|
+
save_dir,
|
|
288
|
+
n_saved=2
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
|
|
292
|
+
def on_checkpoint_saved(engine):
|
|
293
|
+
print(f"Checkpoint saved at epoch {engine.state.epoch}")
|
|
294
|
+
|
|
295
|
+
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
|
|
296
|
+
|
|
297
|
+
Attributes:
|
|
298
|
+
SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
|
|
299
|
+
:class:`~ignite.handlers.checkpoint.CheckpointEvents`.
|
|
300
|
+
|
|
267
301
|
.. versionchanged:: 0.4.3
|
|
268
302
|
|
|
269
303
|
- Checkpoint can save model with same filename.
|
|
@@ -274,9 +308,14 @@ class Checkpoint(Serializable):
|
|
|
274
308
|
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
|
|
275
309
|
- `save_handler` automatically saves to disk if path to directory is provided.
|
|
276
310
|
- `save_on_rank` saves objects on this rank in a distributed configuration.
|
|
311
|
+
|
|
312
|
+
.. versionchanged:: 0.5.3
|
|
313
|
+
|
|
314
|
+
- Added ``SAVED_CHECKPOINT`` class attribute.
|
|
277
315
|
"""
|
|
278
316
|
|
|
279
|
-
|
|
317
|
+
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
|
|
318
|
+
Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
|
|
280
319
|
_state_dict_all_req_keys = ("_saved",)
|
|
281
320
|
|
|
282
321
|
def __init__(
|
|
@@ -284,7 +323,7 @@ class Checkpoint(Serializable):
|
|
|
284
323
|
to_save: Mapping,
|
|
285
324
|
save_handler: Union[str, Path, Callable, BaseSaveHandler],
|
|
286
325
|
filename_prefix: str = "",
|
|
287
|
-
score_function: Optional[Callable] = None,
|
|
326
|
+
score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
|
|
288
327
|
score_name: Optional[str] = None,
|
|
289
328
|
n_saved: Union[int, None] = 1,
|
|
290
329
|
global_step_transform: Optional[Callable] = None,
|
|
@@ -400,6 +439,8 @@ class Checkpoint(Serializable):
|
|
|
400
439
|
return new > self._saved[0].priority
|
|
401
440
|
|
|
402
441
|
def __call__(self, engine: Engine) -> None:
|
|
442
|
+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
|
|
443
|
+
engine.register_events(*CheckpointEvents)
|
|
403
444
|
global_step = None
|
|
404
445
|
if self.global_step_transform is not None:
|
|
405
446
|
global_step = self.global_step_transform(engine, engine.last_event_name)
|
|
@@ -460,11 +501,11 @@ class Checkpoint(Serializable):
|
|
|
460
501
|
if self.include_self:
|
|
461
502
|
# Now that we've updated _saved, we can add our own state_dict.
|
|
462
503
|
checkpoint["checkpointer"] = self.state_dict()
|
|
463
|
-
|
|
464
504
|
try:
|
|
465
505
|
self.save_handler(checkpoint, filename, metadata)
|
|
466
506
|
except TypeError:
|
|
467
507
|
self.save_handler(checkpoint, filename)
|
|
508
|
+
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
|
|
468
509
|
|
|
469
510
|
def _setup_checkpoint(self) -> Dict[str, Any]:
|
|
470
511
|
if self.to_save is not None:
|
|
@@ -109,8 +109,17 @@ class ClearMLLogger(BaseLogger):
|
|
|
109
109
|
log_handler=WeightsScalarHandler(model)
|
|
110
110
|
)
|
|
111
111
|
|
|
112
|
+
Note:
|
|
113
|
+
:class:`~ignite.handlers.clearml_logger.OutputHandler` can handle
|
|
114
|
+
metrics, state attributes and engine output values of the following format:
|
|
115
|
+
- scalar values (i.e. int, float)
|
|
116
|
+
- 0d and 1d pytorch tensors
|
|
117
|
+
- dicts and list/tuples of previous types
|
|
118
|
+
|
|
112
119
|
"""
|
|
113
120
|
|
|
121
|
+
_task: Any
|
|
122
|
+
|
|
114
123
|
def __init__(self, **kwargs: Any):
|
|
115
124
|
try:
|
|
116
125
|
from clearml import Task
|
|
@@ -342,9 +351,10 @@ class OutputHandler(BaseOutputHandler):
|
|
|
342
351
|
for key, value in metrics.items():
|
|
343
352
|
if len(key) == 2:
|
|
344
353
|
logger.clearml_logger.report_scalar(title=key[0], series=key[1], iteration=global_step, value=value)
|
|
345
|
-
elif len(key)
|
|
354
|
+
elif len(key) >= 3:
|
|
355
|
+
series = "/".join(key[2:])
|
|
346
356
|
logger.clearml_logger.report_scalar(
|
|
347
|
-
title=f"{key[0]}/{key[1]}", series=
|
|
357
|
+
title=f"{key[0]}/{key[1]}", series=series, iteration=global_step, value=value
|
|
348
358
|
)
|
|
349
359
|
|
|
350
360
|
|
|
@@ -815,6 +825,8 @@ class ClearMLSaver(DiskSaver):
|
|
|
815
825
|
|
|
816
826
|
"""
|
|
817
827
|
|
|
828
|
+
_task: Any
|
|
829
|
+
|
|
818
830
|
def __init__(
|
|
819
831
|
self,
|
|
820
832
|
logger: Optional[ClearMLLogger] = None,
|
|
@@ -949,8 +961,8 @@ class ClearMLSaver(DiskSaver):
|
|
|
949
961
|
metadata=metadata,
|
|
950
962
|
)
|
|
951
963
|
|
|
952
|
-
pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
|
|
953
|
-
post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)
|
|
964
|
+
pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback) # type: ignore[arg-type]
|
|
965
|
+
post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback) # type: ignore[arg-type]
|
|
954
966
|
|
|
955
967
|
try:
|
|
956
968
|
super(ClearMLSaver, self).__call__(checkpoint, filename, metadata)
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from ignite import utils
|
|
9
9
|
from ignite.engine import Engine, Events
|
|
10
|
-
from ignite.handlers import Timer
|
|
10
|
+
from ignite.handlers.timing import Timer
|
|
11
11
|
|
|
12
12
|
MB = 1024.0 * 1024.0
|
|
13
13
|
|
|
@@ -154,7 +154,7 @@ class FBResearchLogger:
|
|
|
154
154
|
if torch.cuda.is_available():
|
|
155
155
|
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
|
|
156
156
|
|
|
157
|
-
current_iter = engine.state.iteration %
|
|
157
|
+
current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
|
|
158
158
|
iter_avg_time = self.iter_timer.value()
|
|
159
159
|
|
|
160
160
|
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
|
ignite/handlers/lr_finder.py
CHANGED
|
@@ -98,15 +98,18 @@ class FastaiLRFinder:
|
|
|
98
98
|
self._best_loss = None
|
|
99
99
|
self._diverge_flag = False
|
|
100
100
|
|
|
101
|
+
assert trainer.state.epoch_length is not None
|
|
102
|
+
assert trainer.state.max_epochs is not None
|
|
103
|
+
|
|
101
104
|
# attach LRScheduler to trainer.
|
|
102
105
|
if num_iter is None:
|
|
103
106
|
num_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
104
107
|
else:
|
|
105
|
-
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
108
|
+
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
106
109
|
if max_iter < num_iter:
|
|
107
110
|
max_iter = num_iter
|
|
108
111
|
trainer.state.max_iters = num_iter
|
|
109
|
-
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
112
|
+
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
110
113
|
|
|
111
114
|
if not trainer.has_event_handler(self._reached_num_iterations):
|
|
112
115
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
|
|
@@ -178,17 +181,14 @@ class FastaiLRFinder:
|
|
|
178
181
|
loss = idist.all_reduce(loss)
|
|
179
182
|
lr = self._lr_schedule.get_param()
|
|
180
183
|
self._history["lr"].append(lr)
|
|
181
|
-
if trainer.state.iteration
|
|
184
|
+
if trainer.state.iteration != 1 and smooth_f > 0:
|
|
185
|
+
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
186
|
+
if self._best_loss is None or loss < self._best_loss:
|
|
182
187
|
self._best_loss = loss
|
|
183
|
-
else:
|
|
184
|
-
if smooth_f > 0:
|
|
185
|
-
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
186
|
-
if loss < self._best_loss:
|
|
187
|
-
self._best_loss = loss
|
|
188
188
|
self._history["loss"].append(loss)
|
|
189
189
|
|
|
190
190
|
# Check if the loss has diverged; if it has, stop the trainer
|
|
191
|
-
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
191
|
+
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
192
192
|
self._diverge_flag = True
|
|
193
193
|
self.logger.info("Stopping early, the loss has diverged")
|
|
194
194
|
trainer.terminate()
|
ignite/handlers/mlflow_logger.py
CHANGED
|
@@ -84,6 +84,49 @@ class MLflowLogger(BaseLogger):
|
|
|
84
84
|
optimizer=optimizer,
|
|
85
85
|
param_name='lr' # optional
|
|
86
86
|
)
|
|
87
|
+
|
|
88
|
+
Note:
|
|
89
|
+
:class:`~ignite.handlers.mlflow_logger.OutputHandler` can handle
|
|
90
|
+
metrics, state attributes and engine output values of the following format:
|
|
91
|
+
- scalar values (i.e. int, float)
|
|
92
|
+
- 0d and 1d pytorch tensors
|
|
93
|
+
- dicts and list/tuples of previous types
|
|
94
|
+
|
|
95
|
+
.. code-block:: python
|
|
96
|
+
|
|
97
|
+
# !!! This is not a runnable code !!!
|
|
98
|
+
evalutator.state.metrics = {
|
|
99
|
+
"a": 0,
|
|
100
|
+
"dict_value": {
|
|
101
|
+
"a": 111,
|
|
102
|
+
"c": {"d": 23, "e": [123, 234]},
|
|
103
|
+
},
|
|
104
|
+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
|
|
105
|
+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
handler = OutputHandler(
|
|
109
|
+
tag="tag",
|
|
110
|
+
metric_names="all",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
handler(evaluator, mlflow_logger, event_name=Events.EPOCH_COMPLETED)
|
|
114
|
+
# Behind it would call `mlflow_logger.log_metrics` on
|
|
115
|
+
# {
|
|
116
|
+
# "tag/a": 0,
|
|
117
|
+
# "tag/dict_value/a": 111,
|
|
118
|
+
# "tag/dict_value/c/d": 23,
|
|
119
|
+
# "tag/dict_value/c/e/0": 123,
|
|
120
|
+
# "tag/dict_value/c/e/1": 234,
|
|
121
|
+
# "tag/list_value/0": 12,
|
|
122
|
+
# "tag/list_value/1": 13,
|
|
123
|
+
# "tag/list_value/2/aa": 33,
|
|
124
|
+
# "tag/list_value/2/bb": 44,
|
|
125
|
+
# "tag/tuple_value/0": 112,
|
|
126
|
+
# "tag/tuple_value/1": 113,
|
|
127
|
+
# "tag/tuple_value/2/aaa": 33,
|
|
128
|
+
# "tag/tuple_value/2/bbb": 44,
|
|
129
|
+
# }
|
|
87
130
|
"""
|
|
88
131
|
|
|
89
132
|
def __init__(self, tracking_uri: Optional[str] = None):
|