pytorch-ignite 0.6.0.dev20251216__py3-none-any.whl → 0.6.0.dev20251217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/tbptt.py +0 -1
- ignite/distributed/comp_models/horovod.py +6 -4
- ignite/distributed/launcher.py +4 -8
- ignite/engine/engine.py +9 -14
- ignite/handlers/checkpoint.py +2 -5
- ignite/handlers/clearml_logger.py +2 -2
- ignite/handlers/lr_finder.py +10 -12
- ignite/handlers/param_scheduler.py +5 -7
- ignite/handlers/state_param_scheduler.py +8 -3
- ignite/handlers/time_profilers.py +3 -3
- ignite/handlers/tqdm_logger.py +1 -1
- ignite/handlers/visdom_logger.py +2 -3
- ignite/metrics/accuracy.py +2 -2
- ignite/metrics/nlp/rouge.py +6 -6
- {pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/METADATA +1 -1
- {pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/RECORD +19 -19
- {pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/WHEEL +0 -0
- {pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/licenses/LICENSE +0 -0
ignite/__init__.py
CHANGED
ignite/contrib/engines/tbptt.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
|
|
3
|
+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -20,6 +20,11 @@ try:
|
|
|
20
20
|
except ImportError:
|
|
21
21
|
has_hvd_support = False
|
|
22
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
# Tell the type checker that hvd imports are always defined.
|
|
25
|
+
import horovod.torch as hvd
|
|
26
|
+
from horovod import run as hvd_mp_spawn
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
if has_hvd_support:
|
|
25
30
|
HOROVOD = "horovod"
|
|
@@ -171,11 +176,8 @@ if has_hvd_support:
|
|
|
171
176
|
return group
|
|
172
177
|
|
|
173
178
|
_reduce_op_map = {
|
|
174
|
-
# pyrefly: ignore [unbound-name]
|
|
175
179
|
"SUM": hvd.mpi_ops.Sum,
|
|
176
|
-
# pyrefly: ignore [unbound-name]
|
|
177
180
|
"AVERAGE": hvd.mpi_ops.Average,
|
|
178
|
-
# pyrefly: ignore [unbound-name]
|
|
179
181
|
"ADASUM": hvd.mpi_ops.Adasum,
|
|
180
182
|
}
|
|
181
183
|
|
ignite/distributed/launcher.py
CHANGED
|
@@ -322,19 +322,15 @@ class Parallel:
|
|
|
322
322
|
idist.initialize(self.backend, init_method=self.init_method)
|
|
323
323
|
|
|
324
324
|
# The logger can be setup from now since idist.initialize() has been called (if needed)
|
|
325
|
-
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
325
|
+
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)
|
|
326
326
|
|
|
327
327
|
if self.backend is not None:
|
|
328
328
|
if self._spawn_params is None:
|
|
329
|
-
self._logger.info(
|
|
330
|
-
f"Initialized processing group with backend: '{self.backend}'"
|
|
331
|
-
)
|
|
329
|
+
self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
|
|
332
330
|
else:
|
|
333
|
-
self._logger.info(
|
|
334
|
-
f"Initialized distributed launcher with backend: '{self.backend}'"
|
|
335
|
-
)
|
|
331
|
+
self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
|
|
336
332
|
msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
|
|
337
|
-
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
333
|
+
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")
|
|
338
334
|
|
|
339
335
|
return self
|
|
340
336
|
|
ignite/engine/engine.py
CHANGED
|
@@ -148,12 +148,11 @@ class Engine(Serializable):
|
|
|
148
148
|
self.should_interrupt = False
|
|
149
149
|
self.state = State()
|
|
150
150
|
self._state_dict_user_keys: List[str] = []
|
|
151
|
-
self._allowed_events: List[EventEnum] = []
|
|
151
|
+
self._allowed_events: List[Union[str, EventEnum]] = []
|
|
152
152
|
|
|
153
153
|
self._dataloader_iter: Optional[Iterator[Any]] = None
|
|
154
154
|
self._init_iter: Optional[int] = None
|
|
155
155
|
|
|
156
|
-
# pyrefly: ignore [bad-argument-type]
|
|
157
156
|
self.register_events(*Events)
|
|
158
157
|
|
|
159
158
|
if self._process_function is None:
|
|
@@ -164,9 +163,7 @@ class Engine(Serializable):
|
|
|
164
163
|
# generator provided by self._internal_run_as_gen
|
|
165
164
|
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
|
|
166
165
|
|
|
167
|
-
def register_events(
|
|
168
|
-
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
|
|
169
|
-
) -> None:
|
|
166
|
+
def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
|
|
170
167
|
"""Add events that can be fired.
|
|
171
168
|
|
|
172
169
|
Registering an event will let the user trigger these events at any point.
|
|
@@ -451,7 +448,7 @@ class Engine(Serializable):
|
|
|
451
448
|
first, others = ((resolved_engine,), args[1:])
|
|
452
449
|
else:
|
|
453
450
|
# metrics do not provide engine when registered
|
|
454
|
-
first, others = (tuple(), args)
|
|
451
|
+
first, others = (tuple(), args)
|
|
455
452
|
|
|
456
453
|
func(*first, *(event_args + others), **kwargs)
|
|
457
454
|
|
|
@@ -990,9 +987,9 @@ class Engine(Serializable):
|
|
|
990
987
|
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
|
|
991
988
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
992
989
|
self._init_timers(self.state)
|
|
990
|
+
start_time = time.time()
|
|
993
991
|
try:
|
|
994
992
|
try:
|
|
995
|
-
start_time = time.time()
|
|
996
993
|
self._fire_event(Events.STARTED)
|
|
997
994
|
yield from self._maybe_terminate_or_interrupt()
|
|
998
995
|
|
|
@@ -1011,7 +1008,7 @@ class Engine(Serializable):
|
|
|
1011
1008
|
# time is available for handlers but must be updated after fire
|
|
1012
1009
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1013
1010
|
|
|
1014
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1011
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1015
1012
|
handlers_start_time = time.time()
|
|
1016
1013
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1017
1014
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1039,13 +1036,12 @@ class Engine(Serializable):
|
|
|
1039
1036
|
"https://github.com/pytorch/ignite/issues/new/choose"
|
|
1040
1037
|
)
|
|
1041
1038
|
|
|
1042
|
-
# pyrefly: ignore [unbound-name]
|
|
1043
1039
|
time_taken = time.time() - start_time
|
|
1044
1040
|
# time is available for handlers but must be updated after fire
|
|
1045
1041
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1046
1042
|
|
|
1047
1043
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1048
|
-
if self.should_terminate != "skip_completed":
|
|
1044
|
+
if self.should_terminate != "skip_completed":
|
|
1049
1045
|
handlers_start_time = time.time()
|
|
1050
1046
|
self._fire_event(Events.COMPLETED)
|
|
1051
1047
|
time_taken += time.time() - handlers_start_time
|
|
@@ -1191,9 +1187,9 @@ class Engine(Serializable):
|
|
|
1191
1187
|
# internal_run without generator for BC
|
|
1192
1188
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
1193
1189
|
self._init_timers(self.state)
|
|
1190
|
+
start_time = time.time()
|
|
1194
1191
|
try:
|
|
1195
1192
|
try:
|
|
1196
|
-
start_time = time.time()
|
|
1197
1193
|
self._fire_event(Events.STARTED)
|
|
1198
1194
|
self._maybe_terminate_legacy()
|
|
1199
1195
|
|
|
@@ -1212,7 +1208,7 @@ class Engine(Serializable):
|
|
|
1212
1208
|
# time is available for handlers but must be updated after fire
|
|
1213
1209
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1214
1210
|
|
|
1215
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1211
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1216
1212
|
handlers_start_time = time.time()
|
|
1217
1213
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1218
1214
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1240,13 +1236,12 @@ class Engine(Serializable):
|
|
|
1240
1236
|
"https://github.com/pytorch/ignite/issues/new/choose"
|
|
1241
1237
|
)
|
|
1242
1238
|
|
|
1243
|
-
# pyrefly: ignore [unbound-name]
|
|
1244
1239
|
time_taken = time.time() - start_time
|
|
1245
1240
|
# time is available for handlers but must be updated after fire
|
|
1246
1241
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1247
1242
|
|
|
1248
1243
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1249
|
-
if self.should_terminate != "skip_completed":
|
|
1244
|
+
if self.should_terminate != "skip_completed":
|
|
1250
1245
|
handlers_start_time = time.time()
|
|
1251
1246
|
self._fire_event(Events.COMPLETED)
|
|
1252
1247
|
time_taken += time.time() - handlers_start_time
|
ignite/handlers/checkpoint.py
CHANGED
|
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
|
|
|
315
315
|
"""
|
|
316
316
|
|
|
317
317
|
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
|
|
318
|
-
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
|
|
318
|
+
Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
|
|
319
319
|
_state_dict_all_req_keys = ("_saved",)
|
|
320
320
|
|
|
321
321
|
def __init__(
|
|
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
|
|
|
323
323
|
to_save: Mapping,
|
|
324
324
|
save_handler: Union[str, Path, Callable, BaseSaveHandler],
|
|
325
325
|
filename_prefix: str = "",
|
|
326
|
-
score_function: Optional[Callable] = None,
|
|
326
|
+
score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
|
|
327
327
|
score_name: Optional[str] = None,
|
|
328
328
|
n_saved: Union[int, None] = 1,
|
|
329
329
|
global_step_transform: Optional[Callable] = None,
|
|
@@ -440,7 +440,6 @@ class Checkpoint(Serializable):
|
|
|
440
440
|
|
|
441
441
|
def __call__(self, engine: Engine) -> None:
|
|
442
442
|
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
|
|
443
|
-
# pyrefly: ignore [bad-argument-type]
|
|
444
443
|
engine.register_events(*CheckpointEvents)
|
|
445
444
|
global_step = None
|
|
446
445
|
if self.global_step_transform is not None:
|
|
@@ -455,7 +454,6 @@ class Checkpoint(Serializable):
|
|
|
455
454
|
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
|
|
456
455
|
priority = global_step
|
|
457
456
|
|
|
458
|
-
# pyrefly: ignore [bad-argument-type]
|
|
459
457
|
if self._check_lt_n_saved() or self._compare_fn(priority):
|
|
460
458
|
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"
|
|
461
459
|
|
|
@@ -497,7 +495,6 @@ class Checkpoint(Serializable):
|
|
|
497
495
|
if isinstance(self.save_handler, BaseSaveHandler):
|
|
498
496
|
self.save_handler.remove(item.filename)
|
|
499
497
|
|
|
500
|
-
# pyrefly: ignore [bad-argument-type]
|
|
501
498
|
self._saved.append(Checkpoint.Item(priority, filename))
|
|
502
499
|
self._saved.sort(key=lambda it: it[0])
|
|
503
500
|
|
|
@@ -862,7 +862,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
862
862
|
except ImportError:
|
|
863
863
|
try:
|
|
864
864
|
# Backwards-compatibility for legacy Trains SDK
|
|
865
|
-
from trains import Task
|
|
865
|
+
from trains import Task
|
|
866
866
|
except ImportError:
|
|
867
867
|
raise ModuleNotFoundError(
|
|
868
868
|
"This contrib module requires clearml to be installed. "
|
|
@@ -937,7 +937,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
937
937
|
except ImportError:
|
|
938
938
|
try:
|
|
939
939
|
# Backwards-compatibility for legacy Trains SDK
|
|
940
|
-
from trains.binding.frameworks import WeightsFileHandler
|
|
940
|
+
from trains.binding.frameworks import WeightsFileHandler
|
|
941
941
|
except ImportError:
|
|
942
942
|
raise ModuleNotFoundError(
|
|
943
943
|
"This contrib module requires clearml to be installed. "
|
ignite/handlers/lr_finder.py
CHANGED
|
@@ -98,16 +98,18 @@ class FastaiLRFinder:
|
|
|
98
98
|
self._best_loss = None
|
|
99
99
|
self._diverge_flag = False
|
|
100
100
|
|
|
101
|
+
assert trainer.state.epoch_length is not None
|
|
102
|
+
assert trainer.state.max_epochs is not None
|
|
103
|
+
|
|
101
104
|
# attach LRScheduler to trainer.
|
|
102
105
|
if num_iter is None:
|
|
103
|
-
# pyrefly: ignore [unsupported-operation]
|
|
104
106
|
num_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
105
107
|
else:
|
|
106
|
-
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
108
|
+
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
107
109
|
if max_iter < num_iter:
|
|
108
110
|
max_iter = num_iter
|
|
109
111
|
trainer.state.max_iters = num_iter
|
|
110
|
-
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
112
|
+
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
111
113
|
|
|
112
114
|
if not trainer.has_event_handler(self._reached_num_iterations):
|
|
113
115
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
|
|
@@ -179,18 +181,14 @@ class FastaiLRFinder:
|
|
|
179
181
|
loss = idist.all_reduce(loss)
|
|
180
182
|
lr = self._lr_schedule.get_param()
|
|
181
183
|
self._history["lr"].append(lr)
|
|
182
|
-
if trainer.state.iteration
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
187
|
-
# pyrefly: ignore [unsupported-operation]
|
|
188
|
-
if loss < self._best_loss:
|
|
189
|
-
self._best_loss = loss
|
|
184
|
+
if trainer.state.iteration != 1 and smooth_f > 0:
|
|
185
|
+
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
186
|
+
if self._best_loss is None or loss < self._best_loss:
|
|
187
|
+
self._best_loss = loss
|
|
190
188
|
self._history["loss"].append(loss)
|
|
191
189
|
|
|
192
190
|
# Check if the loss has diverged; if it has, stop the trainer
|
|
193
|
-
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
191
|
+
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
194
192
|
self._diverge_flag = True
|
|
195
193
|
self.logger.info("Stopping early, the loss has diverged")
|
|
196
194
|
trainer.terminate()
|
|
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
|
|
|
1122
1122
|
f"but given {type(lr_scheduler)}"
|
|
1123
1123
|
)
|
|
1124
1124
|
|
|
1125
|
-
if not isinstance(warmup_duration,
|
|
1125
|
+
if not isinstance(warmup_duration, int):
|
|
1126
1126
|
raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
|
|
1127
1127
|
|
|
1128
1128
|
if not (warmup_duration > 1):
|
|
1129
1129
|
raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
|
|
1130
1130
|
|
|
1131
1131
|
warmup_schedulers: List[ParamScheduler] = []
|
|
1132
|
+
milestones_values: List[Tuple[int, float]] = []
|
|
1132
1133
|
|
|
1133
1134
|
for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
|
|
1134
1135
|
if warmup_end_value is None:
|
|
@@ -1154,7 +1155,6 @@ def create_lr_scheduler_with_warmup(
|
|
|
1154
1155
|
init_lr = lr_scheduler.get_param()
|
|
1155
1156
|
if init_lr == param_group_warmup_end_value:
|
|
1156
1157
|
if warmup_duration > 2:
|
|
1157
|
-
# pyrefly: ignore [unsupported-operation]
|
|
1158
1158
|
d = (param_group_warmup_end_value - warmup_start_value) / (warmup_duration - 1)
|
|
1159
1159
|
milestones_values[-1] = (warmup_duration - 2, param_group_warmup_end_value - d)
|
|
1160
1160
|
else:
|
|
@@ -1164,7 +1164,6 @@ def create_lr_scheduler_with_warmup(
|
|
|
1164
1164
|
PiecewiseLinear(
|
|
1165
1165
|
lr_scheduler.optimizer,
|
|
1166
1166
|
param_name="lr",
|
|
1167
|
-
# pyrefly: ignore [bad-argument-type]
|
|
1168
1167
|
milestones_values=milestones_values,
|
|
1169
1168
|
param_group_index=param_group_index,
|
|
1170
1169
|
save_history=save_history,
|
|
@@ -1177,7 +1176,6 @@ def create_lr_scheduler_with_warmup(
|
|
|
1177
1176
|
warmup_scheduler,
|
|
1178
1177
|
lr_scheduler,
|
|
1179
1178
|
]
|
|
1180
|
-
# pyrefly: ignore [unbound-name, unsupported-operation]
|
|
1181
1179
|
durations = [milestones_values[-1][0] + 1]
|
|
1182
1180
|
# pyrefly: ignore [bad-argument-type]
|
|
1183
1181
|
combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
|
|
@@ -1655,13 +1653,13 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1655
1653
|
self.trainer = trainer
|
|
1656
1654
|
self.optimizer = optimizer
|
|
1657
1655
|
|
|
1656
|
+
min_lr: Union[float, List[float]]
|
|
1658
1657
|
if "min_lr" in scheduler_kwargs and param_group_index is not None:
|
|
1659
1658
|
min_lr = scheduler_kwargs["min_lr"]
|
|
1660
1659
|
if not isinstance(min_lr, float):
|
|
1661
1660
|
raise TypeError(f"When param_group_index is given, min_lr should be a float, but given {type(min_lr)}")
|
|
1662
1661
|
_min_lr = min_lr
|
|
1663
1662
|
min_lr = [0] * len(optimizer.param_groups)
|
|
1664
|
-
# pyrefly: ignore [unsupported-operation]
|
|
1665
1663
|
min_lr[param_group_index] = _min_lr
|
|
1666
1664
|
else:
|
|
1667
1665
|
min_lr = 0
|
|
@@ -1676,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1676
1674
|
_scheduler_kwargs["verbose"] = False
|
|
1677
1675
|
|
|
1678
1676
|
self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
|
|
1679
|
-
self.scheduler._reduce_lr = self._reduce_lr
|
|
1677
|
+
self.scheduler._reduce_lr = self._reduce_lr
|
|
1680
1678
|
|
|
1681
1679
|
self._state_attrs += ["metric_name", "scheduler"]
|
|
1682
1680
|
|
|
1683
|
-
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1681
|
+
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1684
1682
|
if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
|
|
1685
1683
|
raise ValueError(
|
|
1686
1684
|
"Argument engine should have in its 'state', attribute 'metrics' "
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numbers
|
|
2
2
|
import warnings
|
|
3
3
|
from bisect import bisect_right
|
|
4
|
-
from typing import Any, List, Sequence, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, List, Sequence, Tuple, Union
|
|
5
5
|
|
|
6
6
|
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
|
|
7
7
|
from ignite.handlers.param_scheduler import BaseParamScheduler
|
|
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
|
|
|
183
183
|
|
|
184
184
|
"""
|
|
185
185
|
|
|
186
|
-
def __init__(
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
lambda_obj: Callable[[int], Union[List[float], float]],
|
|
189
|
+
param_name: str,
|
|
190
|
+
save_history: bool = False,
|
|
191
|
+
create_new: bool = False,
|
|
192
|
+
):
|
|
187
193
|
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
|
|
188
194
|
|
|
189
195
|
if not callable(lambda_obj):
|
|
@@ -193,7 +199,6 @@ class LambdaStateScheduler(StateParamScheduler):
|
|
|
193
199
|
self._state_attrs += ["lambda_obj"]
|
|
194
200
|
|
|
195
201
|
def get_param(self) -> Union[List[float], float]:
|
|
196
|
-
# pyrefly: ignore [bad-return]
|
|
197
202
|
return self.lambda_obj(self.event_index)
|
|
198
203
|
|
|
199
204
|
|
|
@@ -500,14 +500,14 @@ class HandlersTimeProfiler:
|
|
|
500
500
|
|
|
501
501
|
self.dataflow_times: List[float] = []
|
|
502
502
|
self.processing_times: List[float] = []
|
|
503
|
-
self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
|
|
503
|
+
self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
|
|
504
504
|
|
|
505
505
|
@staticmethod
|
|
506
506
|
def _get_callable_name(handler: Callable) -> str:
|
|
507
507
|
# get name of the callable handler
|
|
508
508
|
return getattr(handler, "__qualname__", handler.__class__.__name__)
|
|
509
509
|
|
|
510
|
-
def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
|
|
510
|
+
def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
|
|
511
511
|
@functools.wraps(handler)
|
|
512
512
|
def _timeit_handler(*args: Any, **kwargs: Any) -> None:
|
|
513
513
|
self._event_handlers_timer.reset()
|
|
@@ -532,7 +532,7 @@ class HandlersTimeProfiler:
|
|
|
532
532
|
t = self._dataflow_timer.value()
|
|
533
533
|
self.dataflow_times.append(t)
|
|
534
534
|
|
|
535
|
-
def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
|
|
535
|
+
def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
|
|
536
536
|
# reset the variables used for profiling
|
|
537
537
|
self.dataflow_times = []
|
|
538
538
|
self.processing_times = []
|
ignite/handlers/tqdm_logger.py
CHANGED
|
@@ -223,7 +223,7 @@ class ProgressBar(BaseLogger):
|
|
|
223
223
|
super(ProgressBar, self).attach(engine, log_handler, event_name)
|
|
224
224
|
engine.add_event_handler(closing_event_name, self._close)
|
|
225
225
|
|
|
226
|
-
def attach_opt_params_handler(
|
|
226
|
+
def attach_opt_params_handler(
|
|
227
227
|
self,
|
|
228
228
|
engine: Engine,
|
|
229
229
|
event_name: Union[str, Events],
|
ignite/handlers/visdom_logger.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Visdom logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
@@ -165,7 +165,7 @@ class VisdomLogger(BaseLogger):
|
|
|
165
165
|
"pip install git+https://github.com/fossasia/visdom.git"
|
|
166
166
|
)
|
|
167
167
|
|
|
168
|
-
if num_workers > 0:
|
|
168
|
+
if num_workers > 0 or TYPE_CHECKING:
|
|
169
169
|
# If visdom is installed, one of its dependencies `tornado`
|
|
170
170
|
# requires also `futures` to be installed.
|
|
171
171
|
# Let's check anyway if we can import it.
|
|
@@ -199,7 +199,6 @@ class VisdomLogger(BaseLogger):
|
|
|
199
199
|
|
|
200
200
|
self.executor: Union[_DummyExecutor, "ThreadPoolExecutor"] = _DummyExecutor()
|
|
201
201
|
if num_workers > 0:
|
|
202
|
-
# pyrefly: ignore [unbound-name]
|
|
203
202
|
self.executor = ThreadPoolExecutor(max_workers=num_workers)
|
|
204
203
|
|
|
205
204
|
def _save(self) -> None:
|
ignite/metrics/accuracy.py
CHANGED
|
@@ -254,10 +254,10 @@ class Accuracy(_BaseClassification):
|
|
|
254
254
|
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
|
|
255
255
|
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
|
|
256
256
|
correct = torch.all(y == y_pred.type_as(y), dim=-1)
|
|
257
|
+
else:
|
|
258
|
+
raise ValueError(f"Unexpected type: {self._type}")
|
|
257
259
|
|
|
258
|
-
# pyrefly: ignore [unbound-name]
|
|
259
260
|
self._num_correct += torch.sum(correct).to(self._device)
|
|
260
|
-
# pyrefly: ignore [unbound-name]
|
|
261
261
|
self._num_examples += correct.shape[0]
|
|
262
262
|
|
|
263
263
|
@sync_all_reduce("_num_examples", "_num_correct")
|
ignite/metrics/nlp/rouge.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from abc import ABCMeta, abstractmethod
|
|
2
|
-
from
|
|
3
|
-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
|
|
2
|
+
from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
@@ -13,24 +12,25 @@ from ignite.metrics.nlp.utils import lcs, ngrams
|
|
|
13
12
|
__all__ = ["Rouge", "RougeN", "RougeL"]
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
class Score(namedtuple("Score", ["match", "candidate", "reference"])):
|
|
15
|
+
class Score(NamedTuple):
|
|
18
16
|
r"""
|
|
19
17
|
Computes precision and recall for given matches, candidate and reference lengths.
|
|
20
18
|
"""
|
|
21
19
|
|
|
20
|
+
match: int
|
|
21
|
+
candidate: int
|
|
22
|
+
reference: int
|
|
23
|
+
|
|
22
24
|
def precision(self) -> float:
|
|
23
25
|
"""
|
|
24
26
|
Calculates precision.
|
|
25
27
|
"""
|
|
26
|
-
# pyrefly: ignore [missing-attribute]
|
|
27
28
|
return self.match / self.candidate if self.candidate > 0 else 0
|
|
28
29
|
|
|
29
30
|
def recall(self) -> float:
|
|
30
31
|
"""
|
|
31
32
|
Calculates recall.
|
|
32
33
|
"""
|
|
33
|
-
# pyrefly: ignore [missing-attribute]
|
|
34
34
|
return self.match / self.reference if self.reference > 0 else 0
|
|
35
35
|
|
|
36
36
|
|
{pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytorch-ignite
|
|
3
|
-
Version: 0.6.0.
|
|
3
|
+
Version: 0.6.0.dev20251217
|
|
4
4
|
Summary: A lightweight library to help with training neural networks in PyTorch.
|
|
5
5
|
Project-URL: Homepage, https://pytorch-ignite.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/pytorch/ignite
|
{pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/RECORD
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
ignite/__init__.py,sha256=
|
|
1
|
+
ignite/__init__.py,sha256=Qj__TPyyfR88JFHqrH51XDNSIbnV-iaV-PJQPhNVCcY,194
|
|
2
2
|
ignite/_utils.py,sha256=XDPpUDJ8ykLXWMV2AYTqGSj8XCfApsyzsQ3Vij_OB4M,182
|
|
3
3
|
ignite/exceptions.py,sha256=5ZWCVLPC9rgoW8t84D-VeEleqz5O7XpAGPpCdU8rKd0,150
|
|
4
4
|
ignite/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -8,7 +8,7 @@ ignite/base/mixins.py,sha256=Ip1SHCQCsvNUnLJKJwX9L-hqpfcZAlTad87-PaVgCBI,991
|
|
|
8
8
|
ignite/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
ignite/contrib/engines/__init__.py,sha256=BxmXYIYEtEB1niMWITL8pgyKufCIpXR61rSzPQOhA0g,87
|
|
10
10
|
ignite/contrib/engines/common.py,sha256=8WyVV6pqVHKnBfcdZoBPbOUXqqwSOTUI2OKUyMqvOks,28483
|
|
11
|
-
ignite/contrib/engines/tbptt.py,sha256=
|
|
11
|
+
ignite/contrib/engines/tbptt.py,sha256=FSmF5SnoZn7mWNZWRZ-ohWUCfucET78GQu3lvVRNItk,4507
|
|
12
12
|
ignite/contrib/handlers/__init__.py,sha256=rZszZnCbzncE2jqsvx9KP1iS3WZ0I-CnrV3Jh3Xl8_o,1073
|
|
13
13
|
ignite/contrib/handlers/base_logger.py,sha256=gHVTkVvYMRUXI793rNq8564mMyJaL_HCuoCu8xiKxFY,1158
|
|
14
14
|
ignite/contrib/handlers/clearml_logger.py,sha256=4CRD38jrif-8MeKYiEu5RbF5B-PhEkPnUGfielvt5s8,1385
|
|
@@ -47,44 +47,44 @@ ignite/contrib/metrics/regression/r2_score.py,sha256=1Mwo3Ft2PkYL8xq-CcbKqidJP5j
|
|
|
47
47
|
ignite/contrib/metrics/regression/wave_hedges_distance.py,sha256=1uSqAUZX5aBzw0UJNla6bRYhHM3uPdVPuEzNJa4dixk,847
|
|
48
48
|
ignite/distributed/__init__.py,sha256=qC28ok9XHWJawZfQR2MqWf6ctggS4rUY9PiTJjOCNvI,181
|
|
49
49
|
ignite/distributed/auto.py,sha256=9nk9ArklntyzTaHx-odUTtKtX7bch-qQf1HQE7Y6YQE,15443
|
|
50
|
-
ignite/distributed/launcher.py,sha256=
|
|
50
|
+
ignite/distributed/launcher.py,sha256=lEzoLqfVQDDXoPJ0ELUNs7090o1I6cDBFKuq3lTLPs4,13298
|
|
51
51
|
ignite/distributed/utils.py,sha256=D97JwWgL9RKP8rTfDRf1zMmfRUeJizr7XfLZ8LAScOI,24999
|
|
52
52
|
ignite/distributed/comp_models/__init__.py,sha256=S2WHl463U7BvpcUe9-JaGtuCi3G1cMHFW5QFBQ6fv20,1357
|
|
53
53
|
ignite/distributed/comp_models/base.py,sha256=pTIylP1h2g6NWopBEponfXC6UefqS1l2lEdzTUTNXFc,14185
|
|
54
|
-
ignite/distributed/comp_models/horovod.py,sha256=
|
|
54
|
+
ignite/distributed/comp_models/horovod.py,sha256=pGrcHQcwjuuMWJufBR4RyT5YR6RHT8wtk4-Bz_ir3_w,9353
|
|
55
55
|
ignite/distributed/comp_models/native.py,sha256=k2ADEkHNTRDyWfBE1JP7AvTQTjjPtW8a2pyNLkeV6AQ,28139
|
|
56
56
|
ignite/distributed/comp_models/xla.py,sha256=XhKFeo7kNu4mTe9yyzLoEzxS8cDbTFJKAYY9m_dDHIk,6367
|
|
57
57
|
ignite/engine/__init__.py,sha256=MRFj6yywKhVkov4ccPkrw4dX1O8PfqceiJkngrcFb7A,36094
|
|
58
58
|
ignite/engine/deterministic.py,sha256=uXn5VfxN_AgcEzZwBk_zdPWlSdKH2tl8Md1lcx1mvJ4,11643
|
|
59
|
-
ignite/engine/engine.py,sha256=
|
|
59
|
+
ignite/engine/engine.py,sha256=R0cDvh_MxFWOucmVuxrjiH3_xcybNDo9c4BkHUk2CEI,60713
|
|
60
60
|
ignite/engine/events.py,sha256=FrcvnvjNZEzzohMQU6ZxL8ezrUQshUuM917Rsyxf8v0,21833
|
|
61
61
|
ignite/engine/utils.py,sha256=QG5mkdg4OipspqgpNQcJuoHTYdr2Sx5LS16kfjOHDdI,1073
|
|
62
62
|
ignite/handlers/__init__.py,sha256=Qq85YTtHPcii6UAfMOoCPg9RwigH96iqxOJKIlRfDqw,2728
|
|
63
63
|
ignite/handlers/base_logger.py,sha256=wPiGn9iCh5ung1GaRUf_qAlqe63h1NpUUQ0XK709p2k,13011
|
|
64
|
-
ignite/handlers/checkpoint.py,sha256=
|
|
65
|
-
ignite/handlers/clearml_logger.py,sha256=
|
|
64
|
+
ignite/handlers/checkpoint.py,sha256=u6cFUDxAoSSBKCBprmDud2LEZGDEYHvyCoLUmtG3Xd4,46309
|
|
65
|
+
ignite/handlers/clearml_logger.py,sha256=0-57RYznIz-EgTsKtkKFPdGGFQXJIhq146H_qiE8hVc,37897
|
|
66
66
|
ignite/handlers/early_stopping.py,sha256=UA6TiKho5CbD085R-16H8w3r0BYPQcWQjhEXg8aITSw,4139
|
|
67
67
|
ignite/handlers/ema_handler.py,sha256=SmUyyWIFPZW3yMvjD_sSk5m_LfnMFl9R-uQdbXNFfY0,11854
|
|
68
68
|
ignite/handlers/fbresearch_logger.py,sha256=MfQeiBIXBYLEwZoDIld2oCceMeTAsz8rc5cd7fLtpJs,11133
|
|
69
|
-
ignite/handlers/lr_finder.py,sha256=
|
|
69
|
+
ignite/handlers/lr_finder.py,sha256=EMcQR3NDPOuh2s85a5Zu5Bqt0I4pg1cACJpjSa5cO4A,22100
|
|
70
70
|
ignite/handlers/mlflow_logger.py,sha256=M5Mggrnr2wMsms8wbEaHqNtTk5L1zNs1MlPWD0ZCpDQ,13894
|
|
71
71
|
ignite/handlers/neptune_logger.py,sha256=Rv-O_i0zGZC2Ozzeetxv7rtD7iP3IeWEcbY-U28Mkzg,27348
|
|
72
|
-
ignite/handlers/param_scheduler.py,sha256=
|
|
72
|
+
ignite/handlers/param_scheduler.py,sha256=Tn4o27YBrp5JsuadHobIrsHfmvB_cR1IrV_oV1Eo7us,68373
|
|
73
73
|
ignite/handlers/polyaxon_logger.py,sha256=5b7Zxhksne8Ufg_SBTG-rlf_9CPSjkBQOJR4-ynoZnQ,12354
|
|
74
|
-
ignite/handlers/state_param_scheduler.py,sha256=
|
|
74
|
+
ignite/handlers/state_param_scheduler.py,sha256=B89YKZyj9DXLXQyr3amDNMslUOWNHZDis2DXIwW0q10,20841
|
|
75
75
|
ignite/handlers/stores.py,sha256=8XM_Qqsitfu0WtOOE-K2FMtv51vD90r3GgQlCzRABYc,2616
|
|
76
76
|
ignite/handlers/tensorboard_logger.py,sha256=q3YxXkbIFayBggI_kcHyl-upttVVjjnqFOLgyjj2cRo,27967
|
|
77
77
|
ignite/handlers/terminate_on_nan.py,sha256=RFSKd3Oqn9Me2xLCos4lSE-hnY7fYWWjE9blioeMlIs,2103
|
|
78
78
|
ignite/handlers/time_limit.py,sha256=heTuS-ReBbOUCm1NcNJGhzxI080Hanc4hOLB2Y4GeZk,1567
|
|
79
|
-
ignite/handlers/time_profilers.py,sha256=
|
|
79
|
+
ignite/handlers/time_profilers.py,sha256=8iCcBYPxv0vKFSO_ujFV0ST54a9PD9ezFLvYTIu9lFI,30482
|
|
80
80
|
ignite/handlers/timing.py,sha256=nHeBHvPwYdPRMAx-jk_8MjZit4a7rmsmIWkUrajAG-s,4705
|
|
81
|
-
ignite/handlers/tqdm_logger.py,sha256=
|
|
81
|
+
ignite/handlers/tqdm_logger.py,sha256=3kxH39vM0LCDVwIZl9HQRaWM2Pr6bYC_l9oydFJmdM4,13093
|
|
82
82
|
ignite/handlers/utils.py,sha256=X4LRqo1kqGsbmX0pEuZKYR6K4C8sZudAqxCLriiXtCg,872
|
|
83
|
-
ignite/handlers/visdom_logger.py,sha256=
|
|
83
|
+
ignite/handlers/visdom_logger.py,sha256=RY5ss3NAPad7d3xFFnqczCtuO6RgmWq9ROz-sFf6imI,21862
|
|
84
84
|
ignite/handlers/wandb_logger.py,sha256=vGok3gADQmTNkc6KkfFBreYoHAO8EneuU65xjBpT5-Q,14837
|
|
85
85
|
ignite/metrics/__init__.py,sha256=m-8F8J17r-aEwsO6Ww-8AqDRN59WFfYBwCDKwqGDSmI,3627
|
|
86
86
|
ignite/metrics/accumulation.py,sha256=xWdsm9u6JfsfODX_GUKzQc_omrdFDJ4yELBR-xXgc4s,12448
|
|
87
|
-
ignite/metrics/accuracy.py,sha256=
|
|
87
|
+
ignite/metrics/accuracy.py,sha256=W8mO4W11VzryMXKy8G7W_g4A9PH9RYpejW_tQ-T_Txw,10245
|
|
88
88
|
ignite/metrics/average_precision.py,sha256=AL4fvWCUL6zMNq_u2vQRnAdmdByB8S8x8jSE-MoFVjY,3694
|
|
89
89
|
ignite/metrics/classification_report.py,sha256=zjGlaMnRz2__op6hrZq74OusO0W_5B1AIe8KzYGFilM,5988
|
|
90
90
|
ignite/metrics/cohen_kappa.py,sha256=Qwcd4P2kN12CVCFC-kVdzn_2XV7kGzP6LlWkK209JJ8,3815
|
|
@@ -129,7 +129,7 @@ ignite/metrics/gan/inception_score.py,sha256=78_qrECWb_KsbLbo1lvDnvFJ9FsWPsbUi1a
|
|
|
129
129
|
ignite/metrics/gan/utils.py,sha256=3nihbBrcM9MRcu6r0p3x5SgZQ5V4aag20ZppM7j_HiM,3993
|
|
130
130
|
ignite/metrics/nlp/__init__.py,sha256=TiDKRhw7lhZeoL2Cn4s306cKIuBbXl2fizN1ZepMhwI,168
|
|
131
131
|
ignite/metrics/nlp/bleu.py,sha256=NyQZ3CQB1xUnH_KWer5QtxkM_S_aiO3ok86UMxHaQ_w,11539
|
|
132
|
-
ignite/metrics/nlp/rouge.py,sha256=
|
|
132
|
+
ignite/metrics/nlp/rouge.py,sha256=siAxJzGE3KjH23u-F3DCUPke--ls-1XMygncGhTYJp4,15313
|
|
133
133
|
ignite/metrics/nlp/utils.py,sha256=CA0MRMk9l97QockFYYhU6k0-hLhP3GwW36ONZ7TRqmc,2341
|
|
134
134
|
ignite/metrics/regression/__init__.py,sha256=I594yB38ypWi9IDi9rrdshdXeBnSRcST09tnLRjN0yk,1472
|
|
135
135
|
ignite/metrics/regression/_base.py,sha256=5V6GkkaBYRuW9J3yDXucyTZp1XJ2uIG7F4w2XcBsd3w,2365
|
|
@@ -153,7 +153,7 @@ ignite/metrics/regression/spearman_correlation.py,sha256=IzmN4WIe7C4cTUU3BOkBmaw
|
|
|
153
153
|
ignite/metrics/regression/wave_hedges_distance.py,sha256=Ji_NRUgnZ3lJgi5fyNFLRjbHO648z4dBmqVDQU9ImKA,2792
|
|
154
154
|
ignite/metrics/vision/__init__.py,sha256=lPBAEq1idc6Q17poFm1SjttE27irHF1-uNeiwrxnLrU,159
|
|
155
155
|
ignite/metrics/vision/object_detection_average_precision_recall.py,sha256=4wwiNVd658ynIpIbQlffTA-ehvyJ2EzmJ5pBSBuA8XQ,25091
|
|
156
|
-
pytorch_ignite-0.6.0.
|
|
157
|
-
pytorch_ignite-0.6.0.
|
|
158
|
-
pytorch_ignite-0.6.0.
|
|
159
|
-
pytorch_ignite-0.6.0.
|
|
156
|
+
pytorch_ignite-0.6.0.dev20251217.dist-info/METADATA,sha256=cDj0GfWJzI7Hi_DBrSiZkH91_lQBQgF1M-7kzmBBHyc,27979
|
|
157
|
+
pytorch_ignite-0.6.0.dev20251217.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
158
|
+
pytorch_ignite-0.6.0.dev20251217.dist-info/licenses/LICENSE,sha256=SwJvaRmy1ql-k9_nL4WnER4_ODTMF9fWoP9HXkoicgw,1527
|
|
159
|
+
pytorch_ignite-0.6.0.dev20251217.dist-info/RECORD,,
|
{pytorch_ignite-0.6.0.dev20251216.dist-info → pytorch_ignite-0.6.0.dev20251217.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|