pytorch-ignite 0.6.0.dev20260118__py3-none-any.whl → 0.6.0.dev20260120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ignite/__init__.py +1 -1
- ignite/handlers/tensorboard_logger.py +12 -12
- ignite/metrics/accuracy.py +6 -6
- ignite/metrics/cohen_kappa.py +5 -5
- ignite/metrics/metric_group.py +4 -4
- {pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/METADATA +1 -1
- {pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/RECORD +9 -9
- {pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/WHEEL +0 -0
- {pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""TensorBoard logger and its helper handlers."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable
|
|
3
|
+
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
from torch.optim import Optimizer
|
|
6
6
|
|
|
@@ -321,16 +321,16 @@ class OutputHandler(BaseOutputHandler):
|
|
|
321
321
|
def __init__(
|
|
322
322
|
self,
|
|
323
323
|
tag: str,
|
|
324
|
-
metric_names:
|
|
325
|
-
output_transform:
|
|
326
|
-
global_step_transform:
|
|
327
|
-
state_attributes:
|
|
324
|
+
metric_names: list[str] | str | None = None,
|
|
325
|
+
output_transform: Callable | None = None,
|
|
326
|
+
global_step_transform: Callable[[Engine, str | Events], int] | None = None,
|
|
327
|
+
state_attributes: list[str] | None = None,
|
|
328
328
|
):
|
|
329
329
|
super(OutputHandler, self).__init__(
|
|
330
330
|
tag, metric_names, output_transform, global_step_transform, state_attributes
|
|
331
331
|
)
|
|
332
332
|
|
|
333
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
333
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
334
334
|
if not isinstance(logger, TensorboardLogger):
|
|
335
335
|
raise RuntimeError("Handler 'OutputHandler' works only with TensorboardLogger")
|
|
336
336
|
|
|
@@ -377,10 +377,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
377
377
|
)
|
|
378
378
|
"""
|
|
379
379
|
|
|
380
|
-
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag:
|
|
380
|
+
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: str | None = None):
|
|
381
381
|
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
|
|
382
382
|
|
|
383
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
383
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
384
384
|
if not isinstance(logger, TensorboardLogger):
|
|
385
385
|
raise RuntimeError("Handler OptimizerParamsHandler works only with TensorboardLogger")
|
|
386
386
|
|
|
@@ -463,7 +463,7 @@ class WeightsScalarHandler(BaseWeightsScalarHandler):
|
|
|
463
463
|
optional argument `whitelist` added.
|
|
464
464
|
"""
|
|
465
465
|
|
|
466
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
466
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
467
467
|
if not isinstance(logger, TensorboardLogger):
|
|
468
468
|
raise RuntimeError("Handler 'WeightsScalarHandler' works only with TensorboardLogger")
|
|
469
469
|
|
|
@@ -542,7 +542,7 @@ class WeightsHistHandler(BaseWeightsHandler):
|
|
|
542
542
|
optional argument `whitelist` added.
|
|
543
543
|
"""
|
|
544
544
|
|
|
545
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
545
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
546
546
|
if not isinstance(logger, TensorboardLogger):
|
|
547
547
|
raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger")
|
|
548
548
|
|
|
@@ -624,7 +624,7 @@ class GradsScalarHandler(BaseWeightsScalarHandler):
|
|
|
624
624
|
optional argument `whitelist` added.
|
|
625
625
|
"""
|
|
626
626
|
|
|
627
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
627
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
628
628
|
if not isinstance(logger, TensorboardLogger):
|
|
629
629
|
raise RuntimeError("Handler 'GradsScalarHandler' works only with TensorboardLogger")
|
|
630
630
|
|
|
@@ -701,7 +701,7 @@ class GradsHistHandler(BaseWeightsHandler):
|
|
|
701
701
|
optional argument `whitelist` added.
|
|
702
702
|
"""
|
|
703
703
|
|
|
704
|
-
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name:
|
|
704
|
+
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
|
|
705
705
|
if not isinstance(logger, TensorboardLogger):
|
|
706
706
|
raise RuntimeError("Handler 'GradsHistHandler' works only with TensorboardLogger")
|
|
707
707
|
|
ignite/metrics/accuracy.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable,
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -13,12 +13,12 @@ class _BaseClassification(Metric):
|
|
|
13
13
|
self,
|
|
14
14
|
output_transform: Callable = lambda x: x,
|
|
15
15
|
is_multilabel: bool = False,
|
|
16
|
-
device:
|
|
16
|
+
device: str | torch.device = torch.device("cpu"),
|
|
17
17
|
skip_unrolling: bool = False,
|
|
18
18
|
):
|
|
19
19
|
self._is_multilabel = is_multilabel
|
|
20
|
-
self._type:
|
|
21
|
-
self._num_classes:
|
|
20
|
+
self._type: str | None = None
|
|
21
|
+
self._num_classes: int | None = None
|
|
22
22
|
super(_BaseClassification, self).__init__(
|
|
23
23
|
output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
|
|
24
24
|
)
|
|
@@ -38,7 +38,7 @@ class _BaseClassification(Metric):
|
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
y_shape = y.shape
|
|
41
|
-
y_pred_shape:
|
|
41
|
+
y_pred_shape: tuple[int, ...] = y_pred.shape
|
|
42
42
|
|
|
43
43
|
if y.ndimension() + 1 == y_pred.ndimension():
|
|
44
44
|
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
|
|
@@ -223,7 +223,7 @@ class Accuracy(_BaseClassification):
|
|
|
223
223
|
self,
|
|
224
224
|
output_transform: Callable = lambda x: x,
|
|
225
225
|
is_multilabel: bool = False,
|
|
226
|
-
device:
|
|
226
|
+
device: str | torch.device = torch.device("cpu"),
|
|
227
227
|
skip_unrolling: bool = False,
|
|
228
228
|
):
|
|
229
229
|
super(Accuracy, self).__init__(
|
ignite/metrics/cohen_kappa.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable,
|
|
1
|
+
from typing import Callable, Literal
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -56,9 +56,9 @@ class CohenKappa(EpochMetric):
|
|
|
56
56
|
def __init__(
|
|
57
57
|
self,
|
|
58
58
|
output_transform: Callable = lambda x: x,
|
|
59
|
-
weights:
|
|
59
|
+
weights: Literal["linear", "quadratic"] | None = None,
|
|
60
60
|
check_compute_fn: bool = False,
|
|
61
|
-
device:
|
|
61
|
+
device: str | torch.device = torch.device("cpu"),
|
|
62
62
|
skip_unrolling: bool = False,
|
|
63
63
|
):
|
|
64
64
|
try:
|
|
@@ -68,8 +68,8 @@ class CohenKappa(EpochMetric):
|
|
|
68
68
|
if weights not in (None, "linear", "quadratic"):
|
|
69
69
|
raise ValueError("Kappa Weighting type must be None or linear or quadratic.")
|
|
70
70
|
|
|
71
|
-
#
|
|
72
|
-
self.weights = weights
|
|
71
|
+
# initialize weights
|
|
72
|
+
self.weights: Literal["linear", "quadratic"] | None = weights
|
|
73
73
|
|
|
74
74
|
super(CohenKappa, self).__init__(
|
|
75
75
|
self._cohen_kappa_score,
|
ignite/metrics/metric_group.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable,
|
|
1
|
+
from typing import Any, Callable, Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -44,10 +44,10 @@ class MetricGroup(Metric):
|
|
|
44
44
|
``skip_unrolling`` argument is added.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
_state_dict_all_req_keys:
|
|
47
|
+
_state_dict_all_req_keys: tuple[str, ...] = ("metrics",)
|
|
48
48
|
|
|
49
49
|
def __init__(
|
|
50
|
-
self, metrics:
|
|
50
|
+
self, metrics: dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False
|
|
51
51
|
):
|
|
52
52
|
self.metrics = metrics
|
|
53
53
|
super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
|
|
@@ -60,5 +60,5 @@ class MetricGroup(Metric):
|
|
|
60
60
|
for m in self.metrics.values():
|
|
61
61
|
m.update(m._output_transform(output))
|
|
62
62
|
|
|
63
|
-
def compute(self) ->
|
|
63
|
+
def compute(self) -> dict[str, Any]:
|
|
64
64
|
return {k: m.compute() for k, m in self.metrics.items()}
|
{pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytorch-ignite
|
|
3
|
-
Version: 0.6.0.
|
|
3
|
+
Version: 0.6.0.dev20260120
|
|
4
4
|
Summary: A lightweight library to help with training neural networks in PyTorch.
|
|
5
5
|
Project-URL: Homepage, https://pytorch-ignite.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/pytorch/ignite
|
{pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/RECORD
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
ignite/__init__.py,sha256=
|
|
1
|
+
ignite/__init__.py,sha256=KCPXiPhh4fXqJPVd1LqDaY1tRXNOBFLJoL2qhYTHDOs,194
|
|
2
2
|
ignite/_utils.py,sha256=XDPpUDJ8ykLXWMV2AYTqGSj8XCfApsyzsQ3Vij_OB4M,182
|
|
3
3
|
ignite/exceptions.py,sha256=5ZWCVLPC9rgoW8t84D-VeEleqz5O7XpAGPpCdU8rKd0,150
|
|
4
4
|
ignite/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -73,7 +73,7 @@ ignite/handlers/param_scheduler.py,sha256=Tn4o27YBrp5JsuadHobIrsHfmvB_cR1IrV_oV1
|
|
|
73
73
|
ignite/handlers/polyaxon_logger.py,sha256=7nyOhu4rgg4Sj6uTNUobPG9IJqYxjb_fopKiF4fk9vc,12346
|
|
74
74
|
ignite/handlers/state_param_scheduler.py,sha256=B89YKZyj9DXLXQyr3amDNMslUOWNHZDis2DXIwW0q10,20841
|
|
75
75
|
ignite/handlers/stores.py,sha256=8XM_Qqsitfu0WtOOE-K2FMtv51vD90r3GgQlCzRABYc,2616
|
|
76
|
-
ignite/handlers/tensorboard_logger.py,sha256=
|
|
76
|
+
ignite/handlers/tensorboard_logger.py,sha256=_kESDcCqFqhdA-IIgwVnMZ3JxS1O7grROG5Bc6wRvoc,27873
|
|
77
77
|
ignite/handlers/terminate_on_nan.py,sha256=RFSKd3Oqn9Me2xLCos4lSE-hnY7fYWWjE9blioeMlIs,2103
|
|
78
78
|
ignite/handlers/time_limit.py,sha256=heTuS-ReBbOUCm1NcNJGhzxI080Hanc4hOLB2Y4GeZk,1567
|
|
79
79
|
ignite/handlers/time_profilers.py,sha256=8iCcBYPxv0vKFSO_ujFV0ST54a9PD9ezFLvYTIu9lFI,30482
|
|
@@ -84,10 +84,10 @@ ignite/handlers/visdom_logger.py,sha256=OlnqVDXKYTea6VG7rwcQdzrfYmai0SSxNLh2Kqsj
|
|
|
84
84
|
ignite/handlers/wandb_logger.py,sha256=9HbwRMHzWckrZ-m0rkMF5Ug6r9C9J4sdq73yqaAHWYE,14829
|
|
85
85
|
ignite/metrics/__init__.py,sha256=m-8F8J17r-aEwsO6Ww-8AqDRN59WFfYBwCDKwqGDSmI,3627
|
|
86
86
|
ignite/metrics/accumulation.py,sha256=xWdsm9u6JfsfODX_GUKzQc_omrdFDJ4yELBR-xXgc4s,12448
|
|
87
|
-
ignite/metrics/accuracy.py,sha256=
|
|
87
|
+
ignite/metrics/accuracy.py,sha256=hVDvMG2kc-EN9H54tdMYWy6D3Zj74DzpHdG3Eq5B6Nc,10203
|
|
88
88
|
ignite/metrics/average_precision.py,sha256=laDD8BnAC5OuAJrCRtwCZ7EjjoQKRb7D3o-86IRsdN4,3681
|
|
89
89
|
ignite/metrics/classification_report.py,sha256=zjGlaMnRz2__op6hrZq74OusO0W_5B1AIe8KzYGFilM,5988
|
|
90
|
-
ignite/metrics/cohen_kappa.py,sha256=
|
|
90
|
+
ignite/metrics/cohen_kappa.py,sha256=D3vOkIK86qoGvg9nis5PlbRJ8uYZmNdpLJaBRVi0Moc,3865
|
|
91
91
|
ignite/metrics/confusion_matrix.py,sha256=dZDuK3vxrrbiQh6VfyV5aWFpuTJWsfnZ30Mxt6u6eOA,18215
|
|
92
92
|
ignite/metrics/cosine_similarity.py,sha256=9f5dM0QaiXBznidlUuKb8Q3E45W_Z_hGiy9WP5Fpcqw,4416
|
|
93
93
|
ignite/metrics/entropy.py,sha256=gJZkR5Sl1ZdIzJ9pFkydf1186bZU8OnkOLvOtKz6Wrs,4511
|
|
@@ -105,7 +105,7 @@ ignite/metrics/mean_average_precision.py,sha256=cXP9pYidQnAazGXBrhC80WoI4eK4lb3a
|
|
|
105
105
|
ignite/metrics/mean_pairwise_distance.py,sha256=Ys6Rns6s-USS_tyP6Pa3bWZSI7f_hP5-lZM64UGJGjo,4104
|
|
106
106
|
ignite/metrics/mean_squared_error.py,sha256=UnLLb7XKwvHhOxQWTVhYDCluKETVazG0yDSdX4s9pQY,3666
|
|
107
107
|
ignite/metrics/metric.py,sha256=T3IiFIGTv_UOScd8ei4H9SraHfTJ09OM8I6hRfzr_sA,35141
|
|
108
|
-
ignite/metrics/metric_group.py,sha256=
|
|
108
|
+
ignite/metrics/metric_group.py,sha256=yiS7MXQB1wROUXnK57omu6oGMkQ0WudYDx-Ee-togmo,2531
|
|
109
109
|
ignite/metrics/metrics_lambda.py,sha256=NwKZ1J-KzFFbSw7YUaNJozdfKZLVqrkjQvFKT6ixnkg,7309
|
|
110
110
|
ignite/metrics/multilabel_confusion_matrix.py,sha256=1pjLNPGTDJWAkN_BHdBPekcish6Ra0uRUeEbdj3Dm6Y,7377
|
|
111
111
|
ignite/metrics/mutual_information.py,sha256=lu1ucVfkx01tGQfELyXzS9woCPOMVImFHfrbIXCvPe8,4692
|
|
@@ -153,7 +153,7 @@ ignite/metrics/regression/spearman_correlation.py,sha256=IzmN4WIe7C4cTUU3BOkBmaw
|
|
|
153
153
|
ignite/metrics/regression/wave_hedges_distance.py,sha256=Ji_NRUgnZ3lJgi5fyNFLRjbHO648z4dBmqVDQU9ImKA,2792
|
|
154
154
|
ignite/metrics/vision/__init__.py,sha256=lPBAEq1idc6Q17poFm1SjttE27irHF1-uNeiwrxnLrU,159
|
|
155
155
|
ignite/metrics/vision/object_detection_average_precision_recall.py,sha256=4wwiNVd658ynIpIbQlffTA-ehvyJ2EzmJ5pBSBuA8XQ,25091
|
|
156
|
-
pytorch_ignite-0.6.0.
|
|
157
|
-
pytorch_ignite-0.6.0.
|
|
158
|
-
pytorch_ignite-0.6.0.
|
|
159
|
-
pytorch_ignite-0.6.0.
|
|
156
|
+
pytorch_ignite-0.6.0.dev20260120.dist-info/METADATA,sha256=s3fEGt1qbwRV18P-TNhk7LzNiE9cQPacXI8Mey3Cztk,27979
|
|
157
|
+
pytorch_ignite-0.6.0.dev20260120.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
158
|
+
pytorch_ignite-0.6.0.dev20260120.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
|
|
159
|
+
pytorch_ignite-0.6.0.dev20260120.dist-info/RECORD,,
|
{pytorch_ignite-0.6.0.dev20260118.dist-info → pytorch_ignite-0.6.0.dev20260120.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|