pytorch-ignite 0.6.0.dev20251216__py3-none-any.whl → 0.6.0.dev20251218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ignite/__init__.py CHANGED
@@ -6,4 +6,4 @@ import ignite.handlers
6
6
  import ignite.metrics
7
7
  import ignite.utils
8
8
 
9
- __version__ = "0.6.0.dev20251216"
9
+ __version__ = "0.6.0.dev20251218"
@@ -117,6 +117,5 @@ def create_supervised_tbptt_trainer(
117
117
  return sum(loss_list) / len(loss_list)
118
118
 
119
119
  engine = Engine(_update)
120
- # pyrefly: ignore [bad-argument-type]
121
120
  engine.register_events(*Tbptt_Events)
122
121
  return engine
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
3
+ from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
4
4
 
5
5
  import torch
6
6
 
@@ -20,6 +20,11 @@ try:
20
20
  except ImportError:
21
21
  has_hvd_support = False
22
22
 
23
+ if TYPE_CHECKING:
24
+ # Tell the type checker that hvd imports are always defined.
25
+ import horovod.torch as hvd
26
+ from horovod import run as hvd_mp_spawn
27
+
23
28
 
24
29
  if has_hvd_support:
25
30
  HOROVOD = "horovod"
@@ -171,11 +176,8 @@ if has_hvd_support:
171
176
  return group
172
177
 
173
178
  _reduce_op_map = {
174
- # pyrefly: ignore [unbound-name]
175
179
  "SUM": hvd.mpi_ops.Sum,
176
- # pyrefly: ignore [unbound-name]
177
180
  "AVERAGE": hvd.mpi_ops.Average,
178
- # pyrefly: ignore [unbound-name]
179
181
  "ADASUM": hvd.mpi_ops.Adasum,
180
182
  }
181
183
 
@@ -322,19 +322,15 @@ class Parallel:
322
322
  idist.initialize(self.backend, init_method=self.init_method)
323
323
 
324
324
  # The logger can be setup from now since idist.initialize() has been called (if needed)
325
- self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment]
325
+ self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
326
326
 
327
327
  if self.backend is not None:
328
328
  if self._spawn_params is None:
329
- self._logger.info( # type: ignore[attr-defined]
330
- f"Initialized processing group with backend: '{self.backend}'"
331
- )
329
+ self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
332
330
  else:
333
- self._logger.info( # type: ignore[attr-defined]
334
- f"Initialized distributed launcher with backend: '{self.backend}'"
335
- )
331
+ self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
336
332
  msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
337
- self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined]
333
+ self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
338
334
 
339
335
  return self
340
336
 
ignite/engine/engine.py CHANGED
@@ -148,12 +148,11 @@ class Engine(Serializable):
148
148
  self.should_interrupt = False
149
149
  self.state = State()
150
150
  self._state_dict_user_keys: List[str] = []
151
- self._allowed_events: List[EventEnum] = []
151
+ self._allowed_events: List[Union[str, EventEnum]] = []
152
152
 
153
153
  self._dataloader_iter: Optional[Iterator[Any]] = None
154
154
  self._init_iter: Optional[int] = None
155
155
 
156
- # pyrefly: ignore [bad-argument-type]
157
156
  self.register_events(*Events)
158
157
 
159
158
  if self._process_function is None:
@@ -164,9 +163,7 @@ class Engine(Serializable):
164
163
  # generator provided by self._internal_run_as_gen
165
164
  self._internal_run_generator: Optional[Generator[Any, None, State]] = None
166
165
 
167
- def register_events(
168
- self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
169
- ) -> None:
166
+ def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
170
167
  """Add events that can be fired.
171
168
 
172
169
  Registering an event will let the user trigger these events at any point.
@@ -451,7 +448,7 @@ class Engine(Serializable):
451
448
  first, others = ((resolved_engine,), args[1:])
452
449
  else:
453
450
  # metrics do not provide engine when registered
454
- first, others = (tuple(), args) # type: ignore[assignment]
451
+ first, others = (tuple(), args)
455
452
 
456
453
  func(*first, *(event_args + others), **kwargs)
457
454
 
@@ -990,9 +987,9 @@ class Engine(Serializable):
990
987
  def _internal_run_as_gen(self) -> Generator[Any, None, State]:
991
988
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
992
989
  self._init_timers(self.state)
990
+ start_time = time.time()
993
991
  try:
994
992
  try:
995
- start_time = time.time()
996
993
  self._fire_event(Events.STARTED)
997
994
  yield from self._maybe_terminate_or_interrupt()
998
995
 
@@ -1011,7 +1008,7 @@ class Engine(Serializable):
1011
1008
  # time is available for handlers but must be updated after fire
1012
1009
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1013
1010
 
1014
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1011
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1015
1012
  handlers_start_time = time.time()
1016
1013
  self._fire_event(Events.EPOCH_COMPLETED)
1017
1014
  epoch_time_taken += time.time() - handlers_start_time
@@ -1039,13 +1036,12 @@ class Engine(Serializable):
1039
1036
  "https://github.com/pytorch/ignite/issues/new/choose"
1040
1037
  )
1041
1038
 
1042
- # pyrefly: ignore [unbound-name]
1043
1039
  time_taken = time.time() - start_time
1044
1040
  # time is available for handlers but must be updated after fire
1045
1041
  self.state.times[Events.COMPLETED.name] = time_taken
1046
1042
 
1047
1043
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1048
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1044
+ if self.should_terminate != "skip_completed":
1049
1045
  handlers_start_time = time.time()
1050
1046
  self._fire_event(Events.COMPLETED)
1051
1047
  time_taken += time.time() - handlers_start_time
@@ -1191,9 +1187,9 @@ class Engine(Serializable):
1191
1187
  # internal_run without generator for BC
1192
1188
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
1193
1189
  self._init_timers(self.state)
1190
+ start_time = time.time()
1194
1191
  try:
1195
1192
  try:
1196
- start_time = time.time()
1197
1193
  self._fire_event(Events.STARTED)
1198
1194
  self._maybe_terminate_legacy()
1199
1195
 
@@ -1212,7 +1208,7 @@ class Engine(Serializable):
1212
1208
  # time is available for handlers but must be updated after fire
1213
1209
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1214
1210
 
1215
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1211
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1216
1212
  handlers_start_time = time.time()
1217
1213
  self._fire_event(Events.EPOCH_COMPLETED)
1218
1214
  epoch_time_taken += time.time() - handlers_start_time
@@ -1240,13 +1236,12 @@ class Engine(Serializable):
1240
1236
  "https://github.com/pytorch/ignite/issues/new/choose"
1241
1237
  )
1242
1238
 
1243
- # pyrefly: ignore [unbound-name]
1244
1239
  time_taken = time.time() - start_time
1245
1240
  # time is available for handlers but must be updated after fire
1246
1241
  self.state.times[Events.COMPLETED.name] = time_taken
1247
1242
 
1248
1243
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1249
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1244
+ if self.should_terminate != "skip_completed":
1250
1245
  handlers_start_time = time.time()
1251
1246
  self._fire_event(Events.COMPLETED)
1252
1247
  time_taken += time.time() - handlers_start_time
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
315
315
  """
316
316
 
317
317
  SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
318
- Item = NamedTuple("Item", [("priority", int), ("filename", str)])
318
+ Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
319
319
  _state_dict_all_req_keys = ("_saved",)
320
320
 
321
321
  def __init__(
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
323
323
  to_save: Mapping,
324
324
  save_handler: Union[str, Path, Callable, BaseSaveHandler],
325
325
  filename_prefix: str = "",
326
- score_function: Optional[Callable] = None,
326
+ score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
327
327
  score_name: Optional[str] = None,
328
328
  n_saved: Union[int, None] = 1,
329
329
  global_step_transform: Optional[Callable] = None,
@@ -440,7 +440,6 @@ class Checkpoint(Serializable):
440
440
 
441
441
  def __call__(self, engine: Engine) -> None:
442
442
  if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
443
- # pyrefly: ignore [bad-argument-type]
444
443
  engine.register_events(*CheckpointEvents)
445
444
  global_step = None
446
445
  if self.global_step_transform is not None:
@@ -455,7 +454,6 @@ class Checkpoint(Serializable):
455
454
  global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
456
455
  priority = global_step
457
456
 
458
- # pyrefly: ignore [bad-argument-type]
459
457
  if self._check_lt_n_saved() or self._compare_fn(priority):
460
458
  priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"
461
459
 
@@ -497,7 +495,6 @@ class Checkpoint(Serializable):
497
495
  if isinstance(self.save_handler, BaseSaveHandler):
498
496
  self.save_handler.remove(item.filename)
499
497
 
500
- # pyrefly: ignore [bad-argument-type]
501
498
  self._saved.append(Checkpoint.Item(priority, filename))
502
499
  self._saved.sort(key=lambda it: it[0])
503
500
 
@@ -862,7 +862,7 @@ class ClearMLSaver(DiskSaver):
862
862
  except ImportError:
863
863
  try:
864
864
  # Backwards-compatibility for legacy Trains SDK
865
- from trains import Task # type: ignore[no-redef]
865
+ from trains import Task
866
866
  except ImportError:
867
867
  raise ModuleNotFoundError(
868
868
  "This contrib module requires clearml to be installed. "
@@ -937,7 +937,7 @@ class ClearMLSaver(DiskSaver):
937
937
  except ImportError:
938
938
  try:
939
939
  # Backwards-compatibility for legacy Trains SDK
940
- from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef]
940
+ from trains.binding.frameworks import WeightsFileHandler
941
941
  except ImportError:
942
942
  raise ModuleNotFoundError(
943
943
  "This contrib module requires clearml to be installed. "
@@ -98,16 +98,18 @@ class FastaiLRFinder:
98
98
  self._best_loss = None
99
99
  self._diverge_flag = False
100
100
 
101
+ assert trainer.state.epoch_length is not None
102
+ assert trainer.state.max_epochs is not None
103
+
101
104
  # attach LRScheduler to trainer.
102
105
  if num_iter is None:
103
- # pyrefly: ignore [unsupported-operation]
104
106
  num_iter = trainer.state.epoch_length * trainer.state.max_epochs
105
107
  else:
106
- max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
108
+ max_iter = trainer.state.epoch_length * trainer.state.max_epochs
107
109
  if max_iter < num_iter:
108
110
  max_iter = num_iter
109
111
  trainer.state.max_iters = num_iter
110
- trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
112
+ trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
111
113
 
112
114
  if not trainer.has_event_handler(self._reached_num_iterations):
113
115
  trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
@@ -179,18 +181,14 @@ class FastaiLRFinder:
179
181
  loss = idist.all_reduce(loss)
180
182
  lr = self._lr_schedule.get_param()
181
183
  self._history["lr"].append(lr)
182
- if trainer.state.iteration == 1:
183
- self._best_loss = loss # type: ignore[assignment]
184
- else:
185
- if smooth_f > 0:
186
- loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
187
- # pyrefly: ignore [unsupported-operation]
188
- if loss < self._best_loss:
189
- self._best_loss = loss
184
+ if trainer.state.iteration != 1 and smooth_f > 0:
185
+ loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
186
+ if self._best_loss is None or loss < self._best_loss:
187
+ self._best_loss = loss
190
188
  self._history["loss"].append(loss)
191
189
 
192
190
  # Check if the loss has diverged; if it has, stop the trainer
193
- if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator]
191
+ if self._history["loss"][-1] > diverge_th * self._best_loss:
194
192
  self._diverge_flag = True
195
193
  self.logger.info("Stopping early, the loss has diverged")
196
194
  trainer.terminate()
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
1122
1122
  f"but given {type(lr_scheduler)}"
1123
1123
  )
1124
1124
 
1125
- if not isinstance(warmup_duration, numbers.Integral):
1125
+ if not isinstance(warmup_duration, int):
1126
1126
  raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
1127
1127
 
1128
1128
  if not (warmup_duration > 1):
1129
1129
  raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
1130
1130
 
1131
1131
  warmup_schedulers: List[ParamScheduler] = []
1132
+ milestones_values: List[Tuple[int, float]] = []
1132
1133
 
1133
1134
  for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
1134
1135
  if warmup_end_value is None:
@@ -1154,7 +1155,6 @@ def create_lr_scheduler_with_warmup(
1154
1155
  init_lr = lr_scheduler.get_param()
1155
1156
  if init_lr == param_group_warmup_end_value:
1156
1157
  if warmup_duration > 2:
1157
- # pyrefly: ignore [unsupported-operation]
1158
1158
  d = (param_group_warmup_end_value - warmup_start_value) / (warmup_duration - 1)
1159
1159
  milestones_values[-1] = (warmup_duration - 2, param_group_warmup_end_value - d)
1160
1160
  else:
@@ -1164,7 +1164,6 @@ def create_lr_scheduler_with_warmup(
1164
1164
  PiecewiseLinear(
1165
1165
  lr_scheduler.optimizer,
1166
1166
  param_name="lr",
1167
- # pyrefly: ignore [bad-argument-type]
1168
1167
  milestones_values=milestones_values,
1169
1168
  param_group_index=param_group_index,
1170
1169
  save_history=save_history,
@@ -1177,7 +1176,6 @@ def create_lr_scheduler_with_warmup(
1177
1176
  warmup_scheduler,
1178
1177
  lr_scheduler,
1179
1178
  ]
1180
- # pyrefly: ignore [unbound-name, unsupported-operation]
1181
1179
  durations = [milestones_values[-1][0] + 1]
1182
1180
  # pyrefly: ignore [bad-argument-type]
1183
1181
  combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
@@ -1655,13 +1653,13 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1655
1653
  self.trainer = trainer
1656
1654
  self.optimizer = optimizer
1657
1655
 
1656
+ min_lr: Union[float, List[float]]
1658
1657
  if "min_lr" in scheduler_kwargs and param_group_index is not None:
1659
1658
  min_lr = scheduler_kwargs["min_lr"]
1660
1659
  if not isinstance(min_lr, float):
1661
1660
  raise TypeError(f"When param_group_index is given, min_lr should be a float, but given {type(min_lr)}")
1662
1661
  _min_lr = min_lr
1663
1662
  min_lr = [0] * len(optimizer.param_groups)
1664
- # pyrefly: ignore [unsupported-operation]
1665
1663
  min_lr[param_group_index] = _min_lr
1666
1664
  else:
1667
1665
  min_lr = 0
@@ -1676,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1676
1674
  _scheduler_kwargs["verbose"] = False
1677
1675
 
1678
1676
  self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
1679
- self.scheduler._reduce_lr = self._reduce_lr # type: ignore[method-assign]
1677
+ self.scheduler._reduce_lr = self._reduce_lr
1680
1678
 
1681
1679
  self._state_attrs += ["metric_name", "scheduler"]
1682
1680
 
1683
- def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override]
1681
+ def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
1684
1682
  if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
1685
1683
  raise ValueError(
1686
1684
  "Argument engine should have in its 'state', attribute 'metrics' "
@@ -1,7 +1,7 @@
1
1
  import numbers
2
2
  import warnings
3
3
  from bisect import bisect_right
4
- from typing import Any, List, Sequence, Tuple, Union
4
+ from typing import Any, Callable, List, Sequence, Tuple, Union
5
5
 
6
6
  from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
7
7
  from ignite.handlers.param_scheduler import BaseParamScheduler
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
183
183
 
184
184
  """
185
185
 
186
- def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
186
+ def __init__(
187
+ self,
188
+ lambda_obj: Callable[[int], Union[List[float], float]],
189
+ param_name: str,
190
+ save_history: bool = False,
191
+ create_new: bool = False,
192
+ ):
187
193
  super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
188
194
 
189
195
  if not callable(lambda_obj):
@@ -193,7 +199,6 @@ class LambdaStateScheduler(StateParamScheduler):
193
199
  self._state_attrs += ["lambda_obj"]
194
200
 
195
201
  def get_param(self) -> Union[List[float], float]:
196
- # pyrefly: ignore [bad-return]
197
202
  return self.lambda_obj(self.event_index)
198
203
 
199
204
 
@@ -500,14 +500,14 @@ class HandlersTimeProfiler:
500
500
 
501
501
  self.dataflow_times: List[float] = []
502
502
  self.processing_times: List[float] = []
503
- self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
503
+ self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
504
504
 
505
505
  @staticmethod
506
506
  def _get_callable_name(handler: Callable) -> str:
507
507
  # get name of the callable handler
508
508
  return getattr(handler, "__qualname__", handler.__class__.__name__)
509
509
 
510
- def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
510
+ def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
511
511
  @functools.wraps(handler)
512
512
  def _timeit_handler(*args: Any, **kwargs: Any) -> None:
513
513
  self._event_handlers_timer.reset()
@@ -532,7 +532,7 @@ class HandlersTimeProfiler:
532
532
  t = self._dataflow_timer.value()
533
533
  self.dataflow_times.append(t)
534
534
 
535
- def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
535
+ def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
536
536
  # reset the variables used for profiling
537
537
  self.dataflow_times = []
538
538
  self.processing_times = []
@@ -223,7 +223,7 @@ class ProgressBar(BaseLogger):
223
223
  super(ProgressBar, self).attach(engine, log_handler, event_name)
224
224
  engine.add_event_handler(closing_event_name, self._close)
225
225
 
226
- def attach_opt_params_handler( # type: ignore[empty-body]
226
+ def attach_opt_params_handler(
227
227
  self,
228
228
  engine: Engine,
229
229
  event_name: Union[str, Events],
@@ -1,7 +1,7 @@
1
1
  """Visdom logger and its helper handlers."""
2
2
 
3
3
  import os
4
- from typing import Any, Callable, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
@@ -165,7 +165,7 @@ class VisdomLogger(BaseLogger):
165
165
  "pip install git+https://github.com/fossasia/visdom.git"
166
166
  )
167
167
 
168
- if num_workers > 0:
168
+ if num_workers > 0 or TYPE_CHECKING:
169
169
  # If visdom is installed, one of its dependencies `tornado`
170
170
  # requires also `futures` to be installed.
171
171
  # Let's check anyway if we can import it.
@@ -199,7 +199,6 @@ class VisdomLogger(BaseLogger):
199
199
 
200
200
  self.executor: Union[_DummyExecutor, "ThreadPoolExecutor"] = _DummyExecutor()
201
201
  if num_workers > 0:
202
- # pyrefly: ignore [unbound-name]
203
202
  self.executor = ThreadPoolExecutor(max_workers=num_workers)
204
203
 
205
204
  def _save(self) -> None:
@@ -254,10 +254,10 @@ class Accuracy(_BaseClassification):
254
254
  y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
255
255
  y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
256
256
  correct = torch.all(y == y_pred.type_as(y), dim=-1)
257
+ else:
258
+ raise ValueError(f"Unexpected type: {self._type}")
257
259
 
258
- # pyrefly: ignore [unbound-name]
259
260
  self._num_correct += torch.sum(correct).to(self._device)
260
- # pyrefly: ignore [unbound-name]
261
261
  self._num_examples += correct.shape[0]
262
262
 
263
263
  @sync_all_reduce("_num_examples", "_num_correct")
@@ -1,6 +1,5 @@
1
1
  from abc import ABCMeta, abstractmethod
2
- from collections import namedtuple
3
- from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
2
+ from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
4
3
 
5
4
  import torch
6
5
 
@@ -13,24 +12,25 @@ from ignite.metrics.nlp.utils import lcs, ngrams
13
12
  __all__ = ["Rouge", "RougeN", "RougeL"]
14
13
 
15
14
 
16
- # pyrefly: ignore [invalid-inheritance]
17
- class Score(namedtuple("Score", ["match", "candidate", "reference"])):
15
+ class Score(NamedTuple):
18
16
  r"""
19
17
  Computes precision and recall for given matches, candidate and reference lengths.
20
18
  """
21
19
 
20
+ match: int
21
+ candidate: int
22
+ reference: int
23
+
22
24
  def precision(self) -> float:
23
25
  """
24
26
  Calculates precision.
25
27
  """
26
- # pyrefly: ignore [missing-attribute]
27
28
  return self.match / self.candidate if self.candidate > 0 else 0
28
29
 
29
30
  def recall(self) -> float:
30
31
  """
31
32
  Calculates recall.
32
33
  """
33
- # pyrefly: ignore [missing-attribute]
34
34
  return self.match / self.reference if self.reference > 0 else 0
35
35
 
36
36
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pytorch-ignite
3
- Version: 0.6.0.dev20251216
3
+ Version: 0.6.0.dev20251218
4
4
  Summary: A lightweight library to help with training neural networks in PyTorch.
5
5
  Project-URL: Homepage, https://pytorch-ignite.ai
6
6
  Project-URL: Repository, https://github.com/pytorch/ignite
@@ -1,4 +1,4 @@
1
- ignite/__init__.py,sha256=U1-KvPu1oRNr4WXKelJ_2A28wYaRpTv5KEfH825zMx8,194
1
+ ignite/__init__.py,sha256=PgiRswfoDhr3-z_4hEfXVGwKuQXbsFvn-IY1_1BNIBU,194
2
2
  ignite/_utils.py,sha256=XDPpUDJ8ykLXWMV2AYTqGSj8XCfApsyzsQ3Vij_OB4M,182
3
3
  ignite/exceptions.py,sha256=5ZWCVLPC9rgoW8t84D-VeEleqz5O7XpAGPpCdU8rKd0,150
4
4
  ignite/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -8,7 +8,7 @@ ignite/base/mixins.py,sha256=Ip1SHCQCsvNUnLJKJwX9L-hqpfcZAlTad87-PaVgCBI,991
8
8
  ignite/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  ignite/contrib/engines/__init__.py,sha256=BxmXYIYEtEB1niMWITL8pgyKufCIpXR61rSzPQOhA0g,87
10
10
  ignite/contrib/engines/common.py,sha256=8WyVV6pqVHKnBfcdZoBPbOUXqqwSOTUI2OKUyMqvOks,28483
11
- ignite/contrib/engines/tbptt.py,sha256=pFVwnk8mN0mFFLwDfVK0ee8DrQnHi_Zpr20mOFHGArs,4549
11
+ ignite/contrib/engines/tbptt.py,sha256=FSmF5SnoZn7mWNZWRZ-ohWUCfucET78GQu3lvVRNItk,4507
12
12
  ignite/contrib/handlers/__init__.py,sha256=rZszZnCbzncE2jqsvx9KP1iS3WZ0I-CnrV3Jh3Xl8_o,1073
13
13
  ignite/contrib/handlers/base_logger.py,sha256=gHVTkVvYMRUXI793rNq8564mMyJaL_HCuoCu8xiKxFY,1158
14
14
  ignite/contrib/handlers/clearml_logger.py,sha256=4CRD38jrif-8MeKYiEu5RbF5B-PhEkPnUGfielvt5s8,1385
@@ -47,44 +47,44 @@ ignite/contrib/metrics/regression/r2_score.py,sha256=1Mwo3Ft2PkYL8xq-CcbKqidJP5j
47
47
  ignite/contrib/metrics/regression/wave_hedges_distance.py,sha256=1uSqAUZX5aBzw0UJNla6bRYhHM3uPdVPuEzNJa4dixk,847
48
48
  ignite/distributed/__init__.py,sha256=qC28ok9XHWJawZfQR2MqWf6ctggS4rUY9PiTJjOCNvI,181
49
49
  ignite/distributed/auto.py,sha256=9nk9ArklntyzTaHx-odUTtKtX7bch-qQf1HQE7Y6YQE,15443
50
- ignite/distributed/launcher.py,sha256=hjdL8pnWNrpMQjw_GrY9CGWyUqvb6g42nfEsT_5cxdo,13492
50
+ ignite/distributed/launcher.py,sha256=lEzoLqfVQDDXoPJ0ELUNs7090o1I6cDBFKuq3lTLPs4,13298
51
51
  ignite/distributed/utils.py,sha256=D97JwWgL9RKP8rTfDRf1zMmfRUeJizr7XfLZ8LAScOI,24999
52
52
  ignite/distributed/comp_models/__init__.py,sha256=S2WHl463U7BvpcUe9-JaGtuCi3G1cMHFW5QFBQ6fv20,1357
53
53
  ignite/distributed/comp_models/base.py,sha256=pTIylP1h2g6NWopBEponfXC6UefqS1l2lEdzTUTNXFc,14185
54
- ignite/distributed/comp_models/horovod.py,sha256=0YatfhWVGRsNCDNcOZDtQi4F_KIWJOv6uPh4lP8VmYg,9297
54
+ ignite/distributed/comp_models/horovod.py,sha256=pGrcHQcwjuuMWJufBR4RyT5YR6RHT8wtk4-Bz_ir3_w,9353
55
55
  ignite/distributed/comp_models/native.py,sha256=k2ADEkHNTRDyWfBE1JP7AvTQTjjPtW8a2pyNLkeV6AQ,28139
56
56
  ignite/distributed/comp_models/xla.py,sha256=XhKFeo7kNu4mTe9yyzLoEzxS8cDbTFJKAYY9m_dDHIk,6367
57
57
  ignite/engine/__init__.py,sha256=MRFj6yywKhVkov4ccPkrw4dX1O8PfqceiJkngrcFb7A,36094
58
58
  ignite/engine/deterministic.py,sha256=uXn5VfxN_AgcEzZwBk_zdPWlSdKH2tl8Md1lcx1mvJ4,11643
59
- ignite/engine/engine.py,sha256=H_dbEsuydRHcD7uSpfLX7Qn71WjEAusnbsU1gTyFzxA,61051
59
+ ignite/engine/engine.py,sha256=R0cDvh_MxFWOucmVuxrjiH3_xcybNDo9c4BkHUk2CEI,60713
60
60
  ignite/engine/events.py,sha256=FrcvnvjNZEzzohMQU6ZxL8ezrUQshUuM917Rsyxf8v0,21833
61
61
  ignite/engine/utils.py,sha256=QG5mkdg4OipspqgpNQcJuoHTYdr2Sx5LS16kfjOHDdI,1073
62
62
  ignite/handlers/__init__.py,sha256=Qq85YTtHPcii6UAfMOoCPg9RwigH96iqxOJKIlRfDqw,2728
63
63
  ignite/handlers/base_logger.py,sha256=wPiGn9iCh5ung1GaRUf_qAlqe63h1NpUUQ0XK709p2k,13011
64
- ignite/handlers/checkpoint.py,sha256=1d59fTyO4gh_iJWTDCTRQS9fEdoWMs82McD7tWIqhYE,46412
65
- ignite/handlers/clearml_logger.py,sha256=12a9eue6hnFh5CrdSFz_EpGF0-XKRMlBXpR2NWWw8DY,37949
64
+ ignite/handlers/checkpoint.py,sha256=u6cFUDxAoSSBKCBprmDud2LEZGDEYHvyCoLUmtG3Xd4,46309
65
+ ignite/handlers/clearml_logger.py,sha256=0-57RYznIz-EgTsKtkKFPdGGFQXJIhq146H_qiE8hVc,37897
66
66
  ignite/handlers/early_stopping.py,sha256=UA6TiKho5CbD085R-16H8w3r0BYPQcWQjhEXg8aITSw,4139
67
67
  ignite/handlers/ema_handler.py,sha256=SmUyyWIFPZW3yMvjD_sSk5m_LfnMFl9R-uQdbXNFfY0,11854
68
68
  ignite/handlers/fbresearch_logger.py,sha256=MfQeiBIXBYLEwZoDIld2oCceMeTAsz8rc5cd7fLtpJs,11133
69
- ignite/handlers/lr_finder.py,sha256=zHOb-gEW_e0obZnw5olnasLVLxInUjTnftjS6vvoifg,22253
69
+ ignite/handlers/lr_finder.py,sha256=EMcQR3NDPOuh2s85a5Zu5Bqt0I4pg1cACJpjSa5cO4A,22100
70
70
  ignite/handlers/mlflow_logger.py,sha256=M5Mggrnr2wMsms8wbEaHqNtTk5L1zNs1MlPWD0ZCpDQ,13894
71
71
  ignite/handlers/neptune_logger.py,sha256=Rv-O_i0zGZC2Ozzeetxv7rtD7iP3IeWEcbY-U28Mkzg,27348
72
- ignite/handlers/param_scheduler.py,sha256=L4e9Nx9QbIbf4HHMCvrG12Nb_p3Uv7bzbZWd5McsRYU,68579
72
+ ignite/handlers/param_scheduler.py,sha256=Tn4o27YBrp5JsuadHobIrsHfmvB_cR1IrV_oV1Eo7us,68373
73
73
  ignite/handlers/polyaxon_logger.py,sha256=5b7Zxhksne8Ufg_SBTG-rlf_9CPSjkBQOJR4-ynoZnQ,12354
74
- ignite/handlers/state_param_scheduler.py,sha256=Jk4tAFQhmP-C3jo1L_paiInn74Flc2Vaxqi-zQwCfqc,20784
74
+ ignite/handlers/state_param_scheduler.py,sha256=B89YKZyj9DXLXQyr3amDNMslUOWNHZDis2DXIwW0q10,20841
75
75
  ignite/handlers/stores.py,sha256=8XM_Qqsitfu0WtOOE-K2FMtv51vD90r3GgQlCzRABYc,2616
76
76
  ignite/handlers/tensorboard_logger.py,sha256=q3YxXkbIFayBggI_kcHyl-upttVVjjnqFOLgyjj2cRo,27967
77
77
  ignite/handlers/terminate_on_nan.py,sha256=RFSKd3Oqn9Me2xLCos4lSE-hnY7fYWWjE9blioeMlIs,2103
78
78
  ignite/handlers/time_limit.py,sha256=heTuS-ReBbOUCm1NcNJGhzxI080Hanc4hOLB2Y4GeZk,1567
79
- ignite/handlers/time_profilers.py,sha256=0Jd_dDBcD5i280xsN4KoaBdmXbS04S2nbFNghSNLmBc,30446
79
+ ignite/handlers/time_profilers.py,sha256=8iCcBYPxv0vKFSO_ujFV0ST54a9PD9ezFLvYTIu9lFI,30482
80
80
  ignite/handlers/timing.py,sha256=nHeBHvPwYdPRMAx-jk_8MjZit4a7rmsmIWkUrajAG-s,4705
81
- ignite/handlers/tqdm_logger.py,sha256=GvrvSLz2WDQHosPUvAZe4GHffMSO8wnfLURbjnmbcOg,13121
81
+ ignite/handlers/tqdm_logger.py,sha256=3kxH39vM0LCDVwIZl9HQRaWM2Pr6bYC_l9oydFJmdM4,13093
82
82
  ignite/handlers/utils.py,sha256=X4LRqo1kqGsbmX0pEuZKYR6K4C8sZudAqxCLriiXtCg,872
83
- ignite/handlers/visdom_logger.py,sha256=HhKurlolglUaqX_rzvK3iG2ofwU0-XGr6rSSzQCRSkk,21875
83
+ ignite/handlers/visdom_logger.py,sha256=RY5ss3NAPad7d3xFFnqczCtuO6RgmWq9ROz-sFf6imI,21862
84
84
  ignite/handlers/wandb_logger.py,sha256=vGok3gADQmTNkc6KkfFBreYoHAO8EneuU65xjBpT5-Q,14837
85
85
  ignite/metrics/__init__.py,sha256=m-8F8J17r-aEwsO6Ww-8AqDRN59WFfYBwCDKwqGDSmI,3627
86
86
  ignite/metrics/accumulation.py,sha256=xWdsm9u6JfsfODX_GUKzQc_omrdFDJ4yELBR-xXgc4s,12448
87
- ignite/metrics/accuracy.py,sha256=iddxBPlOOGv1Vda2hIcMTi9jZwuPibVQqvYptRPMaa8,10250
87
+ ignite/metrics/accuracy.py,sha256=W8mO4W11VzryMXKy8G7W_g4A9PH9RYpejW_tQ-T_Txw,10245
88
88
  ignite/metrics/average_precision.py,sha256=AL4fvWCUL6zMNq_u2vQRnAdmdByB8S8x8jSE-MoFVjY,3694
89
89
  ignite/metrics/classification_report.py,sha256=zjGlaMnRz2__op6hrZq74OusO0W_5B1AIe8KzYGFilM,5988
90
90
  ignite/metrics/cohen_kappa.py,sha256=Qwcd4P2kN12CVCFC-kVdzn_2XV7kGzP6LlWkK209JJ8,3815
@@ -129,7 +129,7 @@ ignite/metrics/gan/inception_score.py,sha256=78_qrECWb_KsbLbo1lvDnvFJ9FsWPsbUi1a
129
129
  ignite/metrics/gan/utils.py,sha256=3nihbBrcM9MRcu6r0p3x5SgZQ5V4aag20ZppM7j_HiM,3993
130
130
  ignite/metrics/nlp/__init__.py,sha256=TiDKRhw7lhZeoL2Cn4s306cKIuBbXl2fizN1ZepMhwI,168
131
131
  ignite/metrics/nlp/bleu.py,sha256=NyQZ3CQB1xUnH_KWer5QtxkM_S_aiO3ok86UMxHaQ_w,11539
132
- ignite/metrics/nlp/rouge.py,sha256=dyQEMTLbq5DLei2P2SoOZpMykPITDr1vq3BtdBIEysk,15460
132
+ ignite/metrics/nlp/rouge.py,sha256=siAxJzGE3KjH23u-F3DCUPke--ls-1XMygncGhTYJp4,15313
133
133
  ignite/metrics/nlp/utils.py,sha256=CA0MRMk9l97QockFYYhU6k0-hLhP3GwW36ONZ7TRqmc,2341
134
134
  ignite/metrics/regression/__init__.py,sha256=I594yB38ypWi9IDi9rrdshdXeBnSRcST09tnLRjN0yk,1472
135
135
  ignite/metrics/regression/_base.py,sha256=5V6GkkaBYRuW9J3yDXucyTZp1XJ2uIG7F4w2XcBsd3w,2365
@@ -153,7 +153,7 @@ ignite/metrics/regression/spearman_correlation.py,sha256=IzmN4WIe7C4cTUU3BOkBmaw
153
153
  ignite/metrics/regression/wave_hedges_distance.py,sha256=Ji_NRUgnZ3lJgi5fyNFLRjbHO648z4dBmqVDQU9ImKA,2792
154
154
  ignite/metrics/vision/__init__.py,sha256=lPBAEq1idc6Q17poFm1SjttE27irHF1-uNeiwrxnLrU,159
155
155
  ignite/metrics/vision/object_detection_average_precision_recall.py,sha256=4wwiNVd658ynIpIbQlffTA-ehvyJ2EzmJ5pBSBuA8XQ,25091
156
- pytorch_ignite-0.6.0.dev20251216.dist-info/METADATA,sha256=_hQO-qN-QvAtkMGRts0VsNPPKsfTpyJhcFuvVlF6-Yw,27979
157
- pytorch_ignite-0.6.0.dev20251216.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
158
- pytorch_ignite-0.6.0.dev20251216.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
159
- pytorch_ignite-0.6.0.dev20251216.dist-info/RECORD,,
156
+ pytorch_ignite-0.6.0.dev20251218.dist-info/METADATA,sha256=rX74hh67QAuRKc4xTn19Gqc62MTsbQ702xRQxfWH6Nk,27979
157
+ pytorch_ignite-0.6.0.dev20251218.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
158
+ pytorch_ignite-0.6.0.dev20251218.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
159
+ pytorch_ignite-0.6.0.dev20251218.dist-info/RECORD,,