pytorch-ignite 0.6.0.dev20260118__py3-none-any.whl → 0.6.0.dev20260120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20260118"
9
+ __version__ = "0.6.0.dev20260120"
@@ -1,6 +1,6 @@
1
1
  """TensorBoard logger and its helper handlers."""
2
2
 
3
- from typing import Any, Callable, List, Optional, Union
3
+ from typing import Any, Callable
4
4
 
5
5
  from torch.optim import Optimizer
6
6
 
@@ -321,16 +321,16 @@ class OutputHandler(BaseOutputHandler):
321
321
  def __init__(
322
322
  self,
323
323
  tag: str,
324
- metric_names: Optional[Union[List[str], str]] = None,
325
- output_transform: Optional[Callable] = None,
326
- global_step_transform: Optional[Callable[[Engine, Union[str, Events]], int]] = None,
327
- state_attributes: Optional[List[str]] = None,
324
+ metric_names: list[str] | str | None = None,
325
+ output_transform: Callable | None = None,
326
+ global_step_transform: Callable[[Engine, str | Events], int] | None = None,
327
+ state_attributes: list[str] | None = None,
328
328
  ):
329
329
  super(OutputHandler, self).__init__(
330
330
  tag, metric_names, output_transform, global_step_transform, state_attributes
331
331
  )
332
332
 
333
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
333
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
334
334
  if not isinstance(logger, TensorboardLogger):
335
335
  raise RuntimeError("Handler 'OutputHandler' works only with TensorboardLogger")
336
336
 
@@ -377,10 +377,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
377
377
  )
378
378
  """
379
379
 
380
- def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
380
+ def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: str | None = None):
381
381
  super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
382
382
 
383
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
383
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
384
384
  if not isinstance(logger, TensorboardLogger):
385
385
  raise RuntimeError("Handler OptimizerParamsHandler works only with TensorboardLogger")
386
386
 
@@ -463,7 +463,7 @@ class WeightsScalarHandler(BaseWeightsScalarHandler):
463
463
  optional argument `whitelist` added.
464
464
  """
465
465
 
466
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
466
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
467
467
  if not isinstance(logger, TensorboardLogger):
468
468
  raise RuntimeError("Handler 'WeightsScalarHandler' works only with TensorboardLogger")
469
469
 
@@ -542,7 +542,7 @@ class WeightsHistHandler(BaseWeightsHandler):
542
542
  optional argument `whitelist` added.
543
543
  """
544
544
 
545
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
545
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
546
546
  if not isinstance(logger, TensorboardLogger):
547
547
  raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger")
548
548
 
@@ -624,7 +624,7 @@ class GradsScalarHandler(BaseWeightsScalarHandler):
624
624
  optional argument `whitelist` added.
625
625
  """
626
626
 
627
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
627
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
628
628
  if not isinstance(logger, TensorboardLogger):
629
629
  raise RuntimeError("Handler 'GradsScalarHandler' works only with TensorboardLogger")
630
630
 
@@ -701,7 +701,7 @@ class GradsHistHandler(BaseWeightsHandler):
701
701
  optional argument `whitelist` added.
702
702
  """
703
703
 
704
- def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
704
+ def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: str | Events) -> None:
705
705
  if not isinstance(logger, TensorboardLogger):
706
706
  raise RuntimeError("Handler 'GradsHistHandler' works only with TensorboardLogger")
707
707
 
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Sequence, Tuple, Union
1
+ from typing import Callable, Sequence
2
2
 
3
3
  import torch
4
4
 
@@ -13,12 +13,12 @@ class _BaseClassification(Metric):
13
13
  self,
14
14
  output_transform: Callable = lambda x: x,
15
15
  is_multilabel: bool = False,
16
- device: Union[str, torch.device] = torch.device("cpu"),
16
+ device: str | torch.device = torch.device("cpu"),
17
17
  skip_unrolling: bool = False,
18
18
  ):
19
19
  self._is_multilabel = is_multilabel
20
- self._type: Optional[str] = None
21
- self._num_classes: Optional[int] = None
20
+ self._type: str | None = None
21
+ self._num_classes: int | None = None
22
22
  super(_BaseClassification, self).__init__(
23
23
  output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
24
24
  )
@@ -38,7 +38,7 @@ class _BaseClassification(Metric):
38
38
  )
39
39
 
40
40
  y_shape = y.shape
41
- y_pred_shape: Tuple[int, ...] = y_pred.shape
41
+ y_pred_shape: tuple[int, ...] = y_pred.shape
42
42
 
43
43
  if y.ndimension() + 1 == y_pred.ndimension():
44
44
  y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
@@ -223,7 +223,7 @@ class Accuracy(_BaseClassification):
223
223
  self,
224
224
  output_transform: Callable = lambda x: x,
225
225
  is_multilabel: bool = False,
226
- device: Union[str, torch.device] = torch.device("cpu"),
226
+ device: str | torch.device = torch.device("cpu"),
227
227
  skip_unrolling: bool = False,
228
228
  ):
229
229
  super(Accuracy, self).__init__(
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Union
1
+ from typing import Callable, Literal
2
2
 
3
3
  import torch
4
4
 
@@ -56,9 +56,9 @@ class CohenKappa(EpochMetric):
56
56
  def __init__(
57
57
  self,
58
58
  output_transform: Callable = lambda x: x,
59
- weights: Optional[str] = None,
59
+ weights: Literal["linear", "quadratic"] | None = None,
60
60
  check_compute_fn: bool = False,
61
- device: Union[str, torch.device] = torch.device("cpu"),
61
+ device: str | torch.device = torch.device("cpu"),
62
62
  skip_unrolling: bool = False,
63
63
  ):
64
64
  try:
@@ -68,8 +68,8 @@ class CohenKappa(EpochMetric):
68
68
  if weights not in (None, "linear", "quadratic"):
69
69
  raise ValueError("Kappa Weighting type must be None or linear or quadratic.")
70
70
 
71
- # initalize weights
72
- self.weights = weights
71
+ # initialize weights
72
+ self.weights: Literal["linear", "quadratic"] | None = weights
73
73
 
74
74
  super(CohenKappa, self).__init__(
75
75
  self._cohen_kappa_score,
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, Sequence, Tuple
1
+ from typing import Any, Callable, Sequence
2
2
 
3
3
  import torch
4
4
 
@@ -44,10 +44,10 @@ class MetricGroup(Metric):
44
44
  ``skip_unrolling`` argument is added.
45
45
  """
46
46
 
47
- _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",)
47
+ _state_dict_all_req_keys: tuple[str, ...] = ("metrics",)
48
48
 
49
49
  def __init__(
50
- self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False
50
+ self, metrics: dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False
51
51
  ):
52
52
  self.metrics = metrics
53
53
  super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
@@ -60,5 +60,5 @@ class MetricGroup(Metric):
60
60
  for m in self.metrics.values():
61
61
  m.update(m._output_transform(output))
62
62
 
63
- def compute(self) -> Dict[str, Any]:
63
+ def compute(self) -> dict[str, Any]:
64
64
  return {k: m.compute() for k, m in self.metrics.items()}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pytorch-ignite
3
- Version: 0.6.0.dev20260118
3
+ Version: 0.6.0.dev20260120
4
4
  Summary: A lightweight library to help with training neural networks in PyTorch.
5
5
  Project-URL: Homepage, https://pytorch-ignite.ai
6
6
  Project-URL: Repository, https://github.com/pytorch/ignite
@@ -1,4 +1,4 @@
1
- ignite/__init__.py,sha256=xQh94zBR0e8L-_Niygli9Iy2moUBK8XfRHQ_J9miR1Q,194
1
+ ignite/__init__.py,sha256=KCPXiPhh4fXqJPVd1LqDaY1tRXNOBFLJoL2qhYTHDOs,194
2
2
  ignite/_utils.py,sha256=XDPpUDJ8ykLXWMV2AYTqGSj8XCfApsyzsQ3Vij_OB4M,182
3
3
  ignite/exceptions.py,sha256=5ZWCVLPC9rgoW8t84D-VeEleqz5O7XpAGPpCdU8rKd0,150
4
4
  ignite/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -73,7 +73,7 @@ ignite/handlers/param_scheduler.py,sha256=Tn4o27YBrp5JsuadHobIrsHfmvB_cR1IrV_oV1
73
73
  ignite/handlers/polyaxon_logger.py,sha256=7nyOhu4rgg4Sj6uTNUobPG9IJqYxjb_fopKiF4fk9vc,12346
74
74
  ignite/handlers/state_param_scheduler.py,sha256=B89YKZyj9DXLXQyr3amDNMslUOWNHZDis2DXIwW0q10,20841
75
75
  ignite/handlers/stores.py,sha256=8XM_Qqsitfu0WtOOE-K2FMtv51vD90r3GgQlCzRABYc,2616
76
- ignite/handlers/tensorboard_logger.py,sha256=e6gS9b7dkGjm-iR4iTQWPZcYKrDUml7cmrriNXoMBoU,27959
76
+ ignite/handlers/tensorboard_logger.py,sha256=_kESDcCqFqhdA-IIgwVnMZ3JxS1O7grROG5Bc6wRvoc,27873
77
77
  ignite/handlers/terminate_on_nan.py,sha256=RFSKd3Oqn9Me2xLCos4lSE-hnY7fYWWjE9blioeMlIs,2103
78
78
  ignite/handlers/time_limit.py,sha256=heTuS-ReBbOUCm1NcNJGhzxI080Hanc4hOLB2Y4GeZk,1567
79
79
  ignite/handlers/time_profilers.py,sha256=8iCcBYPxv0vKFSO_ujFV0ST54a9PD9ezFLvYTIu9lFI,30482
@@ -84,10 +84,10 @@ ignite/handlers/visdom_logger.py,sha256=OlnqVDXKYTea6VG7rwcQdzrfYmai0SSxNLh2Kqsj
84
84
  ignite/handlers/wandb_logger.py,sha256=9HbwRMHzWckrZ-m0rkMF5Ug6r9C9J4sdq73yqaAHWYE,14829
85
85
  ignite/metrics/__init__.py,sha256=m-8F8J17r-aEwsO6Ww-8AqDRN59WFfYBwCDKwqGDSmI,3627
86
86
  ignite/metrics/accumulation.py,sha256=xWdsm9u6JfsfODX_GUKzQc_omrdFDJ4yELBR-xXgc4s,12448
87
- ignite/metrics/accuracy.py,sha256=W8mO4W11VzryMXKy8G7W_g4A9PH9RYpejW_tQ-T_Txw,10245
87
+ ignite/metrics/accuracy.py,sha256=hVDvMG2kc-EN9H54tdMYWy6D3Zj74DzpHdG3Eq5B6Nc,10203
88
88
  ignite/metrics/average_precision.py,sha256=laDD8BnAC5OuAJrCRtwCZ7EjjoQKRb7D3o-86IRsdN4,3681
89
89
  ignite/metrics/classification_report.py,sha256=zjGlaMnRz2__op6hrZq74OusO0W_5B1AIe8KzYGFilM,5988
90
- ignite/metrics/cohen_kappa.py,sha256=Qwcd4P2kN12CVCFC-kVdzn_2XV7kGzP6LlWkK209JJ8,3815
90
+ ignite/metrics/cohen_kappa.py,sha256=D3vOkIK86qoGvg9nis5PlbRJ8uYZmNdpLJaBRVi0Moc,3865
91
91
  ignite/metrics/confusion_matrix.py,sha256=dZDuK3vxrrbiQh6VfyV5aWFpuTJWsfnZ30Mxt6u6eOA,18215
92
92
  ignite/metrics/cosine_similarity.py,sha256=9f5dM0QaiXBznidlUuKb8Q3E45W_Z_hGiy9WP5Fpcqw,4416
93
93
  ignite/metrics/entropy.py,sha256=gJZkR5Sl1ZdIzJ9pFkydf1186bZU8OnkOLvOtKz6Wrs,4511
@@ -105,7 +105,7 @@ ignite/metrics/mean_average_precision.py,sha256=cXP9pYidQnAazGXBrhC80WoI4eK4lb3a
105
105
  ignite/metrics/mean_pairwise_distance.py,sha256=Ys6Rns6s-USS_tyP6Pa3bWZSI7f_hP5-lZM64UGJGjo,4104
106
106
  ignite/metrics/mean_squared_error.py,sha256=UnLLb7XKwvHhOxQWTVhYDCluKETVazG0yDSdX4s9pQY,3666
107
107
  ignite/metrics/metric.py,sha256=T3IiFIGTv_UOScd8ei4H9SraHfTJ09OM8I6hRfzr_sA,35141
108
- ignite/metrics/metric_group.py,sha256=UE7WrMbpKlO9_DPqxQdlmFAWveWoT1knKwRlHDl9YIU,2544
108
+ ignite/metrics/metric_group.py,sha256=yiS7MXQB1wROUXnK57omu6oGMkQ0WudYDx-Ee-togmo,2531
109
109
  ignite/metrics/metrics_lambda.py,sha256=NwKZ1J-KzFFbSw7YUaNJozdfKZLVqrkjQvFKT6ixnkg,7309
110
110
  ignite/metrics/multilabel_confusion_matrix.py,sha256=1pjLNPGTDJWAkN_BHdBPekcish6Ra0uRUeEbdj3Dm6Y,7377
111
111
  ignite/metrics/mutual_information.py,sha256=lu1ucVfkx01tGQfELyXzS9woCPOMVImFHfrbIXCvPe8,4692
@@ -153,7 +153,7 @@ ignite/metrics/regression/spearman_correlation.py,sha256=IzmN4WIe7C4cTUU3BOkBmaw
153
153
  ignite/metrics/regression/wave_hedges_distance.py,sha256=Ji_NRUgnZ3lJgi5fyNFLRjbHO648z4dBmqVDQU9ImKA,2792
154
154
  ignite/metrics/vision/__init__.py,sha256=lPBAEq1idc6Q17poFm1SjttE27irHF1-uNeiwrxnLrU,159
155
155
  ignite/metrics/vision/object_detection_average_precision_recall.py,sha256=4wwiNVd658ynIpIbQlffTA-ehvyJ2EzmJ5pBSBuA8XQ,25091
156
- pytorch_ignite-0.6.0.dev20260118.dist-info/METADATA,sha256=EhmYgR6EJCry4dtHlWqk0E-DvLRY3sphi_JZ_eJfhnE,27979
157
- pytorch_ignite-0.6.0.dev20260118.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
158
- pytorch_ignite-0.6.0.dev20260118.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
159
- pytorch_ignite-0.6.0.dev20260118.dist-info/RECORD,,
156
+ pytorch_ignite-0.6.0.dev20260120.dist-info/METADATA,sha256=s3fEGt1qbwRV18P-TNhk7LzNiE9cQPacXI8Mey3Cztk,27979
157
+ pytorch_ignite-0.6.0.dev20260120.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
158
+ pytorch_ignite-0.6.0.dev20260120.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
159
+ pytorch_ignite-0.6.0.dev20260120.dist-info/RECORD,,