pytorch-ignite 0.6.0.dev20260115__py3-none-any.whl → 0.6.0.dev20260117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +44 -44
- ignite/metrics/average_precision.py +2 -2
- ignite/metrics/cosine_similarity.py +2 -2
- ignite/metrics/frequency.py +2 -2
- ignite/metrics/mean_absolute_error.py +2 -2
- ignite/metrics/mean_squared_error.py +2 -2
- ignite/metrics/psnr.py +3 -3
- ignite/metrics/root_mean_squared_error.py +1 -2
- ignite/utils.py +29 -27
- {pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/METADATA +1 -1
- {pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/RECORD +14 -14
- {pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/WHEEL +0 -0
- {pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py
CHANGED
ignite/contrib/engines/common.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numbers
|
|
2
2
|
import warnings
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable, cast,
|
|
4
|
+
from typing import Any, Callable, cast, Iterable, Mapping, Sequence
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
@@ -41,19 +41,19 @@ from ignite.utils import deprecated
|
|
|
41
41
|
|
|
42
42
|
def setup_common_training_handlers(
|
|
43
43
|
trainer: Engine,
|
|
44
|
-
train_sampler:
|
|
45
|
-
to_save:
|
|
44
|
+
train_sampler: DistributedSampler | None = None,
|
|
45
|
+
to_save: Mapping | None = None,
|
|
46
46
|
save_every_iters: int = 1000,
|
|
47
|
-
output_path:
|
|
48
|
-
lr_scheduler:
|
|
47
|
+
output_path: str | None = None,
|
|
48
|
+
lr_scheduler: ParamScheduler | PyTorchLRScheduler | None = None,
|
|
49
49
|
with_gpu_stats: bool = False,
|
|
50
|
-
output_names:
|
|
50
|
+
output_names: Iterable[str] | None = None,
|
|
51
51
|
with_pbars: bool = True,
|
|
52
52
|
with_pbar_on_iters: bool = True,
|
|
53
53
|
log_every_iters: int = 100,
|
|
54
54
|
stop_on_nan: bool = True,
|
|
55
55
|
clear_cuda_cache: bool = True,
|
|
56
|
-
save_handler:
|
|
56
|
+
save_handler: Callable | BaseSaveHandler | None = None,
|
|
57
57
|
**kwargs: Any,
|
|
58
58
|
) -> None:
|
|
59
59
|
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
|
|
@@ -145,18 +145,18 @@ setup_common_distrib_training_handlers = setup_common_training_handlers
|
|
|
145
145
|
|
|
146
146
|
def _setup_common_training_handlers(
|
|
147
147
|
trainer: Engine,
|
|
148
|
-
to_save:
|
|
148
|
+
to_save: Mapping | None = None,
|
|
149
149
|
save_every_iters: int = 1000,
|
|
150
|
-
output_path:
|
|
151
|
-
lr_scheduler:
|
|
150
|
+
output_path: str | None = None,
|
|
151
|
+
lr_scheduler: ParamScheduler | PyTorchLRScheduler | None = None,
|
|
152
152
|
with_gpu_stats: bool = False,
|
|
153
|
-
output_names:
|
|
153
|
+
output_names: Iterable[str] | None = None,
|
|
154
154
|
with_pbars: bool = True,
|
|
155
155
|
with_pbar_on_iters: bool = True,
|
|
156
156
|
log_every_iters: int = 100,
|
|
157
157
|
stop_on_nan: bool = True,
|
|
158
158
|
clear_cuda_cache: bool = True,
|
|
159
|
-
save_handler:
|
|
159
|
+
save_handler: Callable | BaseSaveHandler | None = None,
|
|
160
160
|
**kwargs: Any,
|
|
161
161
|
) -> None:
|
|
162
162
|
if output_path is not None and save_handler is not None:
|
|
@@ -185,7 +185,7 @@ def _setup_common_training_handlers(
|
|
|
185
185
|
save_handler = DiskSaver(dirname=output_path, require_empty=False)
|
|
186
186
|
|
|
187
187
|
checkpoint_handler = Checkpoint(
|
|
188
|
-
to_save, cast(
|
|
188
|
+
to_save, cast(Callable | BaseSaveHandler, save_handler), filename_prefix="training", **kwargs
|
|
189
189
|
)
|
|
190
190
|
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)
|
|
191
191
|
|
|
@@ -227,19 +227,19 @@ def _setup_common_training_handlers(
|
|
|
227
227
|
|
|
228
228
|
def _setup_common_distrib_training_handlers(
|
|
229
229
|
trainer: Engine,
|
|
230
|
-
train_sampler:
|
|
231
|
-
to_save:
|
|
230
|
+
train_sampler: DistributedSampler | None = None,
|
|
231
|
+
to_save: Mapping | None = None,
|
|
232
232
|
save_every_iters: int = 1000,
|
|
233
|
-
output_path:
|
|
234
|
-
lr_scheduler:
|
|
233
|
+
output_path: str | None = None,
|
|
234
|
+
lr_scheduler: ParamScheduler | PyTorchLRScheduler | None = None,
|
|
235
235
|
with_gpu_stats: bool = False,
|
|
236
|
-
output_names:
|
|
236
|
+
output_names: Iterable[str] | None = None,
|
|
237
237
|
with_pbars: bool = True,
|
|
238
238
|
with_pbar_on_iters: bool = True,
|
|
239
239
|
log_every_iters: int = 100,
|
|
240
240
|
stop_on_nan: bool = True,
|
|
241
241
|
clear_cuda_cache: bool = True,
|
|
242
|
-
save_handler:
|
|
242
|
+
save_handler: Callable | BaseSaveHandler | None = None,
|
|
243
243
|
**kwargs: Any,
|
|
244
244
|
) -> None:
|
|
245
245
|
_setup_common_training_handlers(
|
|
@@ -286,8 +286,8 @@ def setup_any_logging(
|
|
|
286
286
|
logger: BaseLogger,
|
|
287
287
|
logger_module: Any,
|
|
288
288
|
trainer: Engine,
|
|
289
|
-
optimizers:
|
|
290
|
-
evaluators:
|
|
289
|
+
optimizers: Optimizer | dict[str, Optimizer] | dict[None, Optimizer] | None,
|
|
290
|
+
evaluators: Engine | dict[str, Engine] | None,
|
|
291
291
|
log_every_iters: int,
|
|
292
292
|
) -> None:
|
|
293
293
|
pass
|
|
@@ -296,8 +296,8 @@ def setup_any_logging(
|
|
|
296
296
|
def _setup_logging(
|
|
297
297
|
logger: BaseLogger,
|
|
298
298
|
trainer: Engine,
|
|
299
|
-
optimizers:
|
|
300
|
-
evaluators:
|
|
299
|
+
optimizers: Optimizer | dict[str, Optimizer] | dict[None, Optimizer] | None,
|
|
300
|
+
evaluators: Engine | dict[str, Engine] | None,
|
|
301
301
|
log_every_iters: int,
|
|
302
302
|
) -> None:
|
|
303
303
|
if optimizers is not None:
|
|
@@ -341,8 +341,8 @@ def _setup_logging(
|
|
|
341
341
|
def setup_tb_logging(
|
|
342
342
|
output_path: str,
|
|
343
343
|
trainer: Engine,
|
|
344
|
-
optimizers:
|
|
345
|
-
evaluators:
|
|
344
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
345
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
346
346
|
log_every_iters: int = 100,
|
|
347
347
|
**kwargs: Any,
|
|
348
348
|
) -> TensorboardLogger:
|
|
@@ -373,8 +373,8 @@ def setup_tb_logging(
|
|
|
373
373
|
|
|
374
374
|
def setup_visdom_logging(
|
|
375
375
|
trainer: Engine,
|
|
376
|
-
optimizers:
|
|
377
|
-
evaluators:
|
|
376
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
377
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
378
378
|
log_every_iters: int = 100,
|
|
379
379
|
**kwargs: Any,
|
|
380
380
|
) -> VisdomLogger:
|
|
@@ -404,8 +404,8 @@ def setup_visdom_logging(
|
|
|
404
404
|
|
|
405
405
|
def setup_mlflow_logging(
|
|
406
406
|
trainer: Engine,
|
|
407
|
-
optimizers:
|
|
408
|
-
evaluators:
|
|
407
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
408
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
409
409
|
log_every_iters: int = 100,
|
|
410
410
|
**kwargs: Any,
|
|
411
411
|
) -> MLflowLogger:
|
|
@@ -435,8 +435,8 @@ def setup_mlflow_logging(
|
|
|
435
435
|
|
|
436
436
|
def setup_neptune_logging(
|
|
437
437
|
trainer: Engine,
|
|
438
|
-
optimizers:
|
|
439
|
-
evaluators:
|
|
438
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
439
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
440
440
|
log_every_iters: int = 100,
|
|
441
441
|
**kwargs: Any,
|
|
442
442
|
) -> NeptuneLogger:
|
|
@@ -466,8 +466,8 @@ def setup_neptune_logging(
|
|
|
466
466
|
|
|
467
467
|
def setup_wandb_logging(
|
|
468
468
|
trainer: Engine,
|
|
469
|
-
optimizers:
|
|
470
|
-
evaluators:
|
|
469
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
470
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
471
471
|
log_every_iters: int = 100,
|
|
472
472
|
**kwargs: Any,
|
|
473
473
|
) -> WandBLogger:
|
|
@@ -497,8 +497,8 @@ def setup_wandb_logging(
|
|
|
497
497
|
|
|
498
498
|
def setup_plx_logging(
|
|
499
499
|
trainer: Engine,
|
|
500
|
-
optimizers:
|
|
501
|
-
evaluators:
|
|
500
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
501
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
502
502
|
log_every_iters: int = 100,
|
|
503
503
|
**kwargs: Any,
|
|
504
504
|
) -> PolyaxonLogger:
|
|
@@ -528,8 +528,8 @@ def setup_plx_logging(
|
|
|
528
528
|
|
|
529
529
|
def setup_clearml_logging(
|
|
530
530
|
trainer: Engine,
|
|
531
|
-
optimizers:
|
|
532
|
-
evaluators:
|
|
531
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
532
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
533
533
|
log_every_iters: int = 100,
|
|
534
534
|
**kwargs: Any,
|
|
535
535
|
) -> ClearMLLogger:
|
|
@@ -559,8 +559,8 @@ def setup_clearml_logging(
|
|
|
559
559
|
|
|
560
560
|
def setup_trains_logging(
|
|
561
561
|
trainer: Engine,
|
|
562
|
-
optimizers:
|
|
563
|
-
evaluators:
|
|
562
|
+
optimizers: Optimizer | dict[str, Optimizer] | None = None,
|
|
563
|
+
evaluators: Engine | dict[str, Engine] | None = None,
|
|
564
564
|
log_every_iters: int = 100,
|
|
565
565
|
**kwargs: Any,
|
|
566
566
|
) -> ClearMLLogger:
|
|
@@ -573,12 +573,12 @@ get_default_score_fn = Checkpoint.get_default_score_fn
|
|
|
573
573
|
|
|
574
574
|
|
|
575
575
|
def gen_save_best_models_by_val_score(
|
|
576
|
-
save_handler:
|
|
576
|
+
save_handler: Callable | BaseSaveHandler,
|
|
577
577
|
evaluator: Engine,
|
|
578
|
-
models:
|
|
578
|
+
models: torch.nn.Module | dict[str, torch.nn.Module],
|
|
579
579
|
metric_name: str,
|
|
580
580
|
n_saved: int = 3,
|
|
581
|
-
trainer:
|
|
581
|
+
trainer: Engine | None = None,
|
|
582
582
|
tag: str = "val",
|
|
583
583
|
score_sign: float = 1.0,
|
|
584
584
|
**kwargs: Any,
|
|
@@ -615,7 +615,7 @@ def gen_save_best_models_by_val_score(
|
|
|
615
615
|
global_step_transform = global_step_from_engine(trainer)
|
|
616
616
|
|
|
617
617
|
if isinstance(models, nn.Module):
|
|
618
|
-
to_save:
|
|
618
|
+
to_save: dict[str, nn.Module] = {"model": models}
|
|
619
619
|
else:
|
|
620
620
|
to_save = models
|
|
621
621
|
|
|
@@ -640,7 +640,7 @@ def save_best_model_by_val_score(
|
|
|
640
640
|
model: torch.nn.Module,
|
|
641
641
|
metric_name: str,
|
|
642
642
|
n_saved: int = 3,
|
|
643
|
-
trainer:
|
|
643
|
+
trainer: Engine | None = None,
|
|
644
644
|
tag: str = "val",
|
|
645
645
|
score_sign: float = 1.0,
|
|
646
646
|
**kwargs: Any,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable
|
|
1
|
+
from typing import Callable
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -71,7 +71,7 @@ class AveragePrecision(EpochMetric):
|
|
|
71
71
|
self,
|
|
72
72
|
output_transform: Callable = lambda x: x,
|
|
73
73
|
check_compute_fn: bool = False,
|
|
74
|
-
device:
|
|
74
|
+
device: str | torch.device = torch.device("cpu"),
|
|
75
75
|
skip_unrolling: bool = False,
|
|
76
76
|
):
|
|
77
77
|
try:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable, Sequence
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -77,7 +77,7 @@ class CosineSimilarity(Metric):
|
|
|
77
77
|
self,
|
|
78
78
|
eps: float = 1e-8,
|
|
79
79
|
output_transform: Callable = lambda x: x,
|
|
80
|
-
device:
|
|
80
|
+
device: str | torch.device = torch.device("cpu"),
|
|
81
81
|
skip_unrolling: bool = False,
|
|
82
82
|
):
|
|
83
83
|
super().__init__(output_transform, device, skip_unrolling=skip_unrolling)
|
ignite/metrics/frequency.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable
|
|
1
|
+
from typing import Callable
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -56,7 +56,7 @@ class Frequency(Metric):
|
|
|
56
56
|
def __init__(
|
|
57
57
|
self,
|
|
58
58
|
output_transform: Callable = lambda x: x,
|
|
59
|
-
device:
|
|
59
|
+
device: str | torch.device = torch.device("cpu"),
|
|
60
60
|
skip_unrolling: bool = False,
|
|
61
61
|
) -> None:
|
|
62
62
|
super(Frequency, self).__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Sequence
|
|
1
|
+
from typing import Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -80,7 +80,7 @@ class MeanAbsoluteError(Metric):
|
|
|
80
80
|
self._num_examples += y.shape[0]
|
|
81
81
|
|
|
82
82
|
@sync_all_reduce("_sum_of_absolute_errors", "_num_examples")
|
|
83
|
-
def compute(self) ->
|
|
83
|
+
def compute(self) -> float | torch.Tensor:
|
|
84
84
|
if self._num_examples == 0:
|
|
85
85
|
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
|
|
86
86
|
return self._sum_of_absolute_errors.item() / self._num_examples
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Sequence
|
|
1
|
+
from typing import Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -80,7 +80,7 @@ class MeanSquaredError(Metric):
|
|
|
80
80
|
self._num_examples += y.shape[0]
|
|
81
81
|
|
|
82
82
|
@sync_all_reduce("_sum_of_squared_errors", "_num_examples")
|
|
83
|
-
def compute(self) ->
|
|
83
|
+
def compute(self) -> float | torch.Tensor:
|
|
84
84
|
if self._num_examples == 0:
|
|
85
85
|
raise NotComputableError("MeanSquaredError must have at least one example before it can be computed.")
|
|
86
86
|
return self._sum_of_squared_errors.item() / self._num_examples
|
ignite/metrics/psnr.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable, Sequence
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -91,9 +91,9 @@ class PSNR(Metric):
|
|
|
91
91
|
|
|
92
92
|
def __init__(
|
|
93
93
|
self,
|
|
94
|
-
data_range:
|
|
94
|
+
data_range: int | float,
|
|
95
95
|
output_transform: Callable = lambda x: x,
|
|
96
|
-
device:
|
|
96
|
+
device: str | torch.device = torch.device("cpu"),
|
|
97
97
|
skip_unrolling: bool = False,
|
|
98
98
|
):
|
|
99
99
|
super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Union
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
@@ -65,6 +64,6 @@ class RootMeanSquaredError(MeanSquaredError):
|
|
|
65
64
|
``skip_unrolling`` argument is added.
|
|
66
65
|
"""
|
|
67
66
|
|
|
68
|
-
def compute(self) ->
|
|
67
|
+
def compute(self) -> torch.Tensor | float:
|
|
69
68
|
mse = super(RootMeanSquaredError, self).compute()
|
|
70
69
|
return math.sqrt(mse)
|
ignite/utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import collections.abc as collections
|
|
2
4
|
import functools
|
|
3
5
|
import hashlib
|
|
@@ -7,7 +9,7 @@ import random
|
|
|
7
9
|
import shutil
|
|
8
10
|
import warnings
|
|
9
11
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Callable, cast,
|
|
12
|
+
from typing import Any, Callable, cast, TextIO, TypeVar
|
|
11
13
|
|
|
12
14
|
import torch
|
|
13
15
|
|
|
@@ -24,10 +26,10 @@ __all__ = [
|
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def convert_tensor(
|
|
27
|
-
x:
|
|
28
|
-
device:
|
|
29
|
+
x: torch.Tensor | collections.Sequence | collections.Mapping | str | bytes,
|
|
30
|
+
device: str | torch.device | None = None,
|
|
29
31
|
non_blocking: bool = False,
|
|
30
|
-
) ->
|
|
32
|
+
) -> torch.Tensor | collections.Sequence | collections.Mapping | str | bytes:
|
|
31
33
|
"""Move tensors to relevant device.
|
|
32
34
|
|
|
33
35
|
Args:
|
|
@@ -44,8 +46,8 @@ def convert_tensor(
|
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
def apply_to_tensor(
|
|
47
|
-
x:
|
|
48
|
-
) ->
|
|
49
|
+
x: torch.Tensor | collections.Sequence | collections.Mapping | str | bytes, func: Callable
|
|
50
|
+
) -> torch.Tensor | collections.Sequence | collections.Mapping | str | bytes:
|
|
49
51
|
"""Apply a function on a tensor or mapping, or sequence of tensors.
|
|
50
52
|
|
|
51
53
|
Args:
|
|
@@ -56,10 +58,10 @@ def apply_to_tensor(
|
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
def apply_to_type(
|
|
59
|
-
x:
|
|
60
|
-
input_type:
|
|
61
|
+
x: Any | collections.Sequence | collections.Mapping | str | bytes,
|
|
62
|
+
input_type: type | tuple[type[Any], Any],
|
|
61
63
|
func: Callable,
|
|
62
|
-
) ->
|
|
64
|
+
) -> Any | collections.Sequence | collections.Mapping | str | bytes:
|
|
63
65
|
"""Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`.
|
|
64
66
|
|
|
65
67
|
Args:
|
|
@@ -81,8 +83,8 @@ def apply_to_type(
|
|
|
81
83
|
|
|
82
84
|
|
|
83
85
|
def _tree_map(
|
|
84
|
-
func: Callable, x:
|
|
85
|
-
) ->
|
|
86
|
+
func: Callable, x: Any | collections.Sequence | collections.Mapping, key: int | str | None = None
|
|
87
|
+
) -> Any | collections.Sequence | collections.Mapping:
|
|
86
88
|
if isinstance(x, collections.Mapping):
|
|
87
89
|
return cast(Callable, type(x))({k: _tree_map(func, sample, key=k) for k, sample in x.items()})
|
|
88
90
|
if isinstance(x, tuple) and hasattr(x, "_fields"): # namedtuple
|
|
@@ -92,7 +94,7 @@ def _tree_map(
|
|
|
92
94
|
return func(x, key=key)
|
|
93
95
|
|
|
94
96
|
|
|
95
|
-
def _to_str_list(data: Any) ->
|
|
97
|
+
def _to_str_list(data: Any) -> list[str]:
|
|
96
98
|
"""
|
|
97
99
|
Recursively flattens and formats complex data structures, including keys for
|
|
98
100
|
dictionaries, into a list of human-readable strings.
|
|
@@ -127,9 +129,9 @@ def _to_str_list(data: Any) -> List[str]:
|
|
|
127
129
|
A list of formatted strings, each representing a part of the input data
|
|
128
130
|
structure.
|
|
129
131
|
"""
|
|
130
|
-
formatted_items:
|
|
132
|
+
formatted_items: list[str] = []
|
|
131
133
|
|
|
132
|
-
def format_item(item: Any, prefix: str = "") ->
|
|
134
|
+
def format_item(item: Any, prefix: str = "") -> str | None:
|
|
133
135
|
if isinstance(item, numbers.Number):
|
|
134
136
|
return f"{prefix}{item:.4f}"
|
|
135
137
|
elif torch.is_tensor(item):
|
|
@@ -169,9 +171,9 @@ def _to_str_list(data: Any) -> List[str]:
|
|
|
169
171
|
|
|
170
172
|
|
|
171
173
|
class _CollectionItem:
|
|
172
|
-
types_as_collection_item:
|
|
174
|
+
types_as_collection_item: tuple = (int, float, torch.Tensor)
|
|
173
175
|
|
|
174
|
-
def __init__(self, collection:
|
|
176
|
+
def __init__(self, collection: dict | list, key: int | str) -> None:
|
|
175
177
|
if not isinstance(collection, (dict, list)):
|
|
176
178
|
raise TypeError(
|
|
177
179
|
f"Input type is expected to be a mapping or list, but got {type(collection)} " f"for input key '{key}'."
|
|
@@ -189,7 +191,7 @@ class _CollectionItem:
|
|
|
189
191
|
return self.collection[self.key] # type: ignore[index]
|
|
190
192
|
|
|
191
193
|
@staticmethod
|
|
192
|
-
def wrap(object:
|
|
194
|
+
def wrap(object: dict | list, key: int | str, value: Any) -> Any | "_CollectionItem":
|
|
193
195
|
return (
|
|
194
196
|
_CollectionItem(object, key)
|
|
195
197
|
if value is None or isinstance(value, _CollectionItem.types_as_collection_item)
|
|
@@ -199,8 +201,8 @@ class _CollectionItem:
|
|
|
199
201
|
|
|
200
202
|
def _tree_apply2(
|
|
201
203
|
func: Callable,
|
|
202
|
-
x:
|
|
203
|
-
y:
|
|
204
|
+
x: Any | list | dict,
|
|
205
|
+
y: Any | collections.Sequence | collections.Mapping,
|
|
204
206
|
) -> None:
|
|
205
207
|
if isinstance(x, dict) and isinstance(y, collections.Mapping):
|
|
206
208
|
for k, v in x.items():
|
|
@@ -234,14 +236,14 @@ def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
|
234
236
|
|
|
235
237
|
|
|
236
238
|
def setup_logger(
|
|
237
|
-
name:
|
|
239
|
+
name: str | None = "ignite",
|
|
238
240
|
level: int = logging.INFO,
|
|
239
|
-
stream:
|
|
241
|
+
stream: TextIO | None = None,
|
|
240
242
|
format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
|
241
|
-
filepath:
|
|
242
|
-
distributed_rank:
|
|
243
|
+
filepath: str | None = None,
|
|
244
|
+
distributed_rank: int | None = None,
|
|
243
245
|
reset: bool = False,
|
|
244
|
-
encoding:
|
|
246
|
+
encoding: str | None = "utf-8",
|
|
245
247
|
) -> logging.Logger:
|
|
246
248
|
"""Setups logger: name, level, format etc.
|
|
247
249
|
|
|
@@ -393,7 +395,7 @@ def manual_seed(seed: int) -> None:
|
|
|
393
395
|
|
|
394
396
|
|
|
395
397
|
def deprecated(
|
|
396
|
-
deprecated_in: str, removed_in: str = "", reasons:
|
|
398
|
+
deprecated_in: str, removed_in: str = "", reasons: tuple[str, ...] = (), raise_exception: bool = False
|
|
397
399
|
) -> Callable:
|
|
398
400
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
399
401
|
|
|
@@ -406,7 +408,7 @@ def deprecated(
|
|
|
406
408
|
)
|
|
407
409
|
|
|
408
410
|
@functools.wraps(func)
|
|
409
|
-
def wrapper(*args: Any, **kwargs:
|
|
411
|
+
def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Callable:
|
|
410
412
|
if raise_exception:
|
|
411
413
|
raise DeprecationWarning(deprecation_warning)
|
|
412
414
|
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
|
|
@@ -422,7 +424,7 @@ def deprecated(
|
|
|
422
424
|
return decorator
|
|
423
425
|
|
|
424
426
|
|
|
425
|
-
def hash_checkpoint(checkpoint_path:
|
|
427
|
+
def hash_checkpoint(checkpoint_path: str | Path, output_dir: str | Path) -> tuple[Path, str]:
|
|
426
428
|
"""
|
|
427
429
|
Hash the checkpoint file in the format of ``<filename>-<hash>.<ext>``
|
|
428
430
|
to be used with ``check_hash`` of :func:`torch.hub.load_state_dict_from_url`.
|
{pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytorch-ignite
|
|
3
|
-
Version: 0.6.0.
|
|
3
|
+
Version: 0.6.0.dev20260117
|
|
4
4
|
Summary: A lightweight library to help with training neural networks in PyTorch.
|
|
5
5
|
Project-URL: Homepage, https://pytorch-ignite.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/pytorch/ignite
|
{pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/RECORD
RENAMED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
ignite/__init__.py,sha256=
|
|
1
|
+
ignite/__init__.py,sha256=knPS0MilS-UCCCvT5DMsY2Rj4cUbtZLNsMP5W9Iko-4,194
|
|
2
2
|
ignite/_utils.py,sha256=XDPpUDJ8ykLXWMV2AYTqGSj8XCfApsyzsQ3Vij_OB4M,182
|
|
3
3
|
ignite/exceptions.py,sha256=5ZWCVLPC9rgoW8t84D-VeEleqz5O7XpAGPpCdU8rKd0,150
|
|
4
4
|
ignite/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
ignite/utils.py,sha256=
|
|
5
|
+
ignite/utils.py,sha256=D_aro3zrFukYCH7v2yoNEgjD_6b-MlW_8ceWBTs3pCM,16981
|
|
6
6
|
ignite/base/__init__.py,sha256=y2g9egjuVCYRtaj-4ge081y-8cjIXsw_ZgZ6BRguHi0,44
|
|
7
7
|
ignite/base/mixins.py,sha256=Ip1SHCQCsvNUnLJKJwX9L-hqpfcZAlTad87-PaVgCBI,991
|
|
8
8
|
ignite/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
ignite/contrib/engines/__init__.py,sha256=BxmXYIYEtEB1niMWITL8pgyKufCIpXR61rSzPQOhA0g,87
|
|
10
|
-
ignite/contrib/engines/common.py,sha256=
|
|
10
|
+
ignite/contrib/engines/common.py,sha256=2g3xnIfThvzAIVS8zApsOfb2bedNydzmsoZrgpWn8ao,28171
|
|
11
11
|
ignite/contrib/engines/tbptt.py,sha256=FSmF5SnoZn7mWNZWRZ-ohWUCfucET78GQu3lvVRNItk,4507
|
|
12
12
|
ignite/contrib/handlers/__init__.py,sha256=rZszZnCbzncE2jqsvx9KP1iS3WZ0I-CnrV3Jh3Xl8_o,1073
|
|
13
13
|
ignite/contrib/handlers/base_logger.py,sha256=gHVTkVvYMRUXI793rNq8564mMyJaL_HCuoCu8xiKxFY,1158
|
|
@@ -85,25 +85,25 @@ ignite/handlers/wandb_logger.py,sha256=9HbwRMHzWckrZ-m0rkMF5Ug6r9C9J4sdq73yqaAHW
|
|
|
85
85
|
ignite/metrics/__init__.py,sha256=m-8F8J17r-aEwsO6Ww-8AqDRN59WFfYBwCDKwqGDSmI,3627
|
|
86
86
|
ignite/metrics/accumulation.py,sha256=xWdsm9u6JfsfODX_GUKzQc_omrdFDJ4yELBR-xXgc4s,12448
|
|
87
87
|
ignite/metrics/accuracy.py,sha256=W8mO4W11VzryMXKy8G7W_g4A9PH9RYpejW_tQ-T_Txw,10245
|
|
88
|
-
ignite/metrics/average_precision.py,sha256=
|
|
88
|
+
ignite/metrics/average_precision.py,sha256=laDD8BnAC5OuAJrCRtwCZ7EjjoQKRb7D3o-86IRsdN4,3681
|
|
89
89
|
ignite/metrics/classification_report.py,sha256=zjGlaMnRz2__op6hrZq74OusO0W_5B1AIe8KzYGFilM,5988
|
|
90
90
|
ignite/metrics/cohen_kappa.py,sha256=Qwcd4P2kN12CVCFC-kVdzn_2XV7kGzP6LlWkK209JJ8,3815
|
|
91
91
|
ignite/metrics/confusion_matrix.py,sha256=dZDuK3vxrrbiQh6VfyV5aWFpuTJWsfnZ30Mxt6u6eOA,18215
|
|
92
|
-
ignite/metrics/cosine_similarity.py,sha256=
|
|
92
|
+
ignite/metrics/cosine_similarity.py,sha256=9f5dM0QaiXBznidlUuKb8Q3E45W_Z_hGiy9WP5Fpcqw,4416
|
|
93
93
|
ignite/metrics/entropy.py,sha256=gJZkR5Sl1ZdIzJ9pFkydf1186bZU8OnkOLvOtKz6Wrs,4511
|
|
94
94
|
ignite/metrics/epoch_metric.py,sha256=H4PVsDtcqk53l47Ehc3kliKT4QtyZUf600ut-8rRP8M,7050
|
|
95
95
|
ignite/metrics/fbeta.py,sha256=2oDsRM7XXJ8LPVrn7iwLdRy75RLJELijmshtMQO3mJM,6870
|
|
96
|
-
ignite/metrics/frequency.py,sha256=
|
|
96
|
+
ignite/metrics/frequency.py,sha256=eyvfaTIPaxU_s7vU0DCrAUveb5AKKgEuxdg5VsCaNkw,4024
|
|
97
97
|
ignite/metrics/gpu_info.py,sha256=kcDIifr9js_P-32LddizEggvvL6eqFLYCHYeFDR4GL0,4301
|
|
98
98
|
ignite/metrics/hsic.py,sha256=am-gor2mXY3H3u2vVNQGPJtkx_5W5JNZeukl2uYqajE,7099
|
|
99
99
|
ignite/metrics/js_divergence.py,sha256=HAgj12JwL9bT33cCSAX7g4EKSfqFNNehkgwZfJuncfw,4828
|
|
100
100
|
ignite/metrics/kl_divergence.py,sha256=FdC5BT-nd8nmYqT95Xozw-hW0hZC6dtTklkpJdwWJ6o,5152
|
|
101
101
|
ignite/metrics/loss.py,sha256=mB-zYptymtcyIys0OlbVgUOAqL2WHT2dCPMFda-Klpo,4818
|
|
102
102
|
ignite/metrics/maximum_mean_discrepancy.py,sha256=AcrlYW6seQn3ZQKcnPIrLzYK2Ho0riGjuRsJmTNtCms,6444
|
|
103
|
-
ignite/metrics/mean_absolute_error.py,sha256=
|
|
103
|
+
ignite/metrics/mean_absolute_error.py,sha256=lWZWZU7B4My_eEMuICCW3yHllZigNnYuKLRwqQrQYO0,3682
|
|
104
104
|
ignite/metrics/mean_average_precision.py,sha256=cXP9pYidQnAazGXBrhC80WoI4eK4lb3avNO5d70TLd4,19136
|
|
105
105
|
ignite/metrics/mean_pairwise_distance.py,sha256=Ys6Rns6s-USS_tyP6Pa3bWZSI7f_hP5-lZM64UGJGjo,4104
|
|
106
|
-
ignite/metrics/mean_squared_error.py,sha256=
|
|
106
|
+
ignite/metrics/mean_squared_error.py,sha256=UnLLb7XKwvHhOxQWTVhYDCluKETVazG0yDSdX4s9pQY,3666
|
|
107
107
|
ignite/metrics/metric.py,sha256=T3IiFIGTv_UOScd8ei4H9SraHfTJ09OM8I6hRfzr_sA,35141
|
|
108
108
|
ignite/metrics/metric_group.py,sha256=UE7WrMbpKlO9_DPqxQdlmFAWveWoT1knKwRlHDl9YIU,2544
|
|
109
109
|
ignite/metrics/metrics_lambda.py,sha256=NwKZ1J-KzFFbSw7YUaNJozdfKZLVqrkjQvFKT6ixnkg,7309
|
|
@@ -111,10 +111,10 @@ ignite/metrics/multilabel_confusion_matrix.py,sha256=1pjLNPGTDJWAkN_BHdBPekcish6
|
|
|
111
111
|
ignite/metrics/mutual_information.py,sha256=lu1ucVfkx01tGQfELyXzS9woCPOMVImFHfrbIXCvPe8,4692
|
|
112
112
|
ignite/metrics/precision.py,sha256=xe8_e13cPMaC1Mfw-RTlmkag6pdcHCIbi70ASI1IahY,18622
|
|
113
113
|
ignite/metrics/precision_recall_curve.py,sha256=rcmG2W7dDuA_8fyekHNk4ronecewolMprW4rxUB8xsc,6228
|
|
114
|
-
ignite/metrics/psnr.py,sha256=
|
|
114
|
+
ignite/metrics/psnr.py,sha256=a5ZYYVqIdDu9nTQw7WtSYgBuXKqa2xRyQ8WjQ1QeBu8,5539
|
|
115
115
|
ignite/metrics/recall.py,sha256=MaywS5E8ioaHZvTPGhQaYPQV-xDmptYuv8kDRe_-BEY,9867
|
|
116
116
|
ignite/metrics/roc_auc.py,sha256=U97y_JApK2vU1OmZKUJqolHQOZ1qemCSHdxcsLOO2Jg,9246
|
|
117
|
-
ignite/metrics/root_mean_squared_error.py,sha256=
|
|
117
|
+
ignite/metrics/root_mean_squared_error.py,sha256=NQcqnjMn8WUc0QAeGuZiqANTEuk7P1JlIU4eB4hrKUs,2872
|
|
118
118
|
ignite/metrics/running_average.py,sha256=vcC_LtsrJxEMea05TmBFzFqCK6nZd8hHavsfIlf2C6c,11333
|
|
119
119
|
ignite/metrics/ssim.py,sha256=yU877i4wXcHA7vr5qAU9p0LmehEJdKQTFzd2L4Lwm3Q,11866
|
|
120
120
|
ignite/metrics/top_k_categorical_accuracy.py,sha256=pqsArVTSxnwt49S3lZFVqOkCXbzx-WPxfQnhtQ390RM,4706
|
|
@@ -153,7 +153,7 @@ ignite/metrics/regression/spearman_correlation.py,sha256=IzmN4WIe7C4cTUU3BOkBmaw
|
|
|
153
153
|
ignite/metrics/regression/wave_hedges_distance.py,sha256=Ji_NRUgnZ3lJgi5fyNFLRjbHO648z4dBmqVDQU9ImKA,2792
|
|
154
154
|
ignite/metrics/vision/__init__.py,sha256=lPBAEq1idc6Q17poFm1SjttE27irHF1-uNeiwrxnLrU,159
|
|
155
155
|
ignite/metrics/vision/object_detection_average_precision_recall.py,sha256=4wwiNVd658ynIpIbQlffTA-ehvyJ2EzmJ5pBSBuA8XQ,25091
|
|
156
|
-
pytorch_ignite-0.6.0.
|
|
157
|
-
pytorch_ignite-0.6.0.
|
|
158
|
-
pytorch_ignite-0.6.0.
|
|
159
|
-
pytorch_ignite-0.6.0.
|
|
156
|
+
pytorch_ignite-0.6.0.dev20260117.dist-info/METADATA,sha256=A96hpkJgzy8Bpcmzxqjpk9KKAj3QNWFJw3VsZ4R5E4I,27979
|
|
157
|
+
pytorch_ignite-0.6.0.dev20260117.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
158
|
+
pytorch_ignite-0.6.0.dev20260117.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
|
|
159
|
+
pytorch_ignite-0.6.0.dev20260117.dist-info/RECORD,,
|
{pytorch_ignite-0.6.0.dev20260115.dist-info → pytorch_ignite-0.6.0.dev20260117.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|