pytorch-ignite 0.6.0.dev20250324__py3-none-any.whl → 0.6.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/distributed/comp_models/native.py +1 -1
- ignite/engine/__init__.py +9 -9
- ignite/engine/engine.py +30 -4
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/base_logger.py +47 -12
- ignite/handlers/checkpoint.py +44 -3
- ignite/handlers/clearml_logger.py +18 -6
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +1 -1
- ignite/handlers/mlflow_logger.py +43 -0
- ignite/handlers/neptune_logger.py +7 -0
- ignite/handlers/polyaxon_logger.py +7 -0
- ignite/handlers/tensorboard_logger.py +43 -0
- ignite/handlers/tqdm_logger.py +2 -3
- ignite/handlers/visdom_logger.py +9 -2
- ignite/handlers/wandb_logger.py +7 -1
- ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
- ignite/metrics/clustering/silhouette_score.py +1 -1
- ignite/metrics/fbeta.py +17 -8
- ignite/metrics/gan/fid.py +3 -3
- ignite/metrics/js_divergence.py +1 -1
- ignite/metrics/maximum_mean_discrepancy.py +1 -1
- ignite/metrics/metric.py +2 -0
- ignite/metrics/nlp/bleu.py +8 -6
- ignite/metrics/nlp/rouge.py +3 -3
- ignite/metrics/nlp/utils.py +1 -1
- ignite/metrics/precision_recall_curve.py +5 -5
- ignite/metrics/regression/_base.py +4 -0
- ignite/metrics/regression/fractional_bias.py +1 -1
- ignite/metrics/roc_auc.py +3 -3
- ignite/metrics/ssim.py +58 -20
- {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/METADATA +11 -17
- {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/RECORD +36 -37
- {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info}/WHEEL +1 -2
- pytorch_ignite-0.6.0.dev20250324.dist-info/top_level.txt +0 -1
- {pytorch_ignite-0.6.0.dev20250324.dist-info → pytorch_ignite-0.6.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
ignite/__init__.py
CHANGED
|
@@ -178,7 +178,7 @@ if has_native_dist_support:
|
|
|
178
178
|
c: Counter = Counter(hostnames)
|
|
179
179
|
sizes = torch.tensor([0] + list(c.values()))
|
|
180
180
|
cumsum_sizes = torch.cumsum(sizes, dim=0)
|
|
181
|
-
node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item()
|
|
181
|
+
node_rank = cast(int, (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item())
|
|
182
182
|
local_rank = rank - cumsum_sizes[node_rank].item()
|
|
183
183
|
return int(local_rank), node_rank
|
|
184
184
|
|
ignite/engine/__init__.py
CHANGED
|
@@ -133,11 +133,11 @@ def supervised_training_step_amp(
|
|
|
133
133
|
prepare_batch: Callable = _prepare_batch,
|
|
134
134
|
model_transform: Callable[[Any], Any] = lambda output: output,
|
|
135
135
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
136
|
-
scaler: Optional["torch.
|
|
136
|
+
scaler: Optional["torch.amp.GradScaler"] = None,
|
|
137
137
|
gradient_accumulation_steps: int = 1,
|
|
138
138
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
139
139
|
) -> Callable:
|
|
140
|
-
"""Factory function for supervised training using ``torch.
|
|
140
|
+
"""Factory function for supervised training using ``torch.amp``.
|
|
141
141
|
|
|
142
142
|
Args:
|
|
143
143
|
model: the model to train.
|
|
@@ -170,7 +170,7 @@ def supervised_training_step_amp(
|
|
|
170
170
|
model = ...
|
|
171
171
|
optimizer = ...
|
|
172
172
|
loss_fn = ...
|
|
173
|
-
scaler = torch.
|
|
173
|
+
scaler = torch.amp.GradScaler('cuda', 2**10)
|
|
174
174
|
|
|
175
175
|
update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
|
|
176
176
|
trainer = Engine(update_fn)
|
|
@@ -393,8 +393,8 @@ def supervised_training_step_tpu(
|
|
|
393
393
|
|
|
394
394
|
|
|
395
395
|
def _check_arg(
|
|
396
|
-
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.
|
|
397
|
-
) -> Tuple[Optional[str], Optional["torch.
|
|
396
|
+
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
|
|
397
|
+
) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
|
|
398
398
|
"""Checking tpu, mps, amp and GradScaler instance combinations."""
|
|
399
399
|
if on_mps and amp_mode:
|
|
400
400
|
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
|
|
@@ -410,9 +410,9 @@ def _check_arg(
|
|
|
410
410
|
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
|
|
411
411
|
elif amp_mode == "amp" and isinstance(scaler, bool):
|
|
412
412
|
try:
|
|
413
|
-
from torch.
|
|
413
|
+
from torch.amp import GradScaler
|
|
414
414
|
except ImportError:
|
|
415
|
-
raise ImportError("Please install torch>=
|
|
415
|
+
raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
|
|
416
416
|
scaler = GradScaler(enabled=True)
|
|
417
417
|
|
|
418
418
|
if on_tpu:
|
|
@@ -434,7 +434,7 @@ def create_supervised_trainer(
|
|
|
434
434
|
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
|
|
435
435
|
deterministic: bool = False,
|
|
436
436
|
amp_mode: Optional[str] = None,
|
|
437
|
-
scaler: Union[bool, "torch.
|
|
437
|
+
scaler: Union[bool, "torch.amp.GradScaler"] = False,
|
|
438
438
|
gradient_accumulation_steps: int = 1,
|
|
439
439
|
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
|
|
440
440
|
) -> Engine:
|
|
@@ -459,7 +459,7 @@ def create_supervised_trainer(
|
|
|
459
459
|
:class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
|
|
460
460
|
(default: False).
|
|
461
461
|
amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
|
|
462
|
-
`torch.
|
|
462
|
+
`torch.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
|
|
463
463
|
using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
|
|
464
464
|
scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
|
|
465
465
|
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
|
ignite/engine/engine.py
CHANGED
|
@@ -249,6 +249,17 @@ class Engine(Serializable):
|
|
|
249
249
|
# we need to update state attributes associated with new custom events
|
|
250
250
|
self.state._update_attrs()
|
|
251
251
|
|
|
252
|
+
def has_registered_events(self, event: Any) -> bool:
|
|
253
|
+
"""Check whether engine has a registered event.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
event: Event to check for registration.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
bool: True if the event is registered, False otherwise.
|
|
260
|
+
"""
|
|
261
|
+
return event in self._allowed_events
|
|
262
|
+
|
|
252
263
|
def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Callable) -> Callable:
|
|
253
264
|
# signature of the following wrapper will be inspected during registering to check if engine is necessary
|
|
254
265
|
# we have to build a wrapper with relevant signature : solution is functools.wraps
|
|
@@ -328,7 +339,7 @@ class Engine(Serializable):
|
|
|
328
339
|
|
|
329
340
|
try:
|
|
330
341
|
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
|
|
331
|
-
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
|
|
342
|
+
self._event_handlers[event_name].append((handler, (weakref.ref(self),) + args, kwargs))
|
|
332
343
|
except ValueError:
|
|
333
344
|
_check_signature(handler, "handler", *(event_args + args), **kwargs)
|
|
334
345
|
self._event_handlers[event_name].append((handler, args, kwargs))
|
|
@@ -432,7 +443,15 @@ class Engine(Serializable):
|
|
|
432
443
|
self.last_event_name = event_name
|
|
433
444
|
for func, args, kwargs in self._event_handlers[event_name]:
|
|
434
445
|
kwargs.update(event_kwargs)
|
|
435
|
-
|
|
446
|
+
if args and isinstance(args[0], weakref.ref):
|
|
447
|
+
resolved_engine = args[0]()
|
|
448
|
+
if resolved_engine is None:
|
|
449
|
+
raise RuntimeError("Engine reference not resolved. Cannot execute event handler.")
|
|
450
|
+
first, others = ((resolved_engine,), args[1:])
|
|
451
|
+
else:
|
|
452
|
+
# metrics do not provide engine when registered
|
|
453
|
+
first, others = (tuple(), args) # type: ignore[assignment]
|
|
454
|
+
|
|
436
455
|
func(*first, *(event_args + others), **kwargs)
|
|
437
456
|
|
|
438
457
|
def fire_event(self, event_name: Any) -> None:
|
|
@@ -1069,7 +1088,7 @@ class Engine(Serializable):
|
|
|
1069
1088
|
)
|
|
1070
1089
|
|
|
1071
1090
|
while True:
|
|
1072
|
-
self.state.batch =
|
|
1091
|
+
self.state.batch = None
|
|
1073
1092
|
|
|
1074
1093
|
try:
|
|
1075
1094
|
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
|
|
@@ -1081,6 +1100,9 @@ class Engine(Serializable):
|
|
|
1081
1100
|
yield from self._maybe_terminate_or_interrupt()
|
|
1082
1101
|
|
|
1083
1102
|
self.state.batch = next(self._dataloader_iter)
|
|
1103
|
+
# We on purpose reset state.output here as for iterable dataloaders
|
|
1104
|
+
# we accidentally can remove it when one epoch is completed.
|
|
1105
|
+
self.state.output = None
|
|
1084
1106
|
|
|
1085
1107
|
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
|
|
1086
1108
|
# if no data was provided to engine.run(data=None, ...)
|
|
@@ -1254,7 +1276,7 @@ class Engine(Serializable):
|
|
|
1254
1276
|
)
|
|
1255
1277
|
|
|
1256
1278
|
while True:
|
|
1257
|
-
self.state.batch =
|
|
1279
|
+
self.state.batch = None
|
|
1258
1280
|
try:
|
|
1259
1281
|
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
|
|
1260
1282
|
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
|
|
@@ -1265,6 +1287,10 @@ class Engine(Serializable):
|
|
|
1265
1287
|
self._maybe_terminate_legacy()
|
|
1266
1288
|
|
|
1267
1289
|
self.state.batch = next(self._dataloader_iter)
|
|
1290
|
+
# We on purpose reset state.output here as for iterable dataloaders
|
|
1291
|
+
# we accidentally can remove it when one epoch is completed.
|
|
1292
|
+
self.state.output = None
|
|
1293
|
+
|
|
1268
1294
|
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
|
|
1269
1295
|
# if no data was provided to engine.run(data=None, ...)
|
|
1270
1296
|
if self.state.dataloader is not None:
|
ignite/handlers/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
|
|
|
6
6
|
from ignite.handlers.clearml_logger import ClearMLLogger
|
|
7
7
|
from ignite.handlers.early_stopping import EarlyStopping
|
|
8
8
|
from ignite.handlers.ema_handler import EMAHandler
|
|
9
|
+
from ignite.handlers.fbresearch_logger import FBResearchLogger
|
|
9
10
|
from ignite.handlers.lr_finder import FastaiLRFinder
|
|
10
11
|
from ignite.handlers.mlflow_logger import MLflowLogger
|
|
11
12
|
from ignite.handlers.neptune_logger import NeptuneLogger
|
|
@@ -64,6 +65,7 @@ __all__ = [
|
|
|
64
65
|
"CyclicalScheduler",
|
|
65
66
|
"create_lr_scheduler_with_warmup",
|
|
66
67
|
"FastaiLRFinder",
|
|
68
|
+
"FBResearchLogger",
|
|
67
69
|
"EMAHandler",
|
|
68
70
|
"BasicTimeProfiler",
|
|
69
71
|
"HandlersTimeProfiler",
|
ignite/handlers/base_logger.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base logger and its helper handlers."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as collections
|
|
3
4
|
import numbers
|
|
4
5
|
import warnings
|
|
5
6
|
from abc import ABCMeta, abstractmethod
|
|
@@ -145,30 +146,64 @@ class BaseOutputHandler(BaseHandler):
|
|
|
145
146
|
|
|
146
147
|
metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
|
|
147
148
|
|
|
148
|
-
def
|
|
149
|
-
|
|
149
|
+
def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
|
|
150
|
+
if parent_key is None or isinstance(parent_key, str):
|
|
151
|
+
return (parent_key,) + args
|
|
152
|
+
return parent_key + args
|
|
150
153
|
|
|
151
|
-
def
|
|
152
|
-
|
|
154
|
+
def key_str_fn(parent_key: str, *args: str) -> str:
|
|
155
|
+
args_str = "/".join(args)
|
|
156
|
+
return f"{parent_key}/{args_str}"
|
|
153
157
|
|
|
154
|
-
|
|
158
|
+
key_fn = key_tuple_fn if key_tuple else key_str_fn
|
|
155
159
|
|
|
156
|
-
|
|
160
|
+
def handle_value_fn(
|
|
161
|
+
value: Union[str, int, float, numbers.Number, torch.Tensor]
|
|
162
|
+
) -> Union[None, str, float, numbers.Number]:
|
|
157
163
|
if isinstance(value, numbers.Number):
|
|
158
|
-
|
|
164
|
+
return value
|
|
159
165
|
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
|
|
160
|
-
|
|
161
|
-
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
|
|
162
|
-
for i, v in enumerate(value):
|
|
163
|
-
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
|
|
166
|
+
return value.item()
|
|
164
167
|
else:
|
|
165
168
|
if isinstance(value, str) and log_text:
|
|
166
|
-
|
|
169
|
+
return value
|
|
167
170
|
else:
|
|
168
171
|
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag)
|
|
169
175
|
return metrics_state_attrs_dict
|
|
170
176
|
|
|
171
177
|
|
|
178
|
+
def _flatten_dict(
|
|
179
|
+
in_dict: collections.Mapping,
|
|
180
|
+
key_fn: Callable,
|
|
181
|
+
value_fn: Callable,
|
|
182
|
+
parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
|
|
183
|
+
) -> Dict:
|
|
184
|
+
items = {}
|
|
185
|
+
for key, value in in_dict.items():
|
|
186
|
+
new_key = key_fn(parent_key, key)
|
|
187
|
+
if isinstance(value, collections.Mapping):
|
|
188
|
+
items.update(_flatten_dict(value, key_fn, value_fn, new_key))
|
|
189
|
+
elif any(
|
|
190
|
+
[
|
|
191
|
+
isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple
|
|
192
|
+
not isinstance(value, str) and isinstance(value, collections.Sequence),
|
|
193
|
+
]
|
|
194
|
+
):
|
|
195
|
+
for i, item in enumerate(value):
|
|
196
|
+
items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key))
|
|
197
|
+
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
|
|
198
|
+
for i, item in enumerate(value):
|
|
199
|
+
items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key))
|
|
200
|
+
else:
|
|
201
|
+
new_value = value_fn(value)
|
|
202
|
+
if new_value is not None:
|
|
203
|
+
items[new_key] = new_value
|
|
204
|
+
return items
|
|
205
|
+
|
|
206
|
+
|
|
172
207
|
class BaseWeightsScalarHandler(BaseWeightsHandler):
|
|
173
208
|
"""
|
|
174
209
|
Helper handler to log model's weights or gradients as scalars.
|
ignite/handlers/checkpoint.py
CHANGED
|
@@ -21,10 +21,21 @@ else:
|
|
|
21
21
|
|
|
22
22
|
import ignite.distributed as idist
|
|
23
23
|
from ignite.base import Serializable
|
|
24
|
-
from ignite.engine import Engine, Events
|
|
24
|
+
from ignite.engine import Engine, Events, EventEnum
|
|
25
25
|
from ignite.utils import _tree_apply2, _tree_map
|
|
26
26
|
|
|
27
|
-
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
|
|
27
|
+
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CheckpointEvents(EventEnum):
|
|
31
|
+
"""Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
|
|
32
|
+
|
|
33
|
+
- SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
|
|
34
|
+
|
|
35
|
+
.. versionadded:: 0.5.3
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
SAVED_CHECKPOINT = "saved_checkpoint"
|
|
28
39
|
|
|
29
40
|
|
|
30
41
|
class BaseSaveHandler(metaclass=ABCMeta):
|
|
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
|
|
|
264
275
|
to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
|
|
265
276
|
)
|
|
266
277
|
|
|
278
|
+
Respond to checkpoint events:
|
|
279
|
+
|
|
280
|
+
.. code-block:: python
|
|
281
|
+
|
|
282
|
+
from ignite.handlers import Checkpoint
|
|
283
|
+
from ignite.engine import Engine, Events
|
|
284
|
+
|
|
285
|
+
checkpoint_handler = Checkpoint(
|
|
286
|
+
{'model': model, 'optimizer': optimizer},
|
|
287
|
+
save_dir,
|
|
288
|
+
n_saved=2
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
|
|
292
|
+
def on_checkpoint_saved(engine):
|
|
293
|
+
print(f"Checkpoint saved at epoch {engine.state.epoch}")
|
|
294
|
+
|
|
295
|
+
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
|
|
296
|
+
|
|
297
|
+
Attributes:
|
|
298
|
+
SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
|
|
299
|
+
:class:`~ignite.handlers.checkpoint.CheckpointEvents`.
|
|
300
|
+
|
|
267
301
|
.. versionchanged:: 0.4.3
|
|
268
302
|
|
|
269
303
|
- Checkpoint can save model with same filename.
|
|
@@ -274,8 +308,13 @@ class Checkpoint(Serializable):
|
|
|
274
308
|
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
|
|
275
309
|
- `save_handler` automatically saves to disk if path to directory is provided.
|
|
276
310
|
- `save_on_rank` saves objects on this rank in a distributed configuration.
|
|
311
|
+
|
|
312
|
+
.. versionchanged:: 0.5.3
|
|
313
|
+
|
|
314
|
+
- Added ``SAVED_CHECKPOINT`` class attribute.
|
|
277
315
|
"""
|
|
278
316
|
|
|
317
|
+
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
|
|
279
318
|
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
|
|
280
319
|
_state_dict_all_req_keys = ("_saved",)
|
|
281
320
|
|
|
@@ -400,6 +439,8 @@ class Checkpoint(Serializable):
|
|
|
400
439
|
return new > self._saved[0].priority
|
|
401
440
|
|
|
402
441
|
def __call__(self, engine: Engine) -> None:
|
|
442
|
+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
|
|
443
|
+
engine.register_events(*CheckpointEvents)
|
|
403
444
|
global_step = None
|
|
404
445
|
if self.global_step_transform is not None:
|
|
405
446
|
global_step = self.global_step_transform(engine, engine.last_event_name)
|
|
@@ -460,11 +501,11 @@ class Checkpoint(Serializable):
|
|
|
460
501
|
if self.include_self:
|
|
461
502
|
# Now that we've updated _saved, we can add our own state_dict.
|
|
462
503
|
checkpoint["checkpointer"] = self.state_dict()
|
|
463
|
-
|
|
464
504
|
try:
|
|
465
505
|
self.save_handler(checkpoint, filename, metadata)
|
|
466
506
|
except TypeError:
|
|
467
507
|
self.save_handler(checkpoint, filename)
|
|
508
|
+
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
|
|
468
509
|
|
|
469
510
|
def _setup_checkpoint(self) -> Dict[str, Any]:
|
|
470
511
|
if self.to_save is not None:
|
|
@@ -109,8 +109,17 @@ class ClearMLLogger(BaseLogger):
|
|
|
109
109
|
log_handler=WeightsScalarHandler(model)
|
|
110
110
|
)
|
|
111
111
|
|
|
112
|
+
Note:
|
|
113
|
+
:class:`~ignite.handlers.clearml_logger.OutputHandler` can handle
|
|
114
|
+
metrics, state attributes and engine output values of the following format:
|
|
115
|
+
- scalar values (i.e. int, float)
|
|
116
|
+
- 0d and 1d pytorch tensors
|
|
117
|
+
- dicts and list/tuples of previous types
|
|
118
|
+
|
|
112
119
|
"""
|
|
113
120
|
|
|
121
|
+
_task: Any
|
|
122
|
+
|
|
114
123
|
def __init__(self, **kwargs: Any):
|
|
115
124
|
try:
|
|
116
125
|
from clearml import Task
|
|
@@ -342,9 +351,10 @@ class OutputHandler(BaseOutputHandler):
|
|
|
342
351
|
for key, value in metrics.items():
|
|
343
352
|
if len(key) == 2:
|
|
344
353
|
logger.clearml_logger.report_scalar(title=key[0], series=key[1], iteration=global_step, value=value)
|
|
345
|
-
elif len(key)
|
|
354
|
+
elif len(key) >= 3:
|
|
355
|
+
series = "/".join(key[2:])
|
|
346
356
|
logger.clearml_logger.report_scalar(
|
|
347
|
-
title=f"{key[0]}/{key[1]}", series=
|
|
357
|
+
title=f"{key[0]}/{key[1]}", series=series, iteration=global_step, value=value
|
|
348
358
|
)
|
|
349
359
|
|
|
350
360
|
|
|
@@ -815,6 +825,8 @@ class ClearMLSaver(DiskSaver):
|
|
|
815
825
|
|
|
816
826
|
"""
|
|
817
827
|
|
|
828
|
+
_task: Any
|
|
829
|
+
|
|
818
830
|
def __init__(
|
|
819
831
|
self,
|
|
820
832
|
logger: Optional[ClearMLLogger] = None,
|
|
@@ -850,7 +862,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
850
862
|
except ImportError:
|
|
851
863
|
try:
|
|
852
864
|
# Backwards-compatibility for legacy Trains SDK
|
|
853
|
-
from trains import Task
|
|
865
|
+
from trains import Task # type: ignore[no-redef]
|
|
854
866
|
except ImportError:
|
|
855
867
|
raise ModuleNotFoundError(
|
|
856
868
|
"This contrib module requires clearml to be installed. "
|
|
@@ -925,7 +937,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
925
937
|
except ImportError:
|
|
926
938
|
try:
|
|
927
939
|
# Backwards-compatibility for legacy Trains SDK
|
|
928
|
-
from trains.binding.frameworks import WeightsFileHandler
|
|
940
|
+
from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef]
|
|
929
941
|
except ImportError:
|
|
930
942
|
raise ModuleNotFoundError(
|
|
931
943
|
"This contrib module requires clearml to be installed. "
|
|
@@ -949,8 +961,8 @@ class ClearMLSaver(DiskSaver):
|
|
|
949
961
|
metadata=metadata,
|
|
950
962
|
)
|
|
951
963
|
|
|
952
|
-
pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
|
|
953
|
-
post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)
|
|
964
|
+
pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback) # type: ignore[arg-type]
|
|
965
|
+
post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback) # type: ignore[arg-type]
|
|
954
966
|
|
|
955
967
|
try:
|
|
956
968
|
super(ClearMLSaver, self).__call__(checkpoint, filename, metadata)
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from ignite import utils
|
|
9
9
|
from ignite.engine import Engine, Events
|
|
10
|
-
from ignite.handlers import Timer
|
|
10
|
+
from ignite.handlers.timing import Timer
|
|
11
11
|
|
|
12
12
|
MB = 1024.0 * 1024.0
|
|
13
13
|
|
|
@@ -154,7 +154,7 @@ class FBResearchLogger:
|
|
|
154
154
|
if torch.cuda.is_available():
|
|
155
155
|
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
|
|
156
156
|
|
|
157
|
-
current_iter = engine.state.iteration %
|
|
157
|
+
current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
|
|
158
158
|
iter_avg_time = self.iter_timer.value()
|
|
159
159
|
|
|
160
160
|
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
|
ignite/handlers/lr_finder.py
CHANGED
|
@@ -179,7 +179,7 @@ class FastaiLRFinder:
|
|
|
179
179
|
lr = self._lr_schedule.get_param()
|
|
180
180
|
self._history["lr"].append(lr)
|
|
181
181
|
if trainer.state.iteration == 1:
|
|
182
|
-
self._best_loss = loss
|
|
182
|
+
self._best_loss = loss # type: ignore[assignment]
|
|
183
183
|
else:
|
|
184
184
|
if smooth_f > 0:
|
|
185
185
|
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
ignite/handlers/mlflow_logger.py
CHANGED
|
@@ -84,6 +84,49 @@ class MLflowLogger(BaseLogger):
|
|
|
84
84
|
optimizer=optimizer,
|
|
85
85
|
param_name='lr' # optional
|
|
86
86
|
)
|
|
87
|
+
|
|
88
|
+
Note:
|
|
89
|
+
:class:`~ignite.handlers.mlflow_logger.OutputHandler` can handle
|
|
90
|
+
metrics, state attributes and engine output values of the following format:
|
|
91
|
+
- scalar values (i.e. int, float)
|
|
92
|
+
- 0d and 1d pytorch tensors
|
|
93
|
+
- dicts and list/tuples of previous types
|
|
94
|
+
|
|
95
|
+
.. code-block:: python
|
|
96
|
+
|
|
97
|
+
# !!! This is not a runnable code !!!
|
|
98
|
+
evalutator.state.metrics = {
|
|
99
|
+
"a": 0,
|
|
100
|
+
"dict_value": {
|
|
101
|
+
"a": 111,
|
|
102
|
+
"c": {"d": 23, "e": [123, 234]},
|
|
103
|
+
},
|
|
104
|
+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
|
|
105
|
+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
handler = OutputHandler(
|
|
109
|
+
tag="tag",
|
|
110
|
+
metric_names="all",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
handler(evaluator, mlflow_logger, event_name=Events.EPOCH_COMPLETED)
|
|
114
|
+
# Behind it would call `mlflow_logger.log_metrics` on
|
|
115
|
+
# {
|
|
116
|
+
# "tag/a": 0,
|
|
117
|
+
# "tag/dict_value/a": 111,
|
|
118
|
+
# "tag/dict_value/c/d": 23,
|
|
119
|
+
# "tag/dict_value/c/e/0": 123,
|
|
120
|
+
# "tag/dict_value/c/e/1": 234,
|
|
121
|
+
# "tag/list_value/0": 12,
|
|
122
|
+
# "tag/list_value/1": 13,
|
|
123
|
+
# "tag/list_value/2/aa": 33,
|
|
124
|
+
# "tag/list_value/2/bb": 44,
|
|
125
|
+
# "tag/tuple_value/0": 112,
|
|
126
|
+
# "tag/tuple_value/1": 113,
|
|
127
|
+
# "tag/tuple_value/2/aaa": 33,
|
|
128
|
+
# "tag/tuple_value/2/bbb": 44,
|
|
129
|
+
# }
|
|
87
130
|
"""
|
|
88
131
|
|
|
89
132
|
def __init__(self, tracking_uri: Optional[str] = None):
|
|
@@ -153,6 +153,13 @@ class NeptuneLogger(BaseLogger):
|
|
|
153
153
|
output_transform=lambda loss: {"loss": loss},
|
|
154
154
|
)
|
|
155
155
|
|
|
156
|
+
Note:
|
|
157
|
+
:class:`~ignite.handlers.neptune_logger.OutputHandler` can handle
|
|
158
|
+
metrics, state attributes and engine output values of the following format:
|
|
159
|
+
- scalar values (i.e. int, float)
|
|
160
|
+
- 0d and 1d pytorch tensors
|
|
161
|
+
- dicts and list/tuples of previous types
|
|
162
|
+
|
|
156
163
|
"""
|
|
157
164
|
|
|
158
165
|
def __getattr__(self, attr: Any) -> Any:
|
|
@@ -92,6 +92,13 @@ class PolyaxonLogger(BaseLogger):
|
|
|
92
92
|
)
|
|
93
93
|
# to manually end a run
|
|
94
94
|
plx_logger.close()
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
:class:`~ignite.handlers.polyaxon_logger.OutputHandler` can handle
|
|
98
|
+
metrics, state attributes and engine output values of the following format:
|
|
99
|
+
- scalar values (i.e. int, float)
|
|
100
|
+
- 0d and 1d pytorch tensors
|
|
101
|
+
- dicts and list/tuples of previous types
|
|
95
102
|
"""
|
|
96
103
|
|
|
97
104
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
@@ -145,6 +145,49 @@ class TensorboardLogger(BaseLogger):
|
|
|
145
145
|
output_transform=lambda loss: {"loss": loss}
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
+
Note:
|
|
149
|
+
:class:`~ignite.handlers.tensorboard_logger.OutputHandler` can handle
|
|
150
|
+
metrics, state attributes and engine output values of the following format:
|
|
151
|
+
- scalar values (i.e. int, float)
|
|
152
|
+
- 0d and 1d pytorch tensors
|
|
153
|
+
- dicts and list/tuples of previous types
|
|
154
|
+
|
|
155
|
+
.. code-block:: python
|
|
156
|
+
|
|
157
|
+
# !!! This is not a runnable code !!!
|
|
158
|
+
evalutator.state.metrics = {
|
|
159
|
+
"a": 0,
|
|
160
|
+
"dict_value": {
|
|
161
|
+
"a": 111,
|
|
162
|
+
"c": {"d": 23, "e": [123, 234]},
|
|
163
|
+
},
|
|
164
|
+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
|
|
165
|
+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
handler = OutputHandler(
|
|
169
|
+
tag="tag",
|
|
170
|
+
metric_names="all",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
handler(evaluator, tb_logger, event_name=Events.EPOCH_COMPLETED)
|
|
174
|
+
# Behind it would call `tb_logger.writer.add_scalar` on
|
|
175
|
+
# {
|
|
176
|
+
# "tag/a": 0,
|
|
177
|
+
# "tag/dict_value/a": 111,
|
|
178
|
+
# "tag/dict_value/c/d": 23,
|
|
179
|
+
# "tag/dict_value/c/e/0": 123,
|
|
180
|
+
# "tag/dict_value/c/e/1": 234,
|
|
181
|
+
# "tag/list_value/0": 12,
|
|
182
|
+
# "tag/list_value/1": 13,
|
|
183
|
+
# "tag/list_value/2/aa": 33,
|
|
184
|
+
# "tag/list_value/2/bb": 44,
|
|
185
|
+
# "tag/tuple_value/0": 112,
|
|
186
|
+
# "tag/tuple_value/1": 113,
|
|
187
|
+
# "tag/tuple_value/2/aaa": 33,
|
|
188
|
+
# "tag/tuple_value/2/bbb": 44,
|
|
189
|
+
# }
|
|
190
|
+
|
|
148
191
|
"""
|
|
149
192
|
|
|
150
193
|
def __init__(self, *args: Any, **kwargs: Any):
|
ignite/handlers/tqdm_logger.py
CHANGED
|
@@ -200,7 +200,7 @@ class ProgressBar(BaseLogger):
|
|
|
200
200
|
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
|
|
201
201
|
|
|
202
202
|
"""
|
|
203
|
-
desc = self.tqdm_kwargs.get("desc",
|
|
203
|
+
desc = self.tqdm_kwargs.get("desc", "")
|
|
204
204
|
|
|
205
205
|
if event_name not in engine._allowed_events:
|
|
206
206
|
raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")
|
|
@@ -298,8 +298,7 @@ class _OutputHandler(BaseOutputHandler):
|
|
|
298
298
|
rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
|
|
299
299
|
metrics = OrderedDict()
|
|
300
300
|
for key, value in rendered_metrics.items():
|
|
301
|
-
key = "_".join(key[1:]) # tqdm has tag as description
|
|
302
|
-
|
|
301
|
+
key = "_".join(key[1:]) # skip tag as tqdm has tag as description
|
|
303
302
|
metrics[key] = value
|
|
304
303
|
|
|
305
304
|
if metrics:
|
ignite/handlers/visdom_logger.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Visdom logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
@@ -137,6 +137,13 @@ class VisdomLogger(BaseLogger):
|
|
|
137
137
|
output_transform=lambda loss: {"loss": loss}
|
|
138
138
|
)
|
|
139
139
|
|
|
140
|
+
Note:
|
|
141
|
+
:class:`~ignite.handlers.visdom_logger.OutputHandler` can handle
|
|
142
|
+
metrics, state attributes and engine output values of the following format:
|
|
143
|
+
- scalar values (i.e. int, float)
|
|
144
|
+
- 0d and 1d pytorch tensors
|
|
145
|
+
- dicts and list/tuples of previous types
|
|
146
|
+
|
|
140
147
|
.. versionchanged:: 0.4.7
|
|
141
148
|
accepts an optional list of `state_attributes`
|
|
142
149
|
"""
|
|
@@ -172,7 +179,7 @@ class VisdomLogger(BaseLogger):
|
|
|
172
179
|
)
|
|
173
180
|
|
|
174
181
|
if server is None:
|
|
175
|
-
server =
|
|
182
|
+
server = os.environ.get("VISDOM_SERVER_URL", "localhost")
|
|
176
183
|
|
|
177
184
|
if port is None:
|
|
178
185
|
port = int(os.environ.get("VISDOM_PORT", 8097))
|
ignite/handlers/wandb_logger.py
CHANGED
|
@@ -26,7 +26,7 @@ class WandBLogger(BaseLogger):
|
|
|
26
26
|
Args:
|
|
27
27
|
args: Positional arguments accepted by `wandb.init`.
|
|
28
28
|
kwargs: Keyword arguments accepted by `wandb.init`.
|
|
29
|
-
Please see `wandb.init <https://docs.wandb.ai/ref/python/init
|
|
29
|
+
Please see `wandb.init <https://docs.wandb.ai/ref/python/sdk/functions/init/>`_ for documentation of possible parameters.
|
|
30
30
|
|
|
31
31
|
Examples:
|
|
32
32
|
.. code-block:: python
|
|
@@ -120,6 +120,12 @@ class WandBLogger(BaseLogger):
|
|
|
120
120
|
)
|
|
121
121
|
evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
|
|
122
122
|
|
|
123
|
+
Note:
|
|
124
|
+
:class:`~ignite.handlers.wandb_logger.OutputHandler` can handle
|
|
125
|
+
metrics, state attributes and engine output values of the following format:
|
|
126
|
+
- scalar values (i.e. int, float)
|
|
127
|
+
- 0d and 1d pytorch tensors
|
|
128
|
+
- dicts and list/tuples of previous types
|
|
123
129
|
|
|
124
130
|
"""
|
|
125
131
|
|