congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
congrads/__init__.py
CHANGED
|
@@ -1,21 +1,11 @@
|
|
|
1
|
-
#
|
|
1
|
+
try: # noqa: D104
|
|
2
|
+
from importlib.metadata import version as get_version # Python 3.8+
|
|
3
|
+
except ImportError:
|
|
4
|
+
from pkg_resources import (
|
|
5
|
+
get_distribution as get_version,
|
|
6
|
+
) # Fallback for older versions
|
|
2
7
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from . import descriptor
|
|
8
|
-
from . import learners
|
|
9
|
-
from . import metrics
|
|
10
|
-
from . import networks
|
|
11
|
-
|
|
12
|
-
# Define __all__ to specify that the submodules are accessible, but not classes directly.
|
|
13
|
-
__all__ = [
|
|
14
|
-
"core",
|
|
15
|
-
"constraints",
|
|
16
|
-
"datasets",
|
|
17
|
-
"descriptor",
|
|
18
|
-
"learners",
|
|
19
|
-
"metrics",
|
|
20
|
-
"networks"
|
|
21
|
-
]
|
|
8
|
+
try:
|
|
9
|
+
version = get_version("congrads") # Replace with your package name
|
|
10
|
+
except Exception:
|
|
11
|
+
version = "0.0.0" # Fallback if the package isn't installed
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""Callback and Operation Framework for Modular Training Pipelines.
|
|
2
|
+
|
|
3
|
+
This module provides a structured system for defining and executing
|
|
4
|
+
callbacks and operations at different stages of a training lifecycle.
|
|
5
|
+
It is designed to support:
|
|
6
|
+
|
|
7
|
+
- Stateless, reusable operations that produce outputs merged into
|
|
8
|
+
the event-local data.
|
|
9
|
+
- Callbacks that group operations and/or custom logic for specific
|
|
10
|
+
stages of training, epochs, batches, and steps.
|
|
11
|
+
- A central CallbackManager to orchestrate multiple callbacks,
|
|
12
|
+
maintain shared context, and execute stage-specific pipelines
|
|
13
|
+
in deterministic order.
|
|
14
|
+
|
|
15
|
+
Stages supported:
|
|
16
|
+
- on_train_start
|
|
17
|
+
- on_train_end
|
|
18
|
+
- on_epoch_start
|
|
19
|
+
- on_epoch_end
|
|
20
|
+
- on_batch_start
|
|
21
|
+
- on_batch_end
|
|
22
|
+
- on_test_start
|
|
23
|
+
- on_test_end
|
|
24
|
+
- on_train_batch_start
|
|
25
|
+
- on_train_batch_end
|
|
26
|
+
- on_valid_batch_start
|
|
27
|
+
- on_valid_batch_end
|
|
28
|
+
- on_test_batch_start
|
|
29
|
+
- on_test_batch_end
|
|
30
|
+
- after_train_forward
|
|
31
|
+
- after_valid_forward
|
|
32
|
+
- after_test_forward
|
|
33
|
+
|
|
34
|
+
Usage:
|
|
35
|
+
1. Define Operations by subclassing `Operation` and implementing
|
|
36
|
+
the `compute` method.
|
|
37
|
+
2. Create a Callback subclass or instance and register Operations
|
|
38
|
+
to stages via `add(stage, operation)`.
|
|
39
|
+
3. Register callbacks with `CallbackManager`.
|
|
40
|
+
4. Invoke `CallbackManager.run(stage, data)` at appropriate points
|
|
41
|
+
in the training loop, passing in event-local data.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from abc import ABC, abstractmethod
|
|
45
|
+
from collections.abc import Iterable
|
|
46
|
+
from typing import Any, Literal, Self
|
|
47
|
+
|
|
48
|
+
Stage = Literal[
|
|
49
|
+
"on_train_start",
|
|
50
|
+
"on_train_end",
|
|
51
|
+
"on_epoch_start",
|
|
52
|
+
"on_epoch_end",
|
|
53
|
+
"on_test_start",
|
|
54
|
+
"on_test_end",
|
|
55
|
+
"on_batch_start",
|
|
56
|
+
"on_batch_end",
|
|
57
|
+
"on_train_batch_start",
|
|
58
|
+
"on_train_batch_end",
|
|
59
|
+
"on_valid_batch_start",
|
|
60
|
+
"on_valid_batch_end",
|
|
61
|
+
"on_test_batch_start",
|
|
62
|
+
"on_test_batch_end",
|
|
63
|
+
"after_train_forward",
|
|
64
|
+
"after_valid_forward",
|
|
65
|
+
"after_test_forward",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
STAGES: tuple[Stage, ...] = (
|
|
69
|
+
"on_train_start",
|
|
70
|
+
"on_train_end",
|
|
71
|
+
"on_epoch_start",
|
|
72
|
+
"on_epoch_end",
|
|
73
|
+
"on_test_start",
|
|
74
|
+
"on_test_end",
|
|
75
|
+
"on_batch_start",
|
|
76
|
+
"on_batch_end",
|
|
77
|
+
"on_train_batch_start",
|
|
78
|
+
"on_train_batch_end",
|
|
79
|
+
"on_valid_batch_start",
|
|
80
|
+
"on_valid_batch_end",
|
|
81
|
+
"on_test_batch_start",
|
|
82
|
+
"on_test_batch_end",
|
|
83
|
+
"after_train_forward",
|
|
84
|
+
"after_valid_forward",
|
|
85
|
+
"after_test_forward",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Operation(ABC):
|
|
90
|
+
"""Abstract base class representing a stateless unit of work executed inside a callback stage.
|
|
91
|
+
|
|
92
|
+
Subclasses should implement the `compute` method which returns
|
|
93
|
+
a dictionary of outputs to merge into the running event data.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __repr__(self) -> str:
|
|
97
|
+
"""Return a concise string representation of the operation."""
|
|
98
|
+
return f"<{self.__class__.__name__}>"
|
|
99
|
+
|
|
100
|
+
def __call__(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
101
|
+
"""Execute the operation with the given event-local data and shared context.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
data (dict[str, Any]): Event-local dictionary containing data for this stage.
|
|
105
|
+
ctx (dict[str, Any]): Shared context dictionary accessible by all operations and callbacks.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
dict[str, Any]: Outputs produced by the operation to merge into the running data.
|
|
109
|
+
Returns an empty dict if `compute` returns None.
|
|
110
|
+
"""
|
|
111
|
+
out = self.compute(data, ctx)
|
|
112
|
+
if out is None:
|
|
113
|
+
return {}
|
|
114
|
+
if not isinstance(out, dict):
|
|
115
|
+
raise TypeError(f"{self.__class__.__name__}.compute must return dict or None")
|
|
116
|
+
return out
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def compute(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any] | None:
|
|
120
|
+
"""Perform the operation's computation.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data (dict[str, Any]): Event-local dictionary containing the current data.
|
|
124
|
+
ctx (dict[str, Any]): Shared context dictionary.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
dict[str, Any] or None: Outputs to merge into the running data.
|
|
128
|
+
Returning None is equivalent to {}.
|
|
129
|
+
"""
|
|
130
|
+
raise NotImplementedError
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Callback(ABC): # noqa: B024
|
|
134
|
+
"""Abstract base class representing a callback that can have multiple operations registered to different stages of the training lifecycle.
|
|
135
|
+
|
|
136
|
+
Each stage method executes all operations registered for that stage
|
|
137
|
+
in insertion order. Operations can modify the event-local data dictionary.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(self):
|
|
141
|
+
"""Initialize the callback with empty operation lists for all stages."""
|
|
142
|
+
self._ops_by_stage: dict[Stage, list[Operation]] = {s: [] for s in STAGES}
|
|
143
|
+
|
|
144
|
+
def __repr__(self) -> str:
|
|
145
|
+
"""Return a concise string showing number of operations per stage."""
|
|
146
|
+
ops_summary = {stage: len(ops) for stage, ops in self._ops_by_stage.items() if ops}
|
|
147
|
+
return f"<{self.__class__.__name__} ops={ops_summary}>"
|
|
148
|
+
|
|
149
|
+
def add(self, stage: Stage, op: Operation) -> Self:
|
|
150
|
+
"""Register an operation to execute at the given stage.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
stage (Stage): Lifecycle stage at which to run the operation.
|
|
154
|
+
op (Operation): Operation instance to add.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Self: Returns self for method chaining.
|
|
158
|
+
"""
|
|
159
|
+
self._ops_by_stage[stage].append(op)
|
|
160
|
+
return self
|
|
161
|
+
|
|
162
|
+
def _run_ops(self, stage: Stage, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
163
|
+
"""Execute all operations registered for a specific stage.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
stage (Stage): Lifecycle stage to execute.
|
|
167
|
+
data (dict[str, Any]): Event-local data to pass to operations.
|
|
168
|
+
ctx (dict[str, Any]): Shared context across callbacks and operations.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
dict[str, Any]: Merged data dictionary after executing all operations.
|
|
172
|
+
|
|
173
|
+
Notes:
|
|
174
|
+
- Operations are executed in insertion order.
|
|
175
|
+
- If an operation overwrites existing keys, a warning is issued.
|
|
176
|
+
"""
|
|
177
|
+
out = dict(data)
|
|
178
|
+
|
|
179
|
+
for operation in self._ops_by_stage[stage]:
|
|
180
|
+
try:
|
|
181
|
+
produced = operation(out, ctx) or {}
|
|
182
|
+
except Exception as e:
|
|
183
|
+
raise RuntimeError(f"Error in operation {operation} at stage {stage}") from e
|
|
184
|
+
|
|
185
|
+
collisions = set(produced.keys()) & set(out.keys())
|
|
186
|
+
if collisions:
|
|
187
|
+
import warnings
|
|
188
|
+
|
|
189
|
+
warnings.warn(
|
|
190
|
+
f"Operation {operation} at stage '{stage}' is overwriting keys: {collisions}",
|
|
191
|
+
stacklevel=2,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
out.update(produced)
|
|
195
|
+
|
|
196
|
+
return out
|
|
197
|
+
|
|
198
|
+
# --- training ---
|
|
199
|
+
def on_train_start(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
200
|
+
"""Execute operations registered for the 'on_train_start' stage."""
|
|
201
|
+
self._run_ops("on_train_start", data, ctx)
|
|
202
|
+
|
|
203
|
+
def on_train_end(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
204
|
+
"""Execute operations registered for the 'on_train_end' stage."""
|
|
205
|
+
self._run_ops("on_train_end", data, ctx)
|
|
206
|
+
|
|
207
|
+
# --- epoch ---
|
|
208
|
+
def on_epoch_start(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
209
|
+
"""Execute operations registered for the 'on_epoch_start' stage."""
|
|
210
|
+
self._run_ops("on_epoch_start", data, ctx)
|
|
211
|
+
|
|
212
|
+
def on_epoch_end(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
213
|
+
"""Execute operations registered for the 'on_epoch_end' stage."""
|
|
214
|
+
self._run_ops("on_epoch_end", data, ctx)
|
|
215
|
+
|
|
216
|
+
# --- test ---
|
|
217
|
+
def on_test_start(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
218
|
+
"""Execute operations registered for the 'on_test_start' stage."""
|
|
219
|
+
self._run_ops("on_test_start", data, ctx)
|
|
220
|
+
|
|
221
|
+
def on_test_end(self, data: dict[str, Any], ctx: dict[str, Any]):
|
|
222
|
+
"""Execute operations registered for the 'on_test_end' stage."""
|
|
223
|
+
self._run_ops("on_test_end", data, ctx)
|
|
224
|
+
|
|
225
|
+
# --- batch ---
|
|
226
|
+
def on_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
227
|
+
"""Execute operations registered for the 'on_batch_start' stage."""
|
|
228
|
+
return self._run_ops("on_batch_start", data, ctx)
|
|
229
|
+
|
|
230
|
+
def on_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
231
|
+
"""Execute operations registered for the 'on_batch_end' stage."""
|
|
232
|
+
return self._run_ops("on_batch_end", data, ctx)
|
|
233
|
+
|
|
234
|
+
def on_train_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
235
|
+
"""Execute operations registered for the 'on_train_batch_start' stage."""
|
|
236
|
+
return self._run_ops("on_train_batch_start", data, ctx)
|
|
237
|
+
|
|
238
|
+
def on_train_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
239
|
+
"""Execute operations registered for the 'on_train_batch_end' stage."""
|
|
240
|
+
return self._run_ops("on_train_batch_end", data, ctx)
|
|
241
|
+
|
|
242
|
+
def on_valid_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
243
|
+
"""Execute operations registered for the 'on_valid_batch_start' stage."""
|
|
244
|
+
return self._run_ops("on_valid_batch_start", data, ctx)
|
|
245
|
+
|
|
246
|
+
def on_valid_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
247
|
+
"""Execute operations registered for the 'on_valid_batch_end' stage."""
|
|
248
|
+
return self._run_ops("on_valid_batch_end", data, ctx)
|
|
249
|
+
|
|
250
|
+
def on_test_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
251
|
+
"""Execute operations registered for the 'on_test_batch_start' stage."""
|
|
252
|
+
return self._run_ops("on_test_batch_start", data, ctx)
|
|
253
|
+
|
|
254
|
+
def on_test_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
255
|
+
"""Execute operations registered for the 'on_test_batch_end' stage."""
|
|
256
|
+
return self._run_ops("on_test_batch_end", data, ctx)
|
|
257
|
+
|
|
258
|
+
def after_train_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
259
|
+
"""Execute operations registered for the 'after_train_forward' stage."""
|
|
260
|
+
return self._run_ops("after_train_forward", data, ctx)
|
|
261
|
+
|
|
262
|
+
def after_valid_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
263
|
+
"""Execute operations registered for the 'after_valid_forward' stage."""
|
|
264
|
+
return self._run_ops("after_valid_forward", data, ctx)
|
|
265
|
+
|
|
266
|
+
def after_test_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
|
|
267
|
+
"""Execute operations registered for the 'after_test_forward' stage."""
|
|
268
|
+
return self._run_ops("after_test_forward", data, ctx)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class CallbackManager:
|
|
272
|
+
"""Orchestrates multiple callbacks and executes them at specific lifecycle stages.
|
|
273
|
+
|
|
274
|
+
- Callbacks are executed in registration order.
|
|
275
|
+
- Event-local data flows through all callbacks.
|
|
276
|
+
- Shared context is available for cross-callback communication.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(self, callbacks: Iterable[Callback] | None = None):
|
|
280
|
+
"""Initialize a CallbackManager instance.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
callbacks (Iterable[Callback] | None): Optional initial callbacks to register.
|
|
284
|
+
If None, starts with an empty callback list.
|
|
285
|
+
|
|
286
|
+
Attributes:
|
|
287
|
+
_callbacks (list[Callback]): Internal list of registered callbacks.
|
|
288
|
+
ctx (dict[str, Any]): Shared context dictionary accessible to all callbacks
|
|
289
|
+
and operations for cross-event communication.
|
|
290
|
+
"""
|
|
291
|
+
self._callbacks: list[Callback] = list(callbacks) if callbacks else []
|
|
292
|
+
self.ctx: dict[str, Any] = {}
|
|
293
|
+
|
|
294
|
+
def __repr__(self) -> str:
|
|
295
|
+
"""Return a concise representation showing registered callbacks and ctx keys."""
|
|
296
|
+
names = [cb.__class__.__name__ for cb in self._callbacks]
|
|
297
|
+
return f"<CallbackManager callbacks={names} ctx_keys={list(self.ctx.keys())}>"
|
|
298
|
+
|
|
299
|
+
def add(self, callback: Callback) -> Self:
|
|
300
|
+
"""Register a single callback.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
callback (Callback): Callback instance to add.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Self: Returns self for fluent chaining.
|
|
307
|
+
"""
|
|
308
|
+
self._callbacks.append(callback)
|
|
309
|
+
return self
|
|
310
|
+
|
|
311
|
+
def extend(self, callbacks: Iterable[Callback]) -> None:
|
|
312
|
+
"""Register multiple callbacks at once.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
callbacks (Iterable[Callback]): Iterable of callbacks to add.
|
|
316
|
+
"""
|
|
317
|
+
self._callbacks.extend(callbacks)
|
|
318
|
+
|
|
319
|
+
def run(self, stage: Stage, data: dict[str, Any]) -> dict[str, Any]:
|
|
320
|
+
"""Execute all registered callbacks for a specific stage.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
stage (Stage): Lifecycle stage to run (e.g., "on_batch_start").
|
|
324
|
+
data (dict[str, Any]): Event-local data dictionary to pass through callbacks.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
dict[str, Any]: The final merged data dictionary after executing all callbacks.
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
ValueError: If a callback does not implement the requested stage.
|
|
331
|
+
RuntimeError: If any callback raises an exception during execution.
|
|
332
|
+
"""
|
|
333
|
+
for cb in self._callbacks:
|
|
334
|
+
if not hasattr(cb, stage):
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"Callback {cb.__class__.__name__} has no handler for stage {stage}"
|
|
337
|
+
)
|
|
338
|
+
handler = getattr(cb, stage)
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
new_data = handler(data, self.ctx)
|
|
342
|
+
if new_data is not None:
|
|
343
|
+
data = new_data
|
|
344
|
+
|
|
345
|
+
except Exception as e:
|
|
346
|
+
raise RuntimeError(f"Error in callback {cb.__class__.__name__}.{stage}") from e
|
|
347
|
+
|
|
348
|
+
return data
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def callbacks(self) -> tuple[Callback, ...]:
|
|
352
|
+
"""Return a read-only tuple of registered callbacks.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
tuple[Callback, ...]: Registered callbacks.
|
|
356
|
+
"""
|
|
357
|
+
return tuple(self._callbacks)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Holds all callback implementations for use in the training workflow.
|
|
2
|
+
|
|
3
|
+
This module acts as a central registry for defining and storing different
|
|
4
|
+
callback classes, such as logging, checkpointing, or custom behaviors
|
|
5
|
+
triggered during training, validation, or testing. It is intended to
|
|
6
|
+
collect all callback implementations in one place for easy reference
|
|
7
|
+
and import, and can be extended as new callbacks are added.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
11
|
+
|
|
12
|
+
from ..callbacks.base import Callback
|
|
13
|
+
from ..metrics import MetricManager
|
|
14
|
+
from ..utils.utility import CSVLogger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LoggerCallback(Callback):
|
|
18
|
+
"""Callback to log metrics to TensorBoard and CSV after each epoch or test.
|
|
19
|
+
|
|
20
|
+
This callback queries a MetricManager for aggregated metrics, writes them
|
|
21
|
+
to TensorBoard using SummaryWriter, and logs them to a CSV file via CSVLogger.
|
|
22
|
+
It also flushes loggers and resets metrics after logging.
|
|
23
|
+
|
|
24
|
+
Methods implemented:
|
|
25
|
+
- on_epoch_end: Logs metrics at the end of a training epoch.
|
|
26
|
+
- on_test_end: Logs metrics at the end of testing.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
metric_manager: MetricManager,
|
|
32
|
+
tensorboard_logger: SummaryWriter,
|
|
33
|
+
csv_logger: CSVLogger,
|
|
34
|
+
):
|
|
35
|
+
"""Initialize the LoggerCallback.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
metric_manager: Instance of MetricManager used to collect metrics.
|
|
39
|
+
tensorboard_logger: TensorBoard SummaryWriter instance for logging scalars.
|
|
40
|
+
csv_logger: CSVLogger instance for logging metrics to CSV files.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.metric_manager = metric_manager
|
|
44
|
+
self.tensorboard_logger = tensorboard_logger
|
|
45
|
+
self.csv_logger = csv_logger
|
|
46
|
+
|
|
47
|
+
def on_epoch_end(self, data: dict[str, any], ctx: dict[str, any]):
|
|
48
|
+
"""Log training metrics at the end of an epoch.
|
|
49
|
+
|
|
50
|
+
Aggregates metrics from the MetricManager under the 'during_training' category,
|
|
51
|
+
writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics
|
|
52
|
+
for the next epoch.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
data: Dictionary containing batch/epoch context (must include 'epoch').
|
|
56
|
+
ctx: Additional context dictionary (unused in this implementation).
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
data: The same input dictionary, unmodified.
|
|
60
|
+
"""
|
|
61
|
+
epoch = data["epoch"]
|
|
62
|
+
|
|
63
|
+
# Log training metrics
|
|
64
|
+
metrics = self.metric_manager.aggregate("during_training")
|
|
65
|
+
for name, value in metrics.items():
|
|
66
|
+
self.tensorboard_logger.add_scalar(name, value.item(), epoch)
|
|
67
|
+
self.csv_logger.add_value(name, value.item(), epoch)
|
|
68
|
+
|
|
69
|
+
# Flush/save
|
|
70
|
+
self.tensorboard_logger.flush()
|
|
71
|
+
self.csv_logger.save()
|
|
72
|
+
|
|
73
|
+
# Reset metric manager for training
|
|
74
|
+
self.metric_manager.reset("during_training")
|
|
75
|
+
|
|
76
|
+
return data
|
|
77
|
+
|
|
78
|
+
def on_test_end(self, data: dict[str, any], ctx: dict[str, any]):
|
|
79
|
+
"""Log test metrics at the end of testing.
|
|
80
|
+
|
|
81
|
+
Aggregates metrics from the MetricManager under the 'after_training' category,
|
|
82
|
+
writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
data: Dictionary containing test context (must include 'epoch').
|
|
86
|
+
ctx: Additional context dictionary (unused in this implementation).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
data: The same input dictionary, unmodified.
|
|
90
|
+
"""
|
|
91
|
+
epoch = data["epoch"]
|
|
92
|
+
|
|
93
|
+
# Log test metrics
|
|
94
|
+
metrics = self.metric_manager.aggregate("after_training")
|
|
95
|
+
for name, value in metrics.items():
|
|
96
|
+
self.tensorboard_logger.add_scalar(name, value.item(), epoch)
|
|
97
|
+
self.csv_logger.add_value(name, value.item(), epoch)
|
|
98
|
+
|
|
99
|
+
# Flush/save
|
|
100
|
+
self.tensorboard_logger.flush()
|
|
101
|
+
self.csv_logger.save()
|
|
102
|
+
|
|
103
|
+
# Reset metric manager for test
|
|
104
|
+
self.metric_manager.reset("after_training")
|
|
105
|
+
|
|
106
|
+
return data
|
congrads/checkpoints.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Module for managing PyTorch model checkpoints.
|
|
2
|
+
|
|
3
|
+
Provides the `CheckpointManager` class to save and load model and optimizer
|
|
4
|
+
states during training, track the best metric values, and optionally report
|
|
5
|
+
checkpoint events.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from torch import Tensor, load, save
|
|
13
|
+
from torch.nn import Module
|
|
14
|
+
from torch.optim import Optimizer
|
|
15
|
+
|
|
16
|
+
from .metrics import MetricManager
|
|
17
|
+
from .utils.validation import validate_callable, validate_type
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CheckpointManager:
|
|
21
|
+
"""Manage saving and loading checkpoints for PyTorch models and optimizers.
|
|
22
|
+
|
|
23
|
+
Handles checkpointing based on a criteria function, restores metric
|
|
24
|
+
states, and optionally reports when a checkpoint is saved.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
|
|
30
|
+
network: Module,
|
|
31
|
+
optimizer: Optimizer,
|
|
32
|
+
metric_manager: MetricManager,
|
|
33
|
+
save_dir: str = "checkpoints",
|
|
34
|
+
create_dir: bool = False,
|
|
35
|
+
report_save: bool = False,
|
|
36
|
+
):
|
|
37
|
+
"""Initialize the CheckpointManager.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
|
|
41
|
+
Function that determines if the current checkpoint should be
|
|
42
|
+
saved based on the current and best metric values.
|
|
43
|
+
network (torch.nn.Module): The model to save/load.
|
|
44
|
+
optimizer (torch.optim.Optimizer): The optimizer to save/load.
|
|
45
|
+
metric_manager (MetricManager): Manages metric states for checkpointing.
|
|
46
|
+
save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
|
|
47
|
+
create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
|
|
48
|
+
Defaults to False.
|
|
49
|
+
report_save (bool, optional): Whether to report when a checkpoint is saved.
|
|
50
|
+
Defaults to False.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
TypeError: If any provided attribute has an incompatible type.
|
|
54
|
+
FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
|
|
55
|
+
"""
|
|
56
|
+
# Type checking
|
|
57
|
+
validate_callable("criteria_function", criteria_function)
|
|
58
|
+
validate_type("network", network, Module)
|
|
59
|
+
validate_type("optimizer", optimizer, Optimizer)
|
|
60
|
+
validate_type("metric_manager", metric_manager, MetricManager)
|
|
61
|
+
validate_type("create_dir", create_dir, bool)
|
|
62
|
+
validate_type("report_save", report_save, bool)
|
|
63
|
+
|
|
64
|
+
# Create path or raise error if create_dir is not found
|
|
65
|
+
if not os.path.exists(save_dir):
|
|
66
|
+
if not create_dir:
|
|
67
|
+
raise FileNotFoundError(
|
|
68
|
+
f"Save directory '{save_dir}' configured in checkpoint manager is not found."
|
|
69
|
+
)
|
|
70
|
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
71
|
+
|
|
72
|
+
# Initialize objects variables
|
|
73
|
+
self.criteria_function = criteria_function
|
|
74
|
+
self.network = network
|
|
75
|
+
self.optimizer = optimizer
|
|
76
|
+
self.metric_manager = metric_manager
|
|
77
|
+
self.save_dir = save_dir
|
|
78
|
+
self.report_save = report_save
|
|
79
|
+
|
|
80
|
+
self.best_metric_values: dict[str, Tensor] = {}
|
|
81
|
+
|
|
82
|
+
def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
|
|
83
|
+
"""Evaluate the criteria function to determine if a better model is found.
|
|
84
|
+
|
|
85
|
+
Aggregates the current metric values during training and applies the
|
|
86
|
+
criteria function. If the criteria function indicates improvement, the
|
|
87
|
+
best metric values are updated, a checkpoint is saved, and a message is
|
|
88
|
+
optionally printed.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
epoch (int): The current epoch number.
|
|
92
|
+
metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
|
|
93
|
+
"""
|
|
94
|
+
current_metric_values = self.metric_manager.aggregate(metric_group)
|
|
95
|
+
if self.criteria_function is not None and self.criteria_function(
|
|
96
|
+
current_metric_values, self.best_metric_values
|
|
97
|
+
):
|
|
98
|
+
# Print message if a new checkpoint is saved
|
|
99
|
+
if self.report_save:
|
|
100
|
+
print(f"New checkpoint saved at epoch {epoch}.")
|
|
101
|
+
|
|
102
|
+
# Update current best metric values
|
|
103
|
+
for metric_name, metric_value in current_metric_values.items():
|
|
104
|
+
self.best_metric_values[metric_name] = metric_value
|
|
105
|
+
|
|
106
|
+
# Save the current state
|
|
107
|
+
self.save(epoch)
|
|
108
|
+
|
|
109
|
+
def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
|
|
110
|
+
"""Resumes training from a saved checkpoint file.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
filename (str): The name of the checkpoint file to load.
|
|
114
|
+
Defaults to "checkpoint.pth".
|
|
115
|
+
ignore_missing (bool): If True, does not raise an error if the
|
|
116
|
+
checkpoint file is missing and continues without loading,
|
|
117
|
+
starting from epoch 0. Defaults to False.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
int: The epoch number from the loaded checkpoint, or 0 if
|
|
121
|
+
ignore_missing is True and no checkpoint was found.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
125
|
+
FileNotFoundError: If the specified checkpoint file does not exist.
|
|
126
|
+
"""
|
|
127
|
+
# Type checking
|
|
128
|
+
validate_type("filename", filename, str)
|
|
129
|
+
validate_type("ignore_missing", ignore_missing, bool)
|
|
130
|
+
|
|
131
|
+
# Return starting epoch, either from checkpoint file or default
|
|
132
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
133
|
+
if os.path.exists(filepath):
|
|
134
|
+
checkpoint = self.load(filename)
|
|
135
|
+
return checkpoint["epoch"]
|
|
136
|
+
elif ignore_missing:
|
|
137
|
+
return 0
|
|
138
|
+
else:
|
|
139
|
+
raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
|
|
140
|
+
|
|
141
|
+
def save(self, epoch: int, filename: str = "checkpoint.pth"):
|
|
142
|
+
"""Save a checkpoint.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
epoch (int): Current epoch number.
|
|
146
|
+
filename (str): Name of the checkpoint file. Defaults to
|
|
147
|
+
'checkpoint.pth'.
|
|
148
|
+
"""
|
|
149
|
+
state = {
|
|
150
|
+
"epoch": epoch,
|
|
151
|
+
"network_state": self.network.state_dict(),
|
|
152
|
+
"optimizer_state": self.optimizer.state_dict(),
|
|
153
|
+
"best_metrics": self.best_metric_values,
|
|
154
|
+
}
|
|
155
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
156
|
+
save(state, filepath)
|
|
157
|
+
|
|
158
|
+
def load(self, filename: str):
|
|
159
|
+
"""Load a checkpoint and restore the training state.
|
|
160
|
+
|
|
161
|
+
Loads the checkpoint from the specified file and restores the network
|
|
162
|
+
weights, optimizer state, and best metric values.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
filename (str): Name of the checkpoint file.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
dict: A dictionary containing the loaded checkpoint information,
|
|
169
|
+
including epoch, loss, and other relevant training state.
|
|
170
|
+
"""
|
|
171
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
172
|
+
|
|
173
|
+
checkpoint = load(filepath, weights_only=True)
|
|
174
|
+
self.network.load_state_dict(checkpoint["network_state"])
|
|
175
|
+
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
176
|
+
self.best_metric_values = checkpoint["best_metrics"]
|
|
177
|
+
|
|
178
|
+
return checkpoint
|