pytorch-ignite 0.6.0.dev20250927__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytorch-ignite might be problematic. Click here for more details.

Files changed (65) hide show
  1. ignite/__init__.py +1 -1
  2. ignite/contrib/engines/common.py +1 -0
  3. ignite/contrib/handlers/base_logger.py +1 -1
  4. ignite/contrib/handlers/clearml_logger.py +1 -1
  5. ignite/contrib/handlers/lr_finder.py +1 -1
  6. ignite/contrib/handlers/mlflow_logger.py +1 -1
  7. ignite/contrib/handlers/neptune_logger.py +1 -1
  8. ignite/contrib/handlers/param_scheduler.py +1 -1
  9. ignite/contrib/handlers/polyaxon_logger.py +1 -1
  10. ignite/contrib/handlers/tensorboard_logger.py +1 -1
  11. ignite/contrib/handlers/time_profilers.py +1 -1
  12. ignite/contrib/handlers/tqdm_logger.py +1 -1
  13. ignite/contrib/handlers/visdom_logger.py +1 -1
  14. ignite/contrib/handlers/wandb_logger.py +1 -1
  15. ignite/contrib/metrics/average_precision.py +1 -1
  16. ignite/contrib/metrics/cohen_kappa.py +1 -1
  17. ignite/contrib/metrics/gpu_info.py +1 -1
  18. ignite/contrib/metrics/precision_recall_curve.py +1 -1
  19. ignite/contrib/metrics/regression/canberra_metric.py +2 -3
  20. ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
  21. ignite/contrib/metrics/regression/fractional_bias.py +2 -3
  22. ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
  23. ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
  24. ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
  25. ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
  26. ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
  27. ignite/contrib/metrics/regression/mean_error.py +2 -3
  28. ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
  29. ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
  30. ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
  31. ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
  32. ignite/contrib/metrics/regression/r2_score.py +2 -3
  33. ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
  34. ignite/contrib/metrics/roc_auc.py +1 -1
  35. ignite/distributed/auto.py +1 -0
  36. ignite/distributed/comp_models/horovod.py +8 -1
  37. ignite/distributed/comp_models/native.py +2 -1
  38. ignite/distributed/comp_models/xla.py +2 -0
  39. ignite/distributed/launcher.py +4 -8
  40. ignite/engine/__init__.py +9 -9
  41. ignite/engine/deterministic.py +1 -1
  42. ignite/engine/engine.py +9 -11
  43. ignite/engine/events.py +2 -1
  44. ignite/handlers/__init__.py +2 -0
  45. ignite/handlers/checkpoint.py +2 -2
  46. ignite/handlers/clearml_logger.py +2 -2
  47. ignite/handlers/fbresearch_logger.py +2 -2
  48. ignite/handlers/lr_finder.py +10 -10
  49. ignite/handlers/neptune_logger.py +1 -0
  50. ignite/handlers/param_scheduler.py +7 -3
  51. ignite/handlers/state_param_scheduler.py +8 -2
  52. ignite/handlers/time_profilers.py +6 -3
  53. ignite/handlers/tqdm_logger.py +7 -2
  54. ignite/handlers/visdom_logger.py +2 -2
  55. ignite/handlers/wandb_logger.py +9 -8
  56. ignite/metrics/accuracy.py +2 -0
  57. ignite/metrics/metric.py +1 -0
  58. ignite/metrics/nlp/rouge.py +6 -3
  59. ignite/metrics/roc_auc.py +1 -0
  60. ignite/metrics/ssim.py +4 -0
  61. ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
  62. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +2 -2
  63. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +65 -65
  64. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -1
  65. {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/licenses/LICENSE +0 -0
ignite/engine/engine.py CHANGED
@@ -148,7 +148,7 @@ class Engine(Serializable):
148
148
  self.should_interrupt = False
149
149
  self.state = State()
150
150
  self._state_dict_user_keys: List[str] = []
151
- self._allowed_events: List[EventEnum] = []
151
+ self._allowed_events: List[Union[str, EventEnum]] = []
152
152
 
153
153
  self._dataloader_iter: Optional[Iterator[Any]] = None
154
154
  self._init_iter: Optional[int] = None
@@ -163,9 +163,7 @@ class Engine(Serializable):
163
163
  # generator provided by self._internal_run_as_gen
164
164
  self._internal_run_generator: Optional[Generator[Any, None, State]] = None
165
165
 
166
- def register_events(
167
- self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
168
- ) -> None:
166
+ def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
169
167
  """Add events that can be fired.
170
168
 
171
169
  Registering an event will let the user trigger these events at any point.
@@ -450,7 +448,7 @@ class Engine(Serializable):
450
448
  first, others = ((resolved_engine,), args[1:])
451
449
  else:
452
450
  # metrics do not provide engine when registered
453
- first, others = (tuple(), args) # type: ignore[assignment]
451
+ first, others = (tuple(), args)
454
452
 
455
453
  func(*first, *(event_args + others), **kwargs)
456
454
 
@@ -989,9 +987,9 @@ class Engine(Serializable):
989
987
  def _internal_run_as_gen(self) -> Generator[Any, None, State]:
990
988
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
991
989
  self._init_timers(self.state)
990
+ start_time = time.time()
992
991
  try:
993
992
  try:
994
- start_time = time.time()
995
993
  self._fire_event(Events.STARTED)
996
994
  yield from self._maybe_terminate_or_interrupt()
997
995
 
@@ -1010,7 +1008,7 @@ class Engine(Serializable):
1010
1008
  # time is available for handlers but must be updated after fire
1011
1009
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1012
1010
 
1013
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1011
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1014
1012
  handlers_start_time = time.time()
1015
1013
  self._fire_event(Events.EPOCH_COMPLETED)
1016
1014
  epoch_time_taken += time.time() - handlers_start_time
@@ -1043,7 +1041,7 @@ class Engine(Serializable):
1043
1041
  self.state.times[Events.COMPLETED.name] = time_taken
1044
1042
 
1045
1043
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1046
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1044
+ if self.should_terminate != "skip_completed":
1047
1045
  handlers_start_time = time.time()
1048
1046
  self._fire_event(Events.COMPLETED)
1049
1047
  time_taken += time.time() - handlers_start_time
@@ -1189,9 +1187,9 @@ class Engine(Serializable):
1189
1187
  # internal_run without generator for BC
1190
1188
  self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
1191
1189
  self._init_timers(self.state)
1190
+ start_time = time.time()
1192
1191
  try:
1193
1192
  try:
1194
- start_time = time.time()
1195
1193
  self._fire_event(Events.STARTED)
1196
1194
  self._maybe_terminate_legacy()
1197
1195
 
@@ -1210,7 +1208,7 @@ class Engine(Serializable):
1210
1208
  # time is available for handlers but must be updated after fire
1211
1209
  self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1212
1210
 
1213
- if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
1211
+ if self.should_terminate_single_epoch != "skip_epoch_completed":
1214
1212
  handlers_start_time = time.time()
1215
1213
  self._fire_event(Events.EPOCH_COMPLETED)
1216
1214
  epoch_time_taken += time.time() - handlers_start_time
@@ -1243,7 +1241,7 @@ class Engine(Serializable):
1243
1241
  self.state.times[Events.COMPLETED.name] = time_taken
1244
1242
 
1245
1243
  # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1246
- if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
1244
+ if self.should_terminate != "skip_completed":
1247
1245
  handlers_start_time = time.time()
1248
1246
  self._fire_event(Events.COMPLETED)
1249
1247
  time_taken += time.time() - handlers_start_time
ignite/engine/events.py CHANGED
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
91
91
  raise ValueError("Argument every should be integer and greater than zero")
92
92
 
93
93
  if once is not None:
94
- c1 = isinstance(once, numbers.Integral) and once > 0
94
+ c1 = isinstance(once, int) and once > 0
95
95
  c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
96
96
  if not (c1 or c2):
97
97
  raise ValueError(
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
240
240
  def __new__(cls, value: str) -> "EventEnum":
241
241
  obj = CallableEventWithFilter.__new__(cls)
242
242
  obj._value_ = value
243
+ # pyrefly: ignore [bad-return]
243
244
  return obj
244
245
 
245
246
 
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
6
6
  from ignite.handlers.clearml_logger import ClearMLLogger
7
7
  from ignite.handlers.early_stopping import EarlyStopping
8
8
  from ignite.handlers.ema_handler import EMAHandler
9
+ from ignite.handlers.fbresearch_logger import FBResearchLogger
9
10
  from ignite.handlers.lr_finder import FastaiLRFinder
10
11
  from ignite.handlers.mlflow_logger import MLflowLogger
11
12
  from ignite.handlers.neptune_logger import NeptuneLogger
@@ -64,6 +65,7 @@ __all__ = [
64
65
  "CyclicalScheduler",
65
66
  "create_lr_scheduler_with_warmup",
66
67
  "FastaiLRFinder",
68
+ "FBResearchLogger",
67
69
  "EMAHandler",
68
70
  "BasicTimeProfiler",
69
71
  "HandlersTimeProfiler",
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
315
315
  """
316
316
 
317
317
  SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
318
- Item = NamedTuple("Item", [("priority", int), ("filename", str)])
318
+ Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
319
319
  _state_dict_all_req_keys = ("_saved",)
320
320
 
321
321
  def __init__(
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
323
323
  to_save: Mapping,
324
324
  save_handler: Union[str, Path, Callable, BaseSaveHandler],
325
325
  filename_prefix: str = "",
326
- score_function: Optional[Callable] = None,
326
+ score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
327
327
  score_name: Optional[str] = None,
328
328
  n_saved: Union[int, None] = 1,
329
329
  global_step_transform: Optional[Callable] = None,
@@ -862,7 +862,7 @@ class ClearMLSaver(DiskSaver):
862
862
  except ImportError:
863
863
  try:
864
864
  # Backwards-compatibility for legacy Trains SDK
865
- from trains import Task # type: ignore[no-redef]
865
+ from trains import Task
866
866
  except ImportError:
867
867
  raise ModuleNotFoundError(
868
868
  "This contrib module requires clearml to be installed. "
@@ -937,7 +937,7 @@ class ClearMLSaver(DiskSaver):
937
937
  except ImportError:
938
938
  try:
939
939
  # Backwards-compatibility for legacy Trains SDK
940
- from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef]
940
+ from trains.binding.frameworks import WeightsFileHandler
941
941
  except ImportError:
942
942
  raise ModuleNotFoundError(
943
943
  "This contrib module requires clearml to be installed. "
@@ -7,7 +7,7 @@ import torch
7
7
 
8
8
  from ignite import utils
9
9
  from ignite.engine import Engine, Events
10
- from ignite.handlers import Timer
10
+ from ignite.handlers.timing import Timer
11
11
 
12
12
  MB = 1024.0 * 1024.0
13
13
 
@@ -154,7 +154,7 @@ class FBResearchLogger:
154
154
  if torch.cuda.is_available():
155
155
  cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
156
156
 
157
- current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
157
+ current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
158
158
  iter_avg_time = self.iter_timer.value()
159
159
 
160
160
  eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
@@ -98,15 +98,18 @@ class FastaiLRFinder:
98
98
  self._best_loss = None
99
99
  self._diverge_flag = False
100
100
 
101
+ assert trainer.state.epoch_length is not None
102
+ assert trainer.state.max_epochs is not None
103
+
101
104
  # attach LRScheduler to trainer.
102
105
  if num_iter is None:
103
106
  num_iter = trainer.state.epoch_length * trainer.state.max_epochs
104
107
  else:
105
- max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
108
+ max_iter = trainer.state.epoch_length * trainer.state.max_epochs
106
109
  if max_iter < num_iter:
107
110
  max_iter = num_iter
108
111
  trainer.state.max_iters = num_iter
109
- trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
112
+ trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
110
113
 
111
114
  if not trainer.has_event_handler(self._reached_num_iterations):
112
115
  trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
@@ -178,17 +181,14 @@ class FastaiLRFinder:
178
181
  loss = idist.all_reduce(loss)
179
182
  lr = self._lr_schedule.get_param()
180
183
  self._history["lr"].append(lr)
181
- if trainer.state.iteration == 1:
182
- self._best_loss = loss # type: ignore[assignment]
183
- else:
184
- if smooth_f > 0:
185
- loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
186
- if loss < self._best_loss:
187
- self._best_loss = loss
184
+ if trainer.state.iteration != 1 and smooth_f > 0:
185
+ loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
186
+ if self._best_loss is None or loss < self._best_loss:
187
+ self._best_loss = loss
188
188
  self._history["loss"].append(loss)
189
189
 
190
190
  # Check if the loss has diverged; if it has, stop the trainer
191
- if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator]
191
+ if self._history["loss"][-1] > diverge_th * self._best_loss:
192
192
  self._diverge_flag = True
193
193
  self.logger.info("Stopping early, the loss has diverged")
194
194
  trainer.terminate()
@@ -698,6 +698,7 @@ class NeptuneSaver(BaseSaveHandler):
698
698
  # hold onto the file stream for uploading.
699
699
  # NOTE: This won't load the whole file in memory and upload
700
700
  # the stream in smaller chunks.
701
+ # pyrefly: ignore [bad-argument-type]
701
702
  self._logger[filename].upload(File.from_stream(tmp.file))
702
703
 
703
704
  @idist.one_rank_only(with_barrier=True)
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
1122
1122
  f"but given {type(lr_scheduler)}"
1123
1123
  )
1124
1124
 
1125
- if not isinstance(warmup_duration, numbers.Integral):
1125
+ if not isinstance(warmup_duration, int):
1126
1126
  raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
1127
1127
 
1128
1128
  if not (warmup_duration > 1):
1129
1129
  raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
1130
1130
 
1131
1131
  warmup_schedulers: List[ParamScheduler] = []
1132
+ milestones_values: List[Tuple[int, float]] = []
1132
1133
 
1133
1134
  for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
1134
1135
  if warmup_end_value is None:
@@ -1176,6 +1177,7 @@ def create_lr_scheduler_with_warmup(
1176
1177
  lr_scheduler,
1177
1178
  ]
1178
1179
  durations = [milestones_values[-1][0] + 1]
1180
+ # pyrefly: ignore [bad-argument-type]
1179
1181
  combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
1180
1182
 
1181
1183
  if output_simulated_values is not None:
@@ -1185,6 +1187,7 @@ def create_lr_scheduler_with_warmup(
1185
1187
  f"but given {type(output_simulated_values)}."
1186
1188
  )
1187
1189
  num_events = len(output_simulated_values)
1190
+ # pyrefly: ignore [bad-argument-type]
1188
1191
  result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations)
1189
1192
  for i in range(num_events):
1190
1193
  output_simulated_values[i] = result[i]
@@ -1650,6 +1653,7 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1650
1653
  self.trainer = trainer
1651
1654
  self.optimizer = optimizer
1652
1655
 
1656
+ min_lr: Union[float, List[float]]
1653
1657
  if "min_lr" in scheduler_kwargs and param_group_index is not None:
1654
1658
  min_lr = scheduler_kwargs["min_lr"]
1655
1659
  if not isinstance(min_lr, float):
@@ -1670,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
1670
1674
  _scheduler_kwargs["verbose"] = False
1671
1675
 
1672
1676
  self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
1673
- self.scheduler._reduce_lr = self._reduce_lr # type: ignore[method-assign]
1677
+ self.scheduler._reduce_lr = self._reduce_lr
1674
1678
 
1675
1679
  self._state_attrs += ["metric_name", "scheduler"]
1676
1680
 
1677
- def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override]
1681
+ def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
1678
1682
  if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
1679
1683
  raise ValueError(
1680
1684
  "Argument engine should have in its 'state', attribute 'metrics' "
@@ -1,7 +1,7 @@
1
1
  import numbers
2
2
  import warnings
3
3
  from bisect import bisect_right
4
- from typing import Any, List, Sequence, Tuple, Union
4
+ from typing import Any, Callable, List, Sequence, Tuple, Union
5
5
 
6
6
  from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
7
7
  from ignite.handlers.param_scheduler import BaseParamScheduler
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
183
183
 
184
184
  """
185
185
 
186
- def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
186
+ def __init__(
187
+ self,
188
+ lambda_obj: Callable[[int], Union[List[float], float]],
189
+ param_name: str,
190
+ save_history: bool = False,
191
+ create_new: bool = False,
192
+ ):
187
193
  super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
188
194
 
189
195
  if not callable(lambda_obj):
@@ -251,6 +251,7 @@ class BasicTimeProfiler:
251
251
  total_eh_time: Union[int, torch.Tensor] = sum(
252
252
  [(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]
253
253
  )
254
+ # pyrefly: ignore [no-matching-overload]
254
255
  event_handlers_stats = dict(
255
256
  [
256
257
  (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e]))
@@ -334,6 +335,7 @@ class BasicTimeProfiler:
334
335
 
335
336
  results_df = pd.DataFrame(
336
337
  data=results_dump,
338
+ # pyrefly: ignore [bad-argument-type]
337
339
  columns=[
338
340
  "epoch",
339
341
  "iteration",
@@ -498,14 +500,14 @@ class HandlersTimeProfiler:
498
500
 
499
501
  self.dataflow_times: List[float] = []
500
502
  self.processing_times: List[float] = []
501
- self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
503
+ self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
502
504
 
503
505
  @staticmethod
504
506
  def _get_callable_name(handler: Callable) -> str:
505
507
  # get name of the callable handler
506
508
  return getattr(handler, "__qualname__", handler.__class__.__name__)
507
509
 
508
- def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
510
+ def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
509
511
  @functools.wraps(handler)
510
512
  def _timeit_handler(*args: Any, **kwargs: Any) -> None:
511
513
  self._event_handlers_timer.reset()
@@ -530,7 +532,7 @@ class HandlersTimeProfiler:
530
532
  t = self._dataflow_timer.value()
531
533
  self.dataflow_times.append(t)
532
534
 
533
- def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
535
+ def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
534
536
  # reset the variables used for profiling
535
537
  self.dataflow_times = []
536
538
  self.processing_times = []
@@ -689,6 +691,7 @@ class HandlersTimeProfiler:
689
691
 
690
692
  results_dump = torch.stack(cols, dim=1).numpy()
691
693
 
694
+ # pyrefly: ignore [bad-argument-type]
692
695
  results_df = pd.DataFrame(data=results_dump, columns=headers)
693
696
  results_df.to_csv(output_path, index=False)
694
697
 
@@ -223,8 +223,13 @@ class ProgressBar(BaseLogger):
223
223
  super(ProgressBar, self).attach(engine, log_handler, event_name)
224
224
  engine.add_event_handler(closing_event_name, self._close)
225
225
 
226
- def attach_opt_params_handler( # type: ignore[empty-body]
227
- self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any
226
+ def attach_opt_params_handler(
227
+ self,
228
+ engine: Engine,
229
+ event_name: Union[str, Events],
230
+ *args: Any,
231
+ **kwargs: Any,
232
+ # pyrefly: ignore [bad-return]
228
233
  ) -> RemovableEventHandle:
229
234
  """Intentionally empty"""
230
235
  pass
@@ -1,7 +1,7 @@
1
1
  """Visdom logger and its helper handlers."""
2
2
 
3
3
  import os
4
- from typing import Any, Callable, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
@@ -165,7 +165,7 @@ class VisdomLogger(BaseLogger):
165
165
  "pip install git+https://github.com/fossasia/visdom.git"
166
166
  )
167
167
 
168
- if num_workers > 0:
168
+ if num_workers > 0 or TYPE_CHECKING:
169
169
  # If visdom is installed, one of its dependencies `tornado`
170
170
  # requires also `futures` to be installed.
171
171
  # Let's check anyway if we can import it.
@@ -1,6 +1,7 @@
1
1
  """WandB logger and its helper handlers."""
2
2
 
3
3
  from typing import Any, Callable, List, Optional, Union
4
+ from warnings import warn
4
5
 
5
6
  from torch.optim import Optimizer
6
7
 
@@ -172,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
172
173
  Default is None, global_step based on attached engine. If provided,
173
174
  uses function output as global_step. To setup global step from another engine, please use
174
175
  :meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
175
- sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
176
- the default value of wandb.log.
176
+ sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
177
177
 
178
178
  Examples:
179
179
  .. code-block:: python
@@ -284,7 +284,8 @@ class OutputHandler(BaseOutputHandler):
284
284
  state_attributes: Optional[List[str]] = None,
285
285
  ):
286
286
  super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
287
- self.sync = sync
287
+ if sync is not None:
288
+ warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
288
289
 
289
290
  def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
290
291
  if not isinstance(logger, WandBLogger):
@@ -298,7 +299,7 @@ class OutputHandler(BaseOutputHandler):
298
299
  )
299
300
 
300
301
  metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
301
- logger.log(metrics, step=global_step, sync=self.sync)
302
+ logger.log(metrics, step=global_step)
302
303
 
303
304
 
304
305
  class OptimizerParamsHandler(BaseOptimizerParamsHandler):
@@ -309,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
309
310
  as a sequence.
310
311
  param_name: parameter name
311
312
  tag: common title for all produced plots. For example, "generator"
312
- sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
313
- the default value of wandb.log.
313
+ sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
314
314
 
315
315
  Examples:
316
316
  .. code-block:: python
@@ -346,7 +346,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
346
346
  self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
347
347
  ):
348
348
  super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
349
- self.sync = sync
349
+ if sync is not None:
350
+ warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
350
351
 
351
352
  def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
352
353
  if not isinstance(logger, WandBLogger):
@@ -358,4 +359,4 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
358
359
  f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
359
360
  for i, param_group in enumerate(self.optimizer.param_groups)
360
361
  }
361
- logger.log(params, step=global_step, sync=self.sync)
362
+ logger.log(params, step=global_step)
@@ -254,6 +254,8 @@ class Accuracy(_BaseClassification):
254
254
  y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
255
255
  y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
256
256
  correct = torch.all(y == y_pred.type_as(y), dim=-1)
257
+ else:
258
+ raise ValueError(f"Unexpected type: {self._type}")
257
259
 
258
260
  self._num_correct += torch.sum(correct).to(self._device)
259
261
  self._num_examples += correct.shape[0]
ignite/metrics/metric.py CHANGED
@@ -369,6 +369,7 @@ class Metric(Serializable, metaclass=ABCMeta):
369
369
  if torch.device(device).type == "xla":
370
370
  raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
371
371
 
372
+ # pyrefly: ignore [read-only]
372
373
  self._device = torch.device(device)
373
374
  self._skip_unrolling = skip_unrolling
374
375
 
@@ -1,6 +1,5 @@
1
1
  from abc import ABCMeta, abstractmethod
2
- from collections import namedtuple
3
- from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
2
+ from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
4
3
 
5
4
  import torch
6
5
 
@@ -13,11 +12,15 @@ from ignite.metrics.nlp.utils import lcs, ngrams
13
12
  __all__ = ["Rouge", "RougeN", "RougeL"]
14
13
 
15
14
 
16
- class Score(namedtuple("Score", ["match", "candidate", "reference"])):
15
+ class Score(NamedTuple):
17
16
  r"""
18
17
  Computes precision and recall for given matches, candidate and reference lengths.
19
18
  """
20
19
 
20
+ match: int
21
+ candidate: int
22
+ reference: int
23
+
21
24
  def precision(self) -> float:
22
25
  """
23
26
  Calculates precision.
ignite/metrics/roc_auc.py CHANGED
@@ -210,4 +210,5 @@ class RocCurve(EpochMetric):
210
210
  tpr = idist.broadcast(tpr, src=0, safe_mode=True)
211
211
  thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)
212
212
 
213
+ # pyrefly: ignore [bad-return]
213
214
  return fpr, tpr, thresholds
ignite/metrics/ssim.py CHANGED
@@ -161,11 +161,15 @@ class SSIM(Metric):
161
161
  kernel_y = self._gaussian(kernel_size[1], sigma[1])
162
162
  if ndims == 3:
163
163
  kernel_z = self._gaussian(kernel_size[2], sigma[2])
164
+ else:
165
+ kernel_z = None
164
166
  else:
165
167
  kernel_x = self._uniform(kernel_size[0])
166
168
  kernel_y = self._uniform(kernel_size[1])
167
169
  if ndims == 3:
168
170
  kernel_z = self._uniform(kernel_size[2])
171
+ else:
172
+ kernel_z = None
169
173
 
170
174
  result = (
171
175
  torch.einsum("i,j->ij", kernel_x, kernel_y)
@@ -160,6 +160,9 @@ class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision):
160
160
  elif self._area_range == "large":
161
161
  min_area = 9216
162
162
  max_area = 1e10
163
+ else:
164
+ min_area = 0
165
+ max_area = 1e10
163
166
  return torch.logical_and(areas >= min_area, areas <= max_area)
164
167
 
165
168
  def _check_matching_input(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pytorch-ignite
3
- Version: 0.6.0.dev20250927
3
+ Version: 0.6.0.dev20260101
4
4
  Summary: A lightweight library to help with training neural networks in PyTorch.
5
5
  Project-URL: Homepage, https://pytorch-ignite.ai
6
6
  Project-URL: Repository, https://github.com/pytorch/ignite
@@ -412,7 +412,7 @@ Few pointers to get you started:
412
412
  - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb) [Basic example of LR finder on
413
413
  MNIST](https://github.com/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb)
414
414
  - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb) [Benchmark mixed precision training on Cifar100:
415
- torch.cuda.amp vs nvidia/apex](https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb)
415
+ torch.amp vs nvidia/apex](https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb)
416
416
  - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/MNIST_on_TPU.ipynb) [MNIST training on a single
417
417
  TPU](https://github.com/pytorch/ignite/blob/master/examples/notebooks/MNIST_on_TPU.ipynb)
418
418
  - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1E9zJrptnLJ_PKhmaP5Vhb6DTVRvyrKHx) [CIFAR10 Training on multiple TPUs](https://github.com/pytorch/ignite/tree/master/examples/cifar10)