pytorch-ignite 0.6.0.dev20250310__py3-none-any.whl → 0.6.0.dev20260101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pytorch-ignite might be problematic. Click here for more details.
- ignite/__init__.py +1 -1
- ignite/contrib/engines/common.py +1 -0
- ignite/contrib/handlers/base_logger.py +1 -1
- ignite/contrib/handlers/clearml_logger.py +1 -1
- ignite/contrib/handlers/lr_finder.py +1 -1
- ignite/contrib/handlers/mlflow_logger.py +1 -1
- ignite/contrib/handlers/neptune_logger.py +1 -1
- ignite/contrib/handlers/param_scheduler.py +1 -1
- ignite/contrib/handlers/polyaxon_logger.py +1 -1
- ignite/contrib/handlers/tensorboard_logger.py +1 -1
- ignite/contrib/handlers/time_profilers.py +1 -1
- ignite/contrib/handlers/tqdm_logger.py +1 -1
- ignite/contrib/handlers/visdom_logger.py +1 -1
- ignite/contrib/handlers/wandb_logger.py +1 -1
- ignite/contrib/metrics/average_precision.py +1 -1
- ignite/contrib/metrics/cohen_kappa.py +1 -1
- ignite/contrib/metrics/gpu_info.py +1 -1
- ignite/contrib/metrics/precision_recall_curve.py +1 -1
- ignite/contrib/metrics/regression/canberra_metric.py +2 -3
- ignite/contrib/metrics/regression/fractional_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/fractional_bias.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/manhattan_distance.py +2 -3
- ignite/contrib/metrics/regression/maximum_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/mean_absolute_relative_error.py +2 -3
- ignite/contrib/metrics/regression/mean_error.py +2 -3
- ignite/contrib/metrics/regression/mean_normalized_bias.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/median_absolute_percentage_error.py +2 -3
- ignite/contrib/metrics/regression/median_relative_absolute_error.py +2 -3
- ignite/contrib/metrics/regression/r2_score.py +2 -3
- ignite/contrib/metrics/regression/wave_hedges_distance.py +2 -3
- ignite/contrib/metrics/roc_auc.py +1 -1
- ignite/distributed/auto.py +1 -0
- ignite/distributed/comp_models/base.py +7 -0
- ignite/distributed/comp_models/horovod.py +35 -5
- ignite/distributed/comp_models/native.py +8 -4
- ignite/distributed/comp_models/xla.py +5 -0
- ignite/distributed/launcher.py +4 -8
- ignite/distributed/utils.py +12 -4
- ignite/engine/__init__.py +9 -9
- ignite/engine/deterministic.py +1 -1
- ignite/engine/engine.py +38 -14
- ignite/engine/events.py +2 -1
- ignite/handlers/__init__.py +2 -0
- ignite/handlers/base_logger.py +47 -12
- ignite/handlers/checkpoint.py +46 -5
- ignite/handlers/clearml_logger.py +16 -4
- ignite/handlers/fbresearch_logger.py +2 -2
- ignite/handlers/lr_finder.py +9 -9
- ignite/handlers/mlflow_logger.py +43 -0
- ignite/handlers/neptune_logger.py +8 -0
- ignite/handlers/param_scheduler.py +7 -3
- ignite/handlers/polyaxon_logger.py +7 -0
- ignite/handlers/state_param_scheduler.py +8 -2
- ignite/handlers/tensorboard_logger.py +43 -0
- ignite/handlers/time_profilers.py +6 -3
- ignite/handlers/tqdm_logger.py +9 -5
- ignite/handlers/visdom_logger.py +10 -3
- ignite/handlers/wandb_logger.py +16 -9
- ignite/metrics/accuracy.py +2 -0
- ignite/metrics/clustering/calinski_harabasz_score.py +1 -1
- ignite/metrics/clustering/silhouette_score.py +1 -1
- ignite/metrics/fbeta.py +17 -8
- ignite/metrics/gan/fid.py +3 -3
- ignite/metrics/js_divergence.py +1 -1
- ignite/metrics/maximum_mean_discrepancy.py +1 -1
- ignite/metrics/metric.py +3 -0
- ignite/metrics/nlp/bleu.py +8 -6
- ignite/metrics/nlp/rouge.py +9 -6
- ignite/metrics/nlp/utils.py +1 -1
- ignite/metrics/precision_recall_curve.py +5 -5
- ignite/metrics/regression/_base.py +4 -0
- ignite/metrics/regression/fractional_bias.py +1 -1
- ignite/metrics/roc_auc.py +4 -3
- ignite/metrics/ssim.py +63 -21
- ignite/metrics/vision/object_detection_average_precision_recall.py +3 -0
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/METADATA +11 -17
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/RECORD +82 -83
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info}/WHEEL +1 -2
- pytorch_ignite-0.6.0.dev20250310.dist-info/top_level.txt +0 -1
- {pytorch_ignite-0.6.0.dev20250310.dist-info → pytorch_ignite-0.6.0.dev20260101.dist-info/licenses}/LICENSE +0 -0
|
@@ -153,6 +153,13 @@ class NeptuneLogger(BaseLogger):
|
|
|
153
153
|
output_transform=lambda loss: {"loss": loss},
|
|
154
154
|
)
|
|
155
155
|
|
|
156
|
+
Note:
|
|
157
|
+
:class:`~ignite.handlers.neptune_logger.OutputHandler` can handle
|
|
158
|
+
metrics, state attributes and engine output values of the following format:
|
|
159
|
+
- scalar values (i.e. int, float)
|
|
160
|
+
- 0d and 1d pytorch tensors
|
|
161
|
+
- dicts and list/tuples of previous types
|
|
162
|
+
|
|
156
163
|
"""
|
|
157
164
|
|
|
158
165
|
def __getattr__(self, attr: Any) -> Any:
|
|
@@ -691,6 +698,7 @@ class NeptuneSaver(BaseSaveHandler):
|
|
|
691
698
|
# hold onto the file stream for uploading.
|
|
692
699
|
# NOTE: This won't load the whole file in memory and upload
|
|
693
700
|
# the stream in smaller chunks.
|
|
701
|
+
# pyrefly: ignore [bad-argument-type]
|
|
694
702
|
self._logger[filename].upload(File.from_stream(tmp.file))
|
|
695
703
|
|
|
696
704
|
@idist.one_rank_only(with_barrier=True)
|
|
@@ -1122,13 +1122,14 @@ def create_lr_scheduler_with_warmup(
|
|
|
1122
1122
|
f"but given {type(lr_scheduler)}"
|
|
1123
1123
|
)
|
|
1124
1124
|
|
|
1125
|
-
if not isinstance(warmup_duration,
|
|
1125
|
+
if not isinstance(warmup_duration, int):
|
|
1126
1126
|
raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}")
|
|
1127
1127
|
|
|
1128
1128
|
if not (warmup_duration > 1):
|
|
1129
1129
|
raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}")
|
|
1130
1130
|
|
|
1131
1131
|
warmup_schedulers: List[ParamScheduler] = []
|
|
1132
|
+
milestones_values: List[Tuple[int, float]] = []
|
|
1132
1133
|
|
|
1133
1134
|
for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups):
|
|
1134
1135
|
if warmup_end_value is None:
|
|
@@ -1176,6 +1177,7 @@ def create_lr_scheduler_with_warmup(
|
|
|
1176
1177
|
lr_scheduler,
|
|
1177
1178
|
]
|
|
1178
1179
|
durations = [milestones_values[-1][0] + 1]
|
|
1180
|
+
# pyrefly: ignore [bad-argument-type]
|
|
1179
1181
|
combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history)
|
|
1180
1182
|
|
|
1181
1183
|
if output_simulated_values is not None:
|
|
@@ -1185,6 +1187,7 @@ def create_lr_scheduler_with_warmup(
|
|
|
1185
1187
|
f"but given {type(output_simulated_values)}."
|
|
1186
1188
|
)
|
|
1187
1189
|
num_events = len(output_simulated_values)
|
|
1190
|
+
# pyrefly: ignore [bad-argument-type]
|
|
1188
1191
|
result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations)
|
|
1189
1192
|
for i in range(num_events):
|
|
1190
1193
|
output_simulated_values[i] = result[i]
|
|
@@ -1650,6 +1653,7 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1650
1653
|
self.trainer = trainer
|
|
1651
1654
|
self.optimizer = optimizer
|
|
1652
1655
|
|
|
1656
|
+
min_lr: Union[float, List[float]]
|
|
1653
1657
|
if "min_lr" in scheduler_kwargs and param_group_index is not None:
|
|
1654
1658
|
min_lr = scheduler_kwargs["min_lr"]
|
|
1655
1659
|
if not isinstance(min_lr, float):
|
|
@@ -1670,11 +1674,11 @@ class ReduceLROnPlateauScheduler(ParamScheduler):
|
|
|
1670
1674
|
_scheduler_kwargs["verbose"] = False
|
|
1671
1675
|
|
|
1672
1676
|
self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
|
|
1673
|
-
self.scheduler._reduce_lr = self._reduce_lr
|
|
1677
|
+
self.scheduler._reduce_lr = self._reduce_lr
|
|
1674
1678
|
|
|
1675
1679
|
self._state_attrs += ["metric_name", "scheduler"]
|
|
1676
1680
|
|
|
1677
|
-
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1681
|
+
def __call__(self, engine: Engine, name: Optional[str] = None) -> None:
|
|
1678
1682
|
if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
|
|
1679
1683
|
raise ValueError(
|
|
1680
1684
|
"Argument engine should have in its 'state', attribute 'metrics' "
|
|
@@ -92,6 +92,13 @@ class PolyaxonLogger(BaseLogger):
|
|
|
92
92
|
)
|
|
93
93
|
# to manually end a run
|
|
94
94
|
plx_logger.close()
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
:class:`~ignite.handlers.polyaxon_logger.OutputHandler` can handle
|
|
98
|
+
metrics, state attributes and engine output values of the following format:
|
|
99
|
+
- scalar values (i.e. int, float)
|
|
100
|
+
- 0d and 1d pytorch tensors
|
|
101
|
+
- dicts and list/tuples of previous types
|
|
95
102
|
"""
|
|
96
103
|
|
|
97
104
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numbers
|
|
2
2
|
import warnings
|
|
3
3
|
from bisect import bisect_right
|
|
4
|
-
from typing import Any, List, Sequence, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, List, Sequence, Tuple, Union
|
|
5
5
|
|
|
6
6
|
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
|
|
7
7
|
from ignite.handlers.param_scheduler import BaseParamScheduler
|
|
@@ -183,7 +183,13 @@ class LambdaStateScheduler(StateParamScheduler):
|
|
|
183
183
|
|
|
184
184
|
"""
|
|
185
185
|
|
|
186
|
-
def __init__(
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
lambda_obj: Callable[[int], Union[List[float], float]],
|
|
189
|
+
param_name: str,
|
|
190
|
+
save_history: bool = False,
|
|
191
|
+
create_new: bool = False,
|
|
192
|
+
):
|
|
187
193
|
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)
|
|
188
194
|
|
|
189
195
|
if not callable(lambda_obj):
|
|
@@ -145,6 +145,49 @@ class TensorboardLogger(BaseLogger):
|
|
|
145
145
|
output_transform=lambda loss: {"loss": loss}
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
+
Note:
|
|
149
|
+
:class:`~ignite.handlers.tensorboard_logger.OutputHandler` can handle
|
|
150
|
+
metrics, state attributes and engine output values of the following format:
|
|
151
|
+
- scalar values (i.e. int, float)
|
|
152
|
+
- 0d and 1d pytorch tensors
|
|
153
|
+
- dicts and list/tuples of previous types
|
|
154
|
+
|
|
155
|
+
.. code-block:: python
|
|
156
|
+
|
|
157
|
+
# !!! This is not a runnable code !!!
|
|
158
|
+
evalutator.state.metrics = {
|
|
159
|
+
"a": 0,
|
|
160
|
+
"dict_value": {
|
|
161
|
+
"a": 111,
|
|
162
|
+
"c": {"d": 23, "e": [123, 234]},
|
|
163
|
+
},
|
|
164
|
+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
|
|
165
|
+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
handler = OutputHandler(
|
|
169
|
+
tag="tag",
|
|
170
|
+
metric_names="all",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
handler(evaluator, tb_logger, event_name=Events.EPOCH_COMPLETED)
|
|
174
|
+
# Behind it would call `tb_logger.writer.add_scalar` on
|
|
175
|
+
# {
|
|
176
|
+
# "tag/a": 0,
|
|
177
|
+
# "tag/dict_value/a": 111,
|
|
178
|
+
# "tag/dict_value/c/d": 23,
|
|
179
|
+
# "tag/dict_value/c/e/0": 123,
|
|
180
|
+
# "tag/dict_value/c/e/1": 234,
|
|
181
|
+
# "tag/list_value/0": 12,
|
|
182
|
+
# "tag/list_value/1": 13,
|
|
183
|
+
# "tag/list_value/2/aa": 33,
|
|
184
|
+
# "tag/list_value/2/bb": 44,
|
|
185
|
+
# "tag/tuple_value/0": 112,
|
|
186
|
+
# "tag/tuple_value/1": 113,
|
|
187
|
+
# "tag/tuple_value/2/aaa": 33,
|
|
188
|
+
# "tag/tuple_value/2/bbb": 44,
|
|
189
|
+
# }
|
|
190
|
+
|
|
148
191
|
"""
|
|
149
192
|
|
|
150
193
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
@@ -251,6 +251,7 @@ class BasicTimeProfiler:
|
|
|
251
251
|
total_eh_time: Union[int, torch.Tensor] = sum(
|
|
252
252
|
[(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]
|
|
253
253
|
)
|
|
254
|
+
# pyrefly: ignore [no-matching-overload]
|
|
254
255
|
event_handlers_stats = dict(
|
|
255
256
|
[
|
|
256
257
|
(str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e]))
|
|
@@ -334,6 +335,7 @@ class BasicTimeProfiler:
|
|
|
334
335
|
|
|
335
336
|
results_df = pd.DataFrame(
|
|
336
337
|
data=results_dump,
|
|
338
|
+
# pyrefly: ignore [bad-argument-type]
|
|
337
339
|
columns=[
|
|
338
340
|
"epoch",
|
|
339
341
|
"iteration",
|
|
@@ -498,14 +500,14 @@ class HandlersTimeProfiler:
|
|
|
498
500
|
|
|
499
501
|
self.dataflow_times: List[float] = []
|
|
500
502
|
self.processing_times: List[float] = []
|
|
501
|
-
self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {}
|
|
503
|
+
self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
|
|
502
504
|
|
|
503
505
|
@staticmethod
|
|
504
506
|
def _get_callable_name(handler: Callable) -> str:
|
|
505
507
|
# get name of the callable handler
|
|
506
508
|
return getattr(handler, "__qualname__", handler.__class__.__name__)
|
|
507
509
|
|
|
508
|
-
def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable:
|
|
510
|
+
def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
|
|
509
511
|
@functools.wraps(handler)
|
|
510
512
|
def _timeit_handler(*args: Any, **kwargs: Any) -> None:
|
|
511
513
|
self._event_handlers_timer.reset()
|
|
@@ -530,7 +532,7 @@ class HandlersTimeProfiler:
|
|
|
530
532
|
t = self._dataflow_timer.value()
|
|
531
533
|
self.dataflow_times.append(t)
|
|
532
534
|
|
|
533
|
-
def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None:
|
|
535
|
+
def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
|
|
534
536
|
# reset the variables used for profiling
|
|
535
537
|
self.dataflow_times = []
|
|
536
538
|
self.processing_times = []
|
|
@@ -689,6 +691,7 @@ class HandlersTimeProfiler:
|
|
|
689
691
|
|
|
690
692
|
results_dump = torch.stack(cols, dim=1).numpy()
|
|
691
693
|
|
|
694
|
+
# pyrefly: ignore [bad-argument-type]
|
|
692
695
|
results_df = pd.DataFrame(data=results_dump, columns=headers)
|
|
693
696
|
results_df.to_csv(output_path, index=False)
|
|
694
697
|
|
ignite/handlers/tqdm_logger.py
CHANGED
|
@@ -200,7 +200,7 @@ class ProgressBar(BaseLogger):
|
|
|
200
200
|
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
|
|
201
201
|
|
|
202
202
|
"""
|
|
203
|
-
desc = self.tqdm_kwargs.get("desc",
|
|
203
|
+
desc = self.tqdm_kwargs.get("desc", "")
|
|
204
204
|
|
|
205
205
|
if event_name not in engine._allowed_events:
|
|
206
206
|
raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")
|
|
@@ -223,8 +223,13 @@ class ProgressBar(BaseLogger):
|
|
|
223
223
|
super(ProgressBar, self).attach(engine, log_handler, event_name)
|
|
224
224
|
engine.add_event_handler(closing_event_name, self._close)
|
|
225
225
|
|
|
226
|
-
def attach_opt_params_handler(
|
|
227
|
-
self,
|
|
226
|
+
def attach_opt_params_handler(
|
|
227
|
+
self,
|
|
228
|
+
engine: Engine,
|
|
229
|
+
event_name: Union[str, Events],
|
|
230
|
+
*args: Any,
|
|
231
|
+
**kwargs: Any,
|
|
232
|
+
# pyrefly: ignore [bad-return]
|
|
228
233
|
) -> RemovableEventHandle:
|
|
229
234
|
"""Intentionally empty"""
|
|
230
235
|
pass
|
|
@@ -298,8 +303,7 @@ class _OutputHandler(BaseOutputHandler):
|
|
|
298
303
|
rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
|
|
299
304
|
metrics = OrderedDict()
|
|
300
305
|
for key, value in rendered_metrics.items():
|
|
301
|
-
key = "_".join(key[1:]) # tqdm has tag as description
|
|
302
|
-
|
|
306
|
+
key = "_".join(key[1:]) # skip tag as tqdm has tag as description
|
|
303
307
|
metrics[key] = value
|
|
304
308
|
|
|
305
309
|
if metrics:
|
ignite/handlers/visdom_logger.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Visdom logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
@@ -137,6 +137,13 @@ class VisdomLogger(BaseLogger):
|
|
|
137
137
|
output_transform=lambda loss: {"loss": loss}
|
|
138
138
|
)
|
|
139
139
|
|
|
140
|
+
Note:
|
|
141
|
+
:class:`~ignite.handlers.visdom_logger.OutputHandler` can handle
|
|
142
|
+
metrics, state attributes and engine output values of the following format:
|
|
143
|
+
- scalar values (i.e. int, float)
|
|
144
|
+
- 0d and 1d pytorch tensors
|
|
145
|
+
- dicts and list/tuples of previous types
|
|
146
|
+
|
|
140
147
|
.. versionchanged:: 0.4.7
|
|
141
148
|
accepts an optional list of `state_attributes`
|
|
142
149
|
"""
|
|
@@ -158,7 +165,7 @@ class VisdomLogger(BaseLogger):
|
|
|
158
165
|
"pip install git+https://github.com/fossasia/visdom.git"
|
|
159
166
|
)
|
|
160
167
|
|
|
161
|
-
if num_workers > 0:
|
|
168
|
+
if num_workers > 0 or TYPE_CHECKING:
|
|
162
169
|
# If visdom is installed, one of its dependencies `tornado`
|
|
163
170
|
# requires also `futures` to be installed.
|
|
164
171
|
# Let's check anyway if we can import it.
|
|
@@ -172,7 +179,7 @@ class VisdomLogger(BaseLogger):
|
|
|
172
179
|
)
|
|
173
180
|
|
|
174
181
|
if server is None:
|
|
175
|
-
server =
|
|
182
|
+
server = os.environ.get("VISDOM_SERVER_URL", "localhost")
|
|
176
183
|
|
|
177
184
|
if port is None:
|
|
178
185
|
port = int(os.environ.get("VISDOM_PORT", 8097))
|
ignite/handlers/wandb_logger.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""WandB logger and its helper handlers."""
|
|
2
2
|
|
|
3
3
|
from typing import Any, Callable, List, Optional, Union
|
|
4
|
+
from warnings import warn
|
|
4
5
|
|
|
5
6
|
from torch.optim import Optimizer
|
|
6
7
|
|
|
@@ -26,7 +27,7 @@ class WandBLogger(BaseLogger):
|
|
|
26
27
|
Args:
|
|
27
28
|
args: Positional arguments accepted by `wandb.init`.
|
|
28
29
|
kwargs: Keyword arguments accepted by `wandb.init`.
|
|
29
|
-
Please see `wandb.init <https://docs.wandb.ai/ref/python/init
|
|
30
|
+
Please see `wandb.init <https://docs.wandb.ai/ref/python/sdk/functions/init/>`_ for documentation of possible parameters.
|
|
30
31
|
|
|
31
32
|
Examples:
|
|
32
33
|
.. code-block:: python
|
|
@@ -120,6 +121,12 @@ class WandBLogger(BaseLogger):
|
|
|
120
121
|
)
|
|
121
122
|
evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
|
|
122
123
|
|
|
124
|
+
Note:
|
|
125
|
+
:class:`~ignite.handlers.wandb_logger.OutputHandler` can handle
|
|
126
|
+
metrics, state attributes and engine output values of the following format:
|
|
127
|
+
- scalar values (i.e. int, float)
|
|
128
|
+
- 0d and 1d pytorch tensors
|
|
129
|
+
- dicts and list/tuples of previous types
|
|
123
130
|
|
|
124
131
|
"""
|
|
125
132
|
|
|
@@ -166,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
|
|
|
166
173
|
Default is None, global_step based on attached engine. If provided,
|
|
167
174
|
uses function output as global_step. To setup global step from another engine, please use
|
|
168
175
|
:meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
|
|
169
|
-
sync:
|
|
170
|
-
the default value of wandb.log.
|
|
176
|
+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
|
|
171
177
|
|
|
172
178
|
Examples:
|
|
173
179
|
.. code-block:: python
|
|
@@ -278,7 +284,8 @@ class OutputHandler(BaseOutputHandler):
|
|
|
278
284
|
state_attributes: Optional[List[str]] = None,
|
|
279
285
|
):
|
|
280
286
|
super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
|
|
281
|
-
|
|
287
|
+
if sync is not None:
|
|
288
|
+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
|
|
282
289
|
|
|
283
290
|
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
|
|
284
291
|
if not isinstance(logger, WandBLogger):
|
|
@@ -292,7 +299,7 @@ class OutputHandler(BaseOutputHandler):
|
|
|
292
299
|
)
|
|
293
300
|
|
|
294
301
|
metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
|
|
295
|
-
logger.log(metrics, step=global_step
|
|
302
|
+
logger.log(metrics, step=global_step)
|
|
296
303
|
|
|
297
304
|
|
|
298
305
|
class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
@@ -303,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
303
310
|
as a sequence.
|
|
304
311
|
param_name: parameter name
|
|
305
312
|
tag: common title for all produced plots. For example, "generator"
|
|
306
|
-
sync:
|
|
307
|
-
the default value of wandb.log.
|
|
313
|
+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
|
|
308
314
|
|
|
309
315
|
Examples:
|
|
310
316
|
.. code-block:: python
|
|
@@ -340,7 +346,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
340
346
|
self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
|
|
341
347
|
):
|
|
342
348
|
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
|
|
343
|
-
|
|
349
|
+
if sync is not None:
|
|
350
|
+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
|
|
344
351
|
|
|
345
352
|
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
|
|
346
353
|
if not isinstance(logger, WandBLogger):
|
|
@@ -352,4 +359,4 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
|
|
|
352
359
|
f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
|
|
353
360
|
for i, param_group in enumerate(self.optimizer.param_groups)
|
|
354
361
|
}
|
|
355
|
-
logger.log(params, step=global_step
|
|
362
|
+
logger.log(params, step=global_step)
|
ignite/metrics/accuracy.py
CHANGED
|
@@ -254,6 +254,8 @@ class Accuracy(_BaseClassification):
|
|
|
254
254
|
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
|
|
255
255
|
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
|
|
256
256
|
correct = torch.all(y == y_pred.type_as(y), dim=-1)
|
|
257
|
+
else:
|
|
258
|
+
raise ValueError(f"Unexpected type: {self._type}")
|
|
257
259
|
|
|
258
260
|
self._num_correct += torch.sum(correct).to(self._device)
|
|
259
261
|
self._num_examples += correct.shape[0]
|
ignite/metrics/fbeta.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Callable, Optional, Union
|
|
1
|
+
from typing import Callable, cast, Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -15,7 +15,7 @@ def Fbeta(
|
|
|
15
15
|
precision: Optional[Precision] = None,
|
|
16
16
|
recall: Optional[Recall] = None,
|
|
17
17
|
output_transform: Optional[Callable] = None,
|
|
18
|
-
device: Union[str, torch.device] =
|
|
18
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
19
19
|
) -> MetricsLambda:
|
|
20
20
|
r"""Calculates F-beta score.
|
|
21
21
|
|
|
@@ -143,17 +143,26 @@ def Fbeta(
|
|
|
143
143
|
if not (beta > 0):
|
|
144
144
|
raise ValueError(f"Beta should be a positive integer, but given {beta}")
|
|
145
145
|
|
|
146
|
-
if precision is not None
|
|
147
|
-
|
|
146
|
+
if precision is not None:
|
|
147
|
+
if output_transform is not None:
|
|
148
|
+
raise ValueError("If precision argument is provided, output_transform should be None")
|
|
149
|
+
if device is not None:
|
|
150
|
+
raise ValueError("If precision argument is provided, device should be None")
|
|
148
151
|
|
|
149
|
-
if recall is not None
|
|
150
|
-
|
|
152
|
+
if recall is not None:
|
|
153
|
+
if output_transform is not None:
|
|
154
|
+
raise ValueError("If recall argument is provided, output_transform should be None")
|
|
155
|
+
if device is not None:
|
|
156
|
+
raise ValueError("If recall argument is provided, device should be None")
|
|
157
|
+
|
|
158
|
+
if precision is None and recall is None and device is None:
|
|
159
|
+
device = torch.device("cpu")
|
|
151
160
|
|
|
152
161
|
if precision is None:
|
|
153
162
|
precision = Precision(
|
|
154
163
|
output_transform=(lambda x: x) if output_transform is None else output_transform,
|
|
155
164
|
average=False,
|
|
156
|
-
device=device,
|
|
165
|
+
device=cast(Union[str, torch.device], recall._device if recall else device),
|
|
157
166
|
)
|
|
158
167
|
elif precision._average:
|
|
159
168
|
raise ValueError("Input precision metric should have average=False")
|
|
@@ -162,7 +171,7 @@ def Fbeta(
|
|
|
162
171
|
recall = Recall(
|
|
163
172
|
output_transform=(lambda x: x) if output_transform is None else output_transform,
|
|
164
173
|
average=False,
|
|
165
|
-
device=device,
|
|
174
|
+
device=cast(Union[str, torch.device], precision._device if precision else device),
|
|
166
175
|
)
|
|
167
176
|
elif recall._average:
|
|
168
177
|
raise ValueError("Input recall metric should have average=False")
|
ignite/metrics/gan/fid.py
CHANGED
|
@@ -31,13 +31,13 @@ def fid_score(
|
|
|
31
31
|
except ImportError:
|
|
32
32
|
raise ModuleNotFoundError("fid_score requires scipy to be installed.")
|
|
33
33
|
|
|
34
|
-
mu1, mu2 = mu1.cpu(), mu2.cpu()
|
|
35
|
-
sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()
|
|
34
|
+
mu1, mu2 = mu1.detach().cpu(), mu2.detach().cpu()
|
|
35
|
+
sigma1, sigma2 = sigma1.detach().cpu(), sigma2.detach().cpu()
|
|
36
36
|
|
|
37
37
|
diff = mu1 - mu2
|
|
38
38
|
|
|
39
39
|
# Product might be almost singular
|
|
40
|
-
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
|
|
40
|
+
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False)
|
|
41
41
|
# Numerical error might give slight imaginary component
|
|
42
42
|
if np.iscomplexobj(covmean):
|
|
43
43
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
ignite/metrics/js_divergence.py
CHANGED
ignite/metrics/metric.py
CHANGED
|
@@ -361,12 +361,15 @@ class Metric(Serializable, metaclass=ABCMeta):
|
|
|
361
361
|
device: Union[str, torch.device] = torch.device("cpu"),
|
|
362
362
|
skip_unrolling: bool = False,
|
|
363
363
|
):
|
|
364
|
+
if not callable(output_transform):
|
|
365
|
+
raise TypeError(f"Argument output_transform should be callable, got {type(output_transform)}")
|
|
364
366
|
self._output_transform = output_transform
|
|
365
367
|
|
|
366
368
|
# Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it.
|
|
367
369
|
if torch.device(device).type == "xla":
|
|
368
370
|
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
|
|
369
371
|
|
|
372
|
+
# pyrefly: ignore [read-only]
|
|
370
373
|
self._device = torch.device(device)
|
|
371
374
|
self._skip_unrolling = skip_unrolling
|
|
372
375
|
|
ignite/metrics/nlp/bleu.py
CHANGED
|
@@ -2,6 +2,7 @@ import math
|
|
|
2
2
|
from typing import Any, Callable, Sequence, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
from ignite.exceptions import NotComputableError
|
|
7
8
|
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
|
|
@@ -71,11 +72,11 @@ class Bleu(Metric):
|
|
|
71
72
|
|
|
72
73
|
More details can be found in `Papineni et al. 2002`__.
|
|
73
74
|
|
|
74
|
-
__ https://
|
|
75
|
+
__ https://aclanthology.org/P02-1040/
|
|
75
76
|
|
|
76
77
|
In addition, a review of smoothing techniques can be found in `Chen et al. 2014`__
|
|
77
78
|
|
|
78
|
-
__ https://aclanthology.org/W14-3346
|
|
79
|
+
__ https://aclanthology.org/W14-3346/
|
|
79
80
|
|
|
80
81
|
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
|
|
81
82
|
- `y_pred` (list(list(str))) - a list of hypotheses sentences.
|
|
@@ -236,12 +237,12 @@ class Bleu(Metric):
|
|
|
236
237
|
@reinit__is_reduced
|
|
237
238
|
def reset(self) -> None:
|
|
238
239
|
if self.average == "macro":
|
|
239
|
-
self._sum_of_bleu = torch.tensor(0.0, dtype=
|
|
240
|
+
self._sum_of_bleu = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
|
|
240
241
|
self._num_sentences = 0
|
|
241
242
|
|
|
242
243
|
if self.average == "micro":
|
|
243
|
-
self.p_numerators = torch.zeros(self.ngrams_order + 1)
|
|
244
|
-
self.p_denominators = torch.zeros(self.ngrams_order + 1)
|
|
244
|
+
self.p_numerators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
|
|
245
|
+
self.p_denominators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
|
|
245
246
|
self.hyp_length_sum = 0
|
|
246
247
|
self.ref_length_sum = 0
|
|
247
248
|
|
|
@@ -278,8 +279,9 @@ class Bleu(Metric):
|
|
|
278
279
|
)
|
|
279
280
|
return bleu_score
|
|
280
281
|
|
|
281
|
-
def compute(self) -> None:
|
|
282
|
+
def compute(self) -> Union[None, Tensor, float]:
|
|
282
283
|
if self.average == "macro":
|
|
283
284
|
return self._compute_macro()
|
|
284
285
|
elif self.average == "micro":
|
|
285
286
|
return self._compute_micro()
|
|
287
|
+
return None
|
ignite/metrics/nlp/rouge.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from abc import ABCMeta, abstractmethod
|
|
2
|
-
from
|
|
3
|
-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
|
|
2
|
+
from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
@@ -13,11 +12,15 @@ from ignite.metrics.nlp.utils import lcs, ngrams
|
|
|
13
12
|
__all__ = ["Rouge", "RougeN", "RougeL"]
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class Score(
|
|
15
|
+
class Score(NamedTuple):
|
|
17
16
|
r"""
|
|
18
17
|
Computes precision and recall for given matches, candidate and reference lengths.
|
|
19
18
|
"""
|
|
20
19
|
|
|
20
|
+
match: int
|
|
21
|
+
candidate: int
|
|
22
|
+
reference: int
|
|
23
|
+
|
|
21
24
|
def precision(self) -> float:
|
|
22
25
|
"""
|
|
23
26
|
Calculates precision.
|
|
@@ -191,7 +194,7 @@ class RougeN(_BaseRouge):
|
|
|
191
194
|
|
|
192
195
|
More details can be found in `Lin 2004`__.
|
|
193
196
|
|
|
194
|
-
__ https://
|
|
197
|
+
__ https://aclanthology.org/W04-1013
|
|
195
198
|
|
|
196
199
|
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
|
|
197
200
|
- `y_pred` (list(list(str))) must be a sequence of tokens.
|
|
@@ -265,7 +268,7 @@ class RougeL(_BaseRouge):
|
|
|
265
268
|
|
|
266
269
|
More details can be found in `Lin 2004`__.
|
|
267
270
|
|
|
268
|
-
__ https://
|
|
271
|
+
__ https://aclanthology.org/W04-1013
|
|
269
272
|
|
|
270
273
|
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
|
|
271
274
|
- `y_pred` (list(list(str))) must be a sequence of tokens.
|
|
@@ -331,7 +334,7 @@ class Rouge(Metric):
|
|
|
331
334
|
|
|
332
335
|
More details can be found in `Lin 2004`__.
|
|
333
336
|
|
|
334
|
-
__ https://
|
|
337
|
+
__ https://aclanthology.org/W04-1013
|
|
335
338
|
|
|
336
339
|
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
|
|
337
340
|
- `y_pred` (list(list(str))) must be a sequence of tokens.
|
ignite/metrics/nlp/utils.py
CHANGED
|
@@ -63,7 +63,7 @@ def modified_precision(references: Sequence[Sequence[Any]], candidate: Any, n: i
|
|
|
63
63
|
|
|
64
64
|
More details can be found in `Papineni et al. 2002`__.
|
|
65
65
|
|
|
66
|
-
__ https://
|
|
66
|
+
__ https://aclanthology.org/P02-1040
|
|
67
67
|
|
|
68
68
|
Args:
|
|
69
69
|
references: list of references R
|
|
@@ -97,7 +97,7 @@ class PrecisionRecallCurve(EpochMetric):
|
|
|
97
97
|
if len(self._predictions) < 1 or len(self._targets) < 1:
|
|
98
98
|
raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")
|
|
99
99
|
|
|
100
|
-
if self._result is None:
|
|
100
|
+
if self._result is None:
|
|
101
101
|
_prediction_tensor = torch.cat(self._predictions, dim=0)
|
|
102
102
|
_target_tensor = torch.cat(self._targets, dim=0)
|
|
103
103
|
|
|
@@ -110,11 +110,11 @@ class PrecisionRecallCurve(EpochMetric):
|
|
|
110
110
|
if idist.get_rank() == 0:
|
|
111
111
|
# Run compute_fn on zero rank only
|
|
112
112
|
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
|
|
113
|
-
precision = torch.tensor(precision, device=_prediction_tensor.device)
|
|
114
|
-
recall = torch.tensor(recall, device=_prediction_tensor.device)
|
|
113
|
+
precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype)
|
|
114
|
+
recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype)
|
|
115
115
|
# thresholds can have negative strides, not compatible with torch tensors
|
|
116
116
|
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
|
|
117
|
-
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
|
|
117
|
+
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device, dtype=self._double_dtype)
|
|
118
118
|
else:
|
|
119
119
|
precision, recall, thresholds = None, None, None
|
|
120
120
|
|
|
@@ -126,4 +126,4 @@ class PrecisionRecallCurve(EpochMetric):
|
|
|
126
126
|
|
|
127
127
|
self._result = (precision, recall, thresholds) # type: ignore[assignment]
|
|
128
128
|
|
|
129
|
-
return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)
|
|
129
|
+
return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)
|