pytorch-ignite 0.6.0.dev20250927__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/horovod.py +8 -1
- ignite/distributed/comp_models/native.py +2 -1
- ignite/distributed/comp_models/xla.py +2 -0
- ignite/distributed/launcher.py +4 -8
- ignite/engine/__init__.py +9 -9
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +9 -11
- ignite/engine/events.py +2 -1
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/checkpoint.py +2 -2
- ignite/handlers/clearml_logger.py +2 -2
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +10 -10
- ignite/handlers/neptune_logger.py +1 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +7 -2
- ignite/handlers/visdom_logger.py +2 -2
- ignite/handlers/wandb_logger.py +9 -8
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/metric.py +1 -0
- ignite/metrics/nlp/rouge.py +6 -3
- ignite/metrics/roc_auc.py +1 -0
- ignite/metrics/ssim.py +4 -0
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +2 -2
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +65 -65
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -1
- {pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/licenses/LICENSE +0 -0
ignite/engine/engine.py
CHANGED
|
@@ -148,7 +148,7 @@ class Engine(Serializable):
|
|
|
148
148
|
self.should_interrupt = False
|
|
149
149
|
self.state = State()
|
|
150
150
|
self._state_dict_user_keys: List[str] = []
|
|
151
|
-
self._allowed_events: List[EventEnum] = []
|
|
151
|
+
self._allowed_events: List[Union[str, EventEnum]] = []
|
|
152
152
|
|
|
153
153
|
self._dataloader_iter: Optional[Iterator[Any]] = None
|
|
154
154
|
self._init_iter: Optional[int] = None
|
|
@@ -163,9 +163,7 @@ class Engine(Serializable):
|
|
|
163
163
|
# generator provided by self._internal_run_as_gen
|
|
164
164
|
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
|
|
165
165
|
|
|
166
|
-
def register_events(
|
|
167
|
-
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
|
|
168
|
-
) -> None:
|
|
166
|
+
def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
|
|
169
167
|
"""Add events that can be fired.
|
|
170
168
|
|
|
171
169
|
Registering an event will let the user trigger these events at any point.
|
|
@@ -450,7 +448,7 @@ class Engine(Serializable):
|
|
|
450
448
|
first, others = ((resolved_engine,), args[1:])
|
|
451
449
|
else:
|
|
452
450
|
# metrics do not provide engine when registered
|
|
453
|
-
first, others = (tuple(), args)
|
|
451
|
+
first, others = (tuple(), args)
|
|
454
452
|
|
|
455
453
|
func(*first, *(event_args + others), **kwargs)
|
|
456
454
|
|
|
@@ -989,9 +987,9 @@ class Engine(Serializable):
|
|
|
989
987
|
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
|
|
990
988
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
991
989
|
self._init_timers(self.state)
|
|
990
|
+
start_time = time.time()
|
|
992
991
|
try:
|
|
993
992
|
try:
|
|
994
|
-
start_time = time.time()
|
|
995
993
|
self._fire_event(Events.STARTED)
|
|
996
994
|
yield from self._maybe_terminate_or_interrupt()
|
|
997
995
|
|
|
@@ -1010,7 +1008,7 @@ class Engine(Serializable):
|
|
|
1010
1008
|
# time is available for handlers but must be updated after fire
|
|
1011
1009
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1012
1010
|
|
|
1013
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1011
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1014
1012
|
handlers_start_time = time.time()
|
|
1015
1013
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1016
1014
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1043,7 +1041,7 @@ class Engine(Serializable):
|
|
|
1043
1041
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1044
1042
|
|
|
1045
1043
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1046
|
-
if self.should_terminate != "skip_completed":
|
|
1044
|
+
if self.should_terminate != "skip_completed":
|
|
1047
1045
|
handlers_start_time = time.time()
|
|
1048
1046
|
self._fire_event(Events.COMPLETED)
|
|
1049
1047
|
time_taken += time.time() - handlers_start_time
|
|
@@ -1189,9 +1187,9 @@ class Engine(Serializable):
|
|
|
1189
1187
|
# internal_run without generator for BC
|
|
1190
1188
|
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
|
|
1191
1189
|
self._init_timers(self.state)
|
|
1190
|
+
start_time = time.time()
|
|
1192
1191
|
try:
|
|
1193
1192
|
try:
|
|
1194
|
-
start_time = time.time()
|
|
1195
1193
|
self._fire_event(Events.STARTED)
|
|
1196
1194
|
self._maybe_terminate_legacy()
|
|
1197
1195
|
|
|
@@ -1210,7 +1208,7 @@ class Engine(Serializable):
|
|
|
1210
1208
|
# time is available for handlers but must be updated after fire
|
|
1211
1209
|
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
|
|
1212
1210
|
|
|
1213
|
-
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1211
|
+
if self.should_terminate_single_epoch != "skip_epoch_completed":
|
|
1214
1212
|
handlers_start_time = time.time()
|
|
1215
1213
|
self._fire_event(Events.EPOCH_COMPLETED)
|
|
1216
1214
|
epoch_time_taken += time.time() - handlers_start_time
|
|
@@ -1243,7 +1241,7 @@ class Engine(Serializable):
|
|
|
1243
1241
|
self.state.times[Events.COMPLETED.name] = time_taken
|
|
1244
1242
|
|
|
1245
1243
|
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
|
|
1246
|
-
if self.should_terminate != "skip_completed":
|
|
1244
|
+
if self.should_terminate != "skip_completed":
|
|
1247
1245
|
handlers_start_time = time.time()
|
|
1248
1246
|
self._fire_event(Events.COMPLETED)
|
|
1249
1247
|
time_taken += time.time() - handlers_start_time
|
ignite/engine/events.py
CHANGED
|
@@ -91,7 +91,7 @@ class CallableEventWithFilter:
|
|
|
91
91
|
raise ValueError("Argument every should be integer and greater than zero")
|
|
92
92
|
|
|
93
93
|
if once is not None:
|
|
94
|
-
c1 = isinstance(once,
|
|
94
|
+
c1 = isinstance(once, int) and once > 0
|
|
95
95
|
c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
|
|
96
96
|
if not (c1 or c2):
|
|
97
97
|
raise ValueError(
|
|
@@ -240,6 +240,7 @@ class EventEnum(CallableEventWithFilter, Enum):
|
|
|
240
240
|
def __new__(cls, value: str) -> "EventEnum":
|
|
241
241
|
obj = CallableEventWithFilter.__new__(cls)
|
|
242
242
|
obj._value_ = value
|
|
243
|
+
# pyrefly: ignore [bad-return]
|
|
243
244
|
return obj
|
|
244
245
|
|
|
245
246
|
|
ignite/handlers/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
|
|
|
6
6
|
from ignite.handlers.clearml_logger import ClearMLLogger
|
|
7
7
|
from ignite.handlers.early_stopping import EarlyStopping
|
|
8
8
|
from ignite.handlers.ema_handler import EMAHandler
|
|
9
|
+
from ignite.handlers.fbresearch_logger import FBResearchLogger
|
|
9
10
|
from ignite.handlers.lr_finder import FastaiLRFinder
|
|
10
11
|
from ignite.handlers.mlflow_logger import MLflowLogger
|
|
11
12
|
from ignite.handlers.neptune_logger import NeptuneLogger
|
|
@@ -64,6 +65,7 @@ __all__ = [
|
|
|
64
65
|
"CyclicalScheduler",
|
|
65
66
|
"create_lr_scheduler_with_warmup",
|
|
66
67
|
"FastaiLRFinder",
|
|
68
|
+
"FBResearchLogger",
|
|
67
69
|
"EMAHandler",
|
|
68
70
|
"BasicTimeProfiler",
|
|
69
71
|
"HandlersTimeProfiler",
|
ignite/handlers/checkpoint.py
CHANGED
|
@@ -315,7 +315,7 @@ class Checkpoint(Serializable):
|
|
|
315
315
|
"""
|
|
316
316
|
|
|
317
317
|
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
|
|
318
|
-
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
|
|
318
|
+
Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
|
|
319
319
|
_state_dict_all_req_keys = ("_saved",)
|
|
320
320
|
|
|
321
321
|
def __init__(
|
|
@@ -323,7 +323,7 @@ class Checkpoint(Serializable):
|
|
|
323
323
|
to_save: Mapping,
|
|
324
324
|
save_handler: Union[str, Path, Callable, BaseSaveHandler],
|
|
325
325
|
filename_prefix: str = "",
|
|
326
|
-
score_function: Optional[Callable] = None,
|
|
326
|
+
score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
|
|
327
327
|
score_name: Optional[str] = None,
|
|
328
328
|
n_saved: Union[int, None] = 1,
|
|
329
329
|
global_step_transform: Optional[Callable] = None,
|
|
@@ -862,7 +862,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
862
862
|
except ImportError:
|
|
863
863
|
try:
|
|
864
864
|
# Backwards-compatibility for legacy Trains SDK
|
|
865
|
-
from trains import Task
|
|
865
|
+
from trains import Task
|
|
866
866
|
except ImportError:
|
|
867
867
|
raise ModuleNotFoundError(
|
|
868
868
|
"This contrib module requires clearml to be installed. "
|
|
@@ -937,7 +937,7 @@ class ClearMLSaver(DiskSaver):
|
|
|
937
937
|
except ImportError:
|
|
938
938
|
try:
|
|
939
939
|
# Backwards-compatibility for legacy Trains SDK
|
|
940
|
-
from trains.binding.frameworks import WeightsFileHandler
|
|
940
|
+
from trains.binding.frameworks import WeightsFileHandler
|
|
941
941
|
except ImportError:
|
|
942
942
|
raise ModuleNotFoundError(
|
|
943
943
|
"This contrib module requires clearml to be installed. "
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from ignite import utils
|
|
9
9
|
from ignite.engine import Engine, Events
|
|
10
|
-
from ignite.handlers import Timer
|
|
10
|
+
from ignite.handlers.timing import Timer
|
|
11
11
|
|
|
12
12
|
MB = 1024.0 * 1024.0
|
|
13
13
|
|
|
@@ -154,7 +154,7 @@ class FBResearchLogger:
|
|
|
154
154
|
if torch.cuda.is_available():
|
|
155
155
|
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
|
|
156
156
|
|
|
157
|
-
current_iter = engine.state.iteration %
|
|
157
|
+
current_iter = ((engine.state.iteration - 1) % engine.state.epoch_length) + 1
|
|
158
158
|
iter_avg_time = self.iter_timer.value()
|
|
159
159
|
|
|
160
160
|
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
|
ignite/handlers/lr_finder.py
CHANGED
|
@@ -98,15 +98,18 @@ class FastaiLRFinder:
|
|
|
98
98
|
self._best_loss = None
|
|
99
99
|
self._diverge_flag = False
|
|
100
100
|
|
|
101
|
+
assert trainer.state.epoch_length is not None
|
|
102
|
+
assert trainer.state.max_epochs is not None
|
|
103
|
+
|
|
101
104
|
# attach LRScheduler to trainer.
|
|
102
105
|
if num_iter is None:
|
|
103
106
|
num_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
104
107
|
else:
|
|
105
|
-
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
108
|
+
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
|
|
106
109
|
if max_iter < num_iter:
|
|
107
110
|
max_iter = num_iter
|
|
108
111
|
trainer.state.max_iters = num_iter
|
|
109
|
-
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
112
|
+
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
|
|
110
113
|
|
|
111
114
|
if not trainer.has_event_handler(self._reached_num_iterations):
|
|
112
115
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
|
|
@@ -178,17 +181,14 @@ class FastaiLRFinder:
|
|
|
178
181
|
loss = idist.all_reduce(loss)
|
|
179
182
|
lr = self._lr_schedule.get_param()
|
|
180
183
|
self._history["lr"].append(lr)
|
|
181
|
-
if trainer.state.iteration
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
186
|
-
if loss < self._best_loss:
|
|
187
|
-
self._best_loss = loss
|
|
184
|
+
if trainer.state.iteration != 1 and smooth_f > 0:
|
|
185
|
+
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
|
|
186
|
+
if self._best_loss is None or loss < self._best_loss:
|
|
187
|
+
self._best_loss = loss
|
|
188
188
|
self._history["loss"].append(loss)
|
|
189
189
|
|
|
190
190
|
# Check if the loss has diverged; if it has, stop the trainer
|
|
191
|
-
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
191
|
+
if self._history["loss"][-1] > diverge_th * self._best_loss:
|
|
192
192
|
self._diverge_flag = True
|
|
193
193
|
self.logger.info("Stopping early, the loss has diverged")
|
|
194
194
|
trainer.terminate()
|
|
@@ -698,6 +698,7 @@ class NeptuneSaver(BaseSaveHandler):
|
|
|
698
698
|
# hold onto the file stream for uploading.
|
|
699
699
|
# NOTE: This won't load the whole file in memory and upload
|
|
700
700
|
# the stream in smaller chunks.
|
|
701
|
+
# pyrefly: ignore [bad-argument-type]
|
|
701
702
|
self._logger[filename].upload(File.from_stream(tmp.file))
|
|
702
703
|
|
|
703
704
|
@idist.one_rank_only(with_barrier=True)
|
|
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
|
|
|
1122
1122
|
f"but given {type(lr_scheduler)}"
|
|
1123
1123
|
)
|
|
1124
1124
|
|
|
1125
|
-
if not isinstance(warmup_duration,
|
|
1125
|
+
if not isinstance(warmup_duration, int):
|
|
1126
1126
|
raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
|
|
1127
1127
|
|
|
1128
1128
|
if not (warmup_duration > 1):
|
|
1129
1129
|
raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
|
|
1130
1130
|
|
|
1131
1131
|
warmup_schedulers: List[ParamScheduler] = []
|
|
1132
|
+
milestones_values: List[Tuple[int, float]] = []
|
|
1132
1133
|
|
|
1133
1134
|
for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
|
|
1134
1135
|
if warmup_end_value is None:
|
|
@@ -1176,6 +1177,7 @@ def create_lr_scheduler_with_warmup(
|
|
|
1176
1177
|
lr_scheduler,
|
|
1177
1178
|
]
|
|
1178
1179
|
durations = [milestones_values[-1][0] + 1]
|
|
1180
|
+
# pyrefly: ignore [bad-argument-type]
|
|
1179
1181
|
combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
|
|
1180
1182
|
|
|
1181
1183
|
if output_simulated_values is not None:
|
|
@@ -1185,6 +1187,7 @@ def create_lr_scheduler_with_warmup(
|
|
|
1185
1187
|
f"but given {type(output_simulated_values)}."
|
|
1186
1188
|
)
|
|
1187
1189
|
num_events = len(output_simulated_values)
|
|
1190
|
+
# pyrefly: ignore [bad-argument-type]
|
|
1188
1191
|
result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations)
|
|
1189
1192
|
for i in range(num_events):
|
|
1190
1193
|
output_simulated_values[i] = result[i]
|
|
@@ -1650,6 +1653,7 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1650
1653
|
self.trainer = trainer
|
|
1651
1654
|
self.optimizer = optimizer
|
|
1652
1655
|
|
|
1656
|
+
min_lr: Union[float, List[float]]
|
|
1653
1657
|
if "min_lr" in scheduler_kwargs and param_group_index is not None:
|
|
1654
1658
|
min_lr = scheduler_kwargs["min_lr"]
|
|
1655
1659
|
if not isinstance(min_lr, float):
|
|
@@ -1670,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1670
1674
|
_scheduler_kwargs["verbose"] = False
|
|
1671
1675
|
|
|
1672
1676
|
self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
|
|
1673
|
-
self.scheduler._reduce_lr = self._reduce_lr
|
|
1677
|
+
self.scheduler._reduce_lr = self._reduce_lr
|
|
1674
1678
|
|
|
1675
1679
|
self._state_attrs += ["metric_name", "scheduler"]
|
|
1676
1680
|
|
|
1677
|
-
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1681
|
+
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1678
1682
|
if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
|
|
1679
1683
|
raise ValueError(
|
|
1680
1684
|
"Argument engine should have in its 'state', attribute 'metrics' "
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numbers
|
|
2
2
|
import warnings
|
|
3
3
|
from bisect import bisect_right
|
|
4
|
-
from typing import Any, List, Sequence, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, List, Sequence, Tuple, Union
|
|
5
5
|
|
|
6
6
|
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
|
|
7
7
|
from ignite.handlers.param_scheduler import BaseParamScheduler
|
|
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
|
|
|
183
183
|
|
|
184
184
|
"""
|
|
185
185
|
|
|
186
|
-
def __init__(
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
lambda_obj: Callable[[int], Union[List[float], float]],
|
|
189
|
+
param_name: str,
|
|
190
|
+
save_history: bool = False,
|
|
191
|
+
create_new: bool = False,
|
|
192
|
+
):
|
|
187
193
|
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
|
|
188
194
|
|
|
189
195
|
if not callable(lambda_obj):
|
|
@@ -251,6 +251,7 @@ class BasicTimeProfiler:
|
|
|
251
251
|
total_eh_time: Union[int, torch.Tensor] = sum(
|
|
252
252
|
[(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]
|
|
253
253
|
)
|
|
254
|
+
# pyrefly: ignore [no-matching-overload]
|
|
254
255
|
event_handlers_stats = dict(
|
|
255
256
|
[
|
|
256
257
|
(str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e]))
|
|
@@ -334,6 +335,7 @@ class BasicTimeProfiler:
|
|
|
334
335
|
|
|
335
336
|
results_df = pd.DataFrame(
|
|
336
337
|
data=results_dump,
|
|
338
|
+
# pyrefly: ignore [bad-argument-type]
|
|
337
339
|
columns=[
|
|
338
340
|
"epoch",
|
|
339
341
|
"iteration",
|
|
@@ -498,14 +500,14 @@ class HandlersTimeProfiler:
|
|
|
498
500
|
|
|
499
501
|
self.dataflow_times: List[float] = []
|
|
500
502
|
self.processing_times: List[float] = []
|
|
501
|
-
self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
|
|
503
|
+
self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
|
|
502
504
|
|
|
503
505
|
@staticmethod
|
|
504
506
|
def _get_callable_name(handler: Callable) -> str:
|
|
505
507
|
# get name of the callable handler
|
|
506
508
|
return getattr(handler, "__qualname__", handler.__class__.__name__)
|
|
507
509
|
|
|
508
|
-
def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
|
|
510
|
+
def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
|
|
509
511
|
@functools.wraps(handler)
|
|
510
512
|
def _timeit_handler(*args: Any, **kwargs: Any) -> None:
|
|
511
513
|
self._event_handlers_timer.reset()
|
|
@@ -530,7 +532,7 @@ class HandlersTimeProfiler:
|
|
|
530
532
|
t = self._dataflow_timer.value()
|
|
531
533
|
self.dataflow_times.append(t)
|
|
532
534
|
|
|
533
|
-
def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
|
|
535
|
+
def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
|
|
534
536
|
# reset the variables used for profiling
|
|
535
537
|
self.dataflow_times = []
|
|
536
538
|
self.processing_times = []
|
|
@@ -689,6 +691,7 @@ class HandlersTimeProfiler:
|
|
|
689
691
|
|
|
690
692
|
results_dump = torch.stack(cols, dim=1).numpy()
|
|
691
693
|
|
|
694
|
+
# pyrefly: ignore [bad-argument-type]
|
|
692
695
|
results_df = pd.DataFrame(data=results_dump, columns=headers)
|
|
693
696
|
results_df.to_csv(output_path, index=False)
|
|
694
697
|
|
ignite/handlers/tqdm_logger.py
CHANGED
|
@@ -223,8 +223,13 @@ class ProgressBar(BaseLogger):
|
|
|
223
223
|
super(ProgressBar, self).attach(engine, log_handler, event_name)
|
|
224
224
|
engine.add_event_handler(closing_event_name, self._close)
|
|
225
225
|
|
|
226
|
-
def attach_opt_params_handler(
|
|
227
|
-
self,
|
|
226
|
+
def attach_opt_params_handler(
|
|
227
|
+
self,
|
|
228
|
+
engine: Engine,
|
|
229
|
+
event_name: Union[str, Events],
|
|
230
|
+
*args: Any,
|
|
231
|
+
**kwargs: Any,
|
|
232
|
+
# pyrefly: ignore [bad-return]
|
|
228
233
|
) -> RemovableEventHandle:
|
|
229
234
|
"""Intentionally empty"""
|
|
230
235
|
pass
|
ignite/handlers/visdom_logger.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Visdom logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
@@ -165,7 +165,7 @@ class VisdomLogger(BaseLogger):
|
|
|
165
165
|
"pip install git+https://github.com/fossasia/visdom.git"
|
|
166
166
|
)
|
|
167
167
|
|
|
168
|
-
if num_workers > 0:
|
|
168
|
+
if num_workers > 0 or TYPE_CHECKING:
|
|
169
169
|
# If visdom is installed, one of its dependencies `tornado`
|
|
170
170
|
# requires also `futures` to be installed.
|
|
171
171
|
# Let's check anyway if we can import it.
|
ignite/handlers/wandb_logger.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""WandB logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
from typing import Any, Callable, List, Optional, Union
|
|
4
|
+
from warnings import warn
|
|
4
5
|
|
|
5
6
|
from torch.optim import Optimizer
|
|
6
7
|
|
|
@@ -172,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
|
|
|
172
173
|
Default is None, global_step based on attached engine. If provided,
|
|
173
174
|
uses function output as global_step. To setup global step from another engine, please use
|
|
174
175
|
:meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
|
|
175
|
-
sync:
|
|
176
|
-
the default value of wandb.log.
|
|
176
|
+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
|
|
177
177
|
|
|
178
178
|
Examples:
|
|
179
179
|
.. code-block:: python
|
|
@@ -284,7 +284,8 @@ class OutputHandler(BaseOutputHandler):
|
|
|
284
284
|
state_attributes: Optional[List[str]] = None,
|
|
285
285
|
):
|
|
286
286
|
super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
|
|
287
|
-
|
|
287
|
+
if sync is not None:
|
|
288
|
+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
|
|
288
289
|
|
|
289
290
|
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
|
|
290
291
|
if not isinstance(logger, WandBLogger):
|
|
@@ -298,7 +299,7 @@ class OutputHandler(BaseOutputHandler):
|
|
|
298
299
|
)
|
|
299
300
|
|
|
300
301
|
metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
|
|
301
|
-
logger.log(metrics, step=global_step
|
|
302
|
+
logger.log(metrics, step=global_step)
|
|
302
303
|
|
|
303
304
|
|
|
304
305
|
class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
@@ -309,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
309
310
|
as a sequence.
|
|
310
311
|
param_name: parameter name
|
|
311
312
|
tag: common title for all produced plots. For example, "generator"
|
|
312
|
-
sync:
|
|
313
|
-
the default value of wandb.log.
|
|
313
|
+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
|
|
314
314
|
|
|
315
315
|
Examples:
|
|
316
316
|
.. code-block:: python
|
|
@@ -346,7 +346,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
346
346
|
self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
|
|
347
347
|
):
|
|
348
348
|
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
|
|
349
|
-
|
|
349
|
+
if sync is not None:
|
|
350
|
+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
|
|
350
351
|
|
|
351
352
|
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
|
|
352
353
|
if not isinstance(logger, WandBLogger):
|
|
@@ -358,4 +359,4 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
358
359
|
f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
|
|
359
360
|
for i, param_group in enumerate(self.optimizer.param_groups)
|
|
360
361
|
}
|
|
361
|
-
logger.log(params, step=global_step
|
|
362
|
+
logger.log(params, step=global_step)
|
ignite/metrics/accuracy.py
CHANGED
|
@@ -254,6 +254,8 @@ class Accuracy(_BaseClassification):
|
|
|
254
254
|
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
|
|
255
255
|
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
|
|
256
256
|
correct = torch.all(y == y_pred.type_as(y), dim=-1)
|
|
257
|
+
else:
|
|
258
|
+
raise ValueError(f"Unexpected type: {self._type}")
|
|
257
259
|
|
|
258
260
|
self._num_correct += torch.sum(correct).to(self._device)
|
|
259
261
|
self._num_examples += correct.shape[0]
|
ignite/metrics/metric.py
CHANGED
|
@@ -369,6 +369,7 @@ class Metric(Serializable, metaclass=ABCMeta):
|
|
|
369
369
|
if torch.device(device).type == "xla":
|
|
370
370
|
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
|
|
371
371
|
|
|
372
|
+
# pyrefly: ignore [read-only]
|
|
372
373
|
self._device = torch.device(device)
|
|
373
374
|
self._skip_unrolling = skip_unrolling
|
|
374
375
|
|
ignite/metrics/nlp/rouge.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from abc import ABCMeta, abstractmethod
|
|
2
|
-
from
|
|
3
|
-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
|
|
2
|
+
from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
@@ -13,11 +12,15 @@ from ignite.metrics.nlp.utils import lcs, ngrams
|
|
|
13
12
|
__all__ = ["Rouge", "RougeN", "RougeL"]
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class Score(
|
|
15
|
+
class Score(NamedTuple):
|
|
17
16
|
r"""
|
|
18
17
|
Computes precision and recall for given matches, candidate and reference lengths.
|
|
19
18
|
"""
|
|
20
19
|
|
|
20
|
+
match: int
|
|
21
|
+
candidate: int
|
|
22
|
+
reference: int
|
|
23
|
+
|
|
21
24
|
def precision(self) -> float:
|
|
22
25
|
"""
|
|
23
26
|
Calculates precision.
|
ignite/metrics/roc_auc.py
CHANGED
ignite/metrics/ssim.py
CHANGED
|
@@ -161,11 +161,15 @@ class SSIM(Metric):
|
|
|
161
161
|
kernel_y = self._gaussian(kernel_size[1], sigma[1])
|
|
162
162
|
if ndims == 3:
|
|
163
163
|
kernel_z = self._gaussian(kernel_size[2], sigma[2])
|
|
164
|
+
else:
|
|
165
|
+
kernel_z = None
|
|
164
166
|
else:
|
|
165
167
|
kernel_x = self._uniform(kernel_size[0])
|
|
166
168
|
kernel_y = self._uniform(kernel_size[1])
|
|
167
169
|
if ndims == 3:
|
|
168
170
|
kernel_z = self._uniform(kernel_size[2])
|
|
171
|
+
else:
|
|
172
|
+
kernel_z = None
|
|
169
173
|
|
|
170
174
|
result = (
|
|
171
175
|
torch.einsum("i,j->ij", kernel_x, kernel_y)
|
|
@@ -160,6 +160,9 @@ class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision):
|
|
|
160
160
|
elif self._area_range == "large":
|
|
161
161
|
min_area = 9216
|
|
162
162
|
max_area = 1e10
|
|
163
|
+
else:
|
|
164
|
+
min_area = 0
|
|
165
|
+
max_area = 1e10
|
|
163
166
|
return torch.logical_and(areas >= min_area, areas <= max_area)
|
|
164
167
|
|
|
165
168
|
def _check_matching_input(
|
{pytorch_ignite-0.6.0.dev20250927.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pytorch-ignite
|
|
3
|
-
Version: 0.6.0.
|
|
3
|
+
Version: 0.6.0.dev20260101
|
|
4
4
|
Summary: A lightweight library to help with training neural networks in PyTorch.
|
|
5
5
|
Project-URL: Homepage, https://pytorch-ignite.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/pytorch/ignite
|
|
@@ -412,7 +412,7 @@ Few pointers to get you started:
|
|
|
412
412
|
- [](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb) [Basic example of LR finder on
|
|
413
413
|
MNIST](https://github.com/pytorch/ignite/blob/master/examples/notebooks/FastaiLRFinder_MNIST.ipynb)
|
|
414
414
|
- [](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb) [Benchmark mixed precision training on Cifar100:
|
|
415
|
-
torch.
|
|
415
|
+
torch.amp vs nvidia/apex](https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar100_bench_amp.ipynb)
|
|
416
416
|
- [](https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/MNIST_on_TPU.ipynb) [MNIST training on a single
|
|
417
417
|
TPU](https://github.com/pytorch/ignite/blob/master/examples/notebooks/MNIST_on_TPU.ipynb)
|
|
418
418
|
- [](https://colab.research.google.com/drive/1E9zJrptnLJ_PKhmaP5Vhb6DTVRvyrKHx) [CIFAR10 Training on multiple TPUs](https://github.com/pytorch/ignite/tree/master/examples/cifar10)
|