congrads 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/__init__.py CHANGED
@@ -1,22 +1,11 @@
1
- # __init__.py
2
- version = "0.2.0"
1
+ try: # noqa: D104
2
+ from importlib.metadata import version as get_version # Python 3.8+
3
+ except ImportError:
4
+ from pkg_resources import (
5
+ get_distribution as get_version,
6
+ ) # Fallback for older versions
3
7
 
4
- # Only expose the submodules, not individual classes
5
- from . import constraints
6
- from . import core
7
- from . import datasets
8
- from . import descriptor
9
- from . import metrics
10
- from . import networks
11
- from . import utils
12
-
13
- # Define __all__ to specify that the submodules are accessible, but not classes directly.
14
- __all__ = [
15
- "constraints",
16
- "core",
17
- "datasets",
18
- "descriptor",
19
- "metrics",
20
- "networks",
21
- "utils",
22
- ]
8
+ try:
9
+ version = get_version("congrads") # Replace with your package name
10
+ except Exception:
11
+ version = "0.0.0" # Fallback if the package isn't installed
@@ -0,0 +1,360 @@
1
+ """Callback and Operation Framework for Modular Training Pipelines.
2
+
3
+ This module provides a structured system for defining and executing
4
+ callbacks and operations at different stages of a training lifecycle.
5
+ It is designed to support:
6
+
7
+ - Stateless, reusable operations that produce outputs merged into
8
+ the event-local data.
9
+ - Callbacks that group operations and/or custom logic for specific
10
+ stages of training, epochs, batches, and steps.
11
+ - A central CallbackManager to orchestrate multiple callbacks,
12
+ maintain shared context, and execute stage-specific pipelines
13
+ in deterministic order.
14
+
15
+ Stages supported:
16
+ - on_train_start
17
+ - on_train_end
18
+ - on_epoch_start
19
+ - on_epoch_end
20
+ - on_batch_start
21
+ - on_batch_end
22
+ - on_test_start
23
+ - on_test_end
24
+ - on_train_batch_start
25
+ - on_train_batch_end
26
+ - on_valid_batch_start
27
+ - on_valid_batch_end
28
+ - on_test_batch_start
29
+ - on_test_batch_end
30
+ - after_train_forward
31
+ - after_valid_forward
32
+ - after_test_forward
33
+
34
+ Usage:
35
+ 1. Define Operations by subclassing `Operation` and implementing
36
+ the `compute` method.
37
+ 2. Create a Callback subclass or instance and register Operations
38
+ to stages via `add(stage, operation)`.
39
+ 3. Register callbacks with `CallbackManager`.
40
+ 4. Invoke `CallbackManager.run(stage, data)` at appropriate points
41
+ in the training loop, passing in event-local data.
42
+ """
43
+
44
+ from abc import ABC, abstractmethod
45
+ from collections.abc import Iterable
46
+ from typing import Any, Literal, Self
47
+
48
+ __all__ = ["Callback", "CallbackManager", "Operation"]
49
+
50
+
51
+ Stage = Literal[
52
+ "on_train_start",
53
+ "on_train_end",
54
+ "on_epoch_start",
55
+ "on_epoch_end",
56
+ "on_test_start",
57
+ "on_test_end",
58
+ "on_batch_start",
59
+ "on_batch_end",
60
+ "on_train_batch_start",
61
+ "on_train_batch_end",
62
+ "on_valid_batch_start",
63
+ "on_valid_batch_end",
64
+ "on_test_batch_start",
65
+ "on_test_batch_end",
66
+ "after_train_forward",
67
+ "after_valid_forward",
68
+ "after_test_forward",
69
+ ]
70
+
71
+ STAGES: tuple[Stage, ...] = (
72
+ "on_train_start",
73
+ "on_train_end",
74
+ "on_epoch_start",
75
+ "on_epoch_end",
76
+ "on_test_start",
77
+ "on_test_end",
78
+ "on_batch_start",
79
+ "on_batch_end",
80
+ "on_train_batch_start",
81
+ "on_train_batch_end",
82
+ "on_valid_batch_start",
83
+ "on_valid_batch_end",
84
+ "on_test_batch_start",
85
+ "on_test_batch_end",
86
+ "after_train_forward",
87
+ "after_valid_forward",
88
+ "after_test_forward",
89
+ )
90
+
91
+
92
+ class Operation(ABC):
93
+ """Abstract base class representing a stateless unit of work executed inside a callback stage.
94
+
95
+ Subclasses should implement the `compute` method which returns
96
+ a dictionary of outputs to merge into the running event data.
97
+ """
98
+
99
+ def __repr__(self) -> str:
100
+ """Return a concise string representation of the operation."""
101
+ return f"<{self.__class__.__name__}>"
102
+
103
+ def __call__(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
104
+ """Execute the operation with the given event-local data and shared context.
105
+
106
+ Args:
107
+ data (dict[str, Any]): Event-local dictionary containing data for this stage.
108
+ ctx (dict[str, Any]): Shared context dictionary accessible by all operations and callbacks.
109
+
110
+ Returns:
111
+ dict[str, Any]: Outputs produced by the operation to merge into the running data.
112
+ Returns an empty dict if `compute` returns None.
113
+ """
114
+ out = self.compute(data, ctx)
115
+ if out is None:
116
+ return {}
117
+ if not isinstance(out, dict):
118
+ raise TypeError(f"{self.__class__.__name__}.compute must return dict or None")
119
+ return out
120
+
121
+ @abstractmethod
122
+ def compute(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any] | None:
123
+ """Perform the operation's computation.
124
+
125
+ Args:
126
+ data (dict[str, Any]): Event-local dictionary containing the current data.
127
+ ctx (dict[str, Any]): Shared context dictionary.
128
+
129
+ Returns:
130
+ dict[str, Any] or None: Outputs to merge into the running data.
131
+ Returning None is equivalent to {}.
132
+ """
133
+ raise NotImplementedError
134
+
135
+
136
+ class Callback(ABC): # noqa: B024
137
+ """Abstract base class representing a callback that can have multiple operations registered to different stages of the training lifecycle.
138
+
139
+ Each stage method executes all operations registered for that stage
140
+ in insertion order. Operations can modify the event-local data dictionary.
141
+ """
142
+
143
+ def __init__(self):
144
+ """Initialize the callback with empty operation lists for all stages."""
145
+ self._ops_by_stage: dict[Stage, list[Operation]] = {s: [] for s in STAGES}
146
+
147
+ def __repr__(self) -> str:
148
+ """Return a concise string showing number of operations per stage."""
149
+ ops_summary = {stage: len(ops) for stage, ops in self._ops_by_stage.items() if ops}
150
+ return f"<{self.__class__.__name__} ops={ops_summary}>"
151
+
152
+ def add(self, stage: Stage, op: Operation) -> Self:
153
+ """Register an operation to execute at the given stage.
154
+
155
+ Args:
156
+ stage (Stage): Lifecycle stage at which to run the operation.
157
+ op (Operation): Operation instance to add.
158
+
159
+ Returns:
160
+ Self: Returns self for method chaining.
161
+ """
162
+ self._ops_by_stage[stage].append(op)
163
+ return self
164
+
165
+ def _run_ops(self, stage: Stage, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
166
+ """Execute all operations registered for a specific stage.
167
+
168
+ Args:
169
+ stage (Stage): Lifecycle stage to execute.
170
+ data (dict[str, Any]): Event-local data to pass to operations.
171
+ ctx (dict[str, Any]): Shared context across callbacks and operations.
172
+
173
+ Returns:
174
+ dict[str, Any]: Merged data dictionary after executing all operations.
175
+
176
+ Notes:
177
+ - Operations are executed in insertion order.
178
+ - If an operation overwrites existing keys, a warning is issued.
179
+ """
180
+ out = dict(data)
181
+
182
+ for operation in self._ops_by_stage[stage]:
183
+ try:
184
+ produced = operation(out, ctx) or {}
185
+ except Exception as e:
186
+ raise RuntimeError(f"Error in operation {operation} at stage {stage}") from e
187
+
188
+ collisions = set(produced.keys()) & set(out.keys())
189
+ if collisions:
190
+ import warnings
191
+
192
+ warnings.warn(
193
+ f"Operation {operation} at stage '{stage}' is overwriting keys: {collisions}",
194
+ stacklevel=2,
195
+ )
196
+
197
+ out.update(produced)
198
+
199
+ return out
200
+
201
+ # --- training ---
202
+ def on_train_start(self, data: dict[str, Any], ctx: dict[str, Any]):
203
+ """Execute operations registered for the 'on_train_start' stage."""
204
+ self._run_ops("on_train_start", data, ctx)
205
+
206
+ def on_train_end(self, data: dict[str, Any], ctx: dict[str, Any]):
207
+ """Execute operations registered for the 'on_train_end' stage."""
208
+ self._run_ops("on_train_end", data, ctx)
209
+
210
+ # --- epoch ---
211
+ def on_epoch_start(self, data: dict[str, Any], ctx: dict[str, Any]):
212
+ """Execute operations registered for the 'on_epoch_start' stage."""
213
+ self._run_ops("on_epoch_start", data, ctx)
214
+
215
+ def on_epoch_end(self, data: dict[str, Any], ctx: dict[str, Any]):
216
+ """Execute operations registered for the 'on_epoch_end' stage."""
217
+ self._run_ops("on_epoch_end", data, ctx)
218
+
219
+ # --- test ---
220
+ def on_test_start(self, data: dict[str, Any], ctx: dict[str, Any]):
221
+ """Execute operations registered for the 'on_test_start' stage."""
222
+ self._run_ops("on_test_start", data, ctx)
223
+
224
+ def on_test_end(self, data: dict[str, Any], ctx: dict[str, Any]):
225
+ """Execute operations registered for the 'on_test_end' stage."""
226
+ self._run_ops("on_test_end", data, ctx)
227
+
228
+ # --- batch ---
229
+ def on_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
230
+ """Execute operations registered for the 'on_batch_start' stage."""
231
+ return self._run_ops("on_batch_start", data, ctx)
232
+
233
+ def on_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
234
+ """Execute operations registered for the 'on_batch_end' stage."""
235
+ return self._run_ops("on_batch_end", data, ctx)
236
+
237
+ def on_train_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
238
+ """Execute operations registered for the 'on_train_batch_start' stage."""
239
+ return self._run_ops("on_train_batch_start", data, ctx)
240
+
241
+ def on_train_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
242
+ """Execute operations registered for the 'on_train_batch_end' stage."""
243
+ return self._run_ops("on_train_batch_end", data, ctx)
244
+
245
+ def on_valid_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
246
+ """Execute operations registered for the 'on_valid_batch_start' stage."""
247
+ return self._run_ops("on_valid_batch_start", data, ctx)
248
+
249
+ def on_valid_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
250
+ """Execute operations registered for the 'on_valid_batch_end' stage."""
251
+ return self._run_ops("on_valid_batch_end", data, ctx)
252
+
253
+ def on_test_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
254
+ """Execute operations registered for the 'on_test_batch_start' stage."""
255
+ return self._run_ops("on_test_batch_start", data, ctx)
256
+
257
+ def on_test_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
258
+ """Execute operations registered for the 'on_test_batch_end' stage."""
259
+ return self._run_ops("on_test_batch_end", data, ctx)
260
+
261
+ def after_train_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
262
+ """Execute operations registered for the 'after_train_forward' stage."""
263
+ return self._run_ops("after_train_forward", data, ctx)
264
+
265
+ def after_valid_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
266
+ """Execute operations registered for the 'after_valid_forward' stage."""
267
+ return self._run_ops("after_valid_forward", data, ctx)
268
+
269
+ def after_test_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
270
+ """Execute operations registered for the 'after_test_forward' stage."""
271
+ return self._run_ops("after_test_forward", data, ctx)
272
+
273
+
274
+ class CallbackManager:
275
+ """Orchestrates multiple callbacks and executes them at specific lifecycle stages.
276
+
277
+ - Callbacks are executed in registration order.
278
+ - Event-local data flows through all callbacks.
279
+ - Shared context is available for cross-callback communication.
280
+ """
281
+
282
+ def __init__(self, callbacks: Iterable[Callback] | None = None):
283
+ """Initialize a CallbackManager instance.
284
+
285
+ Args:
286
+ callbacks (Iterable[Callback] | None): Optional initial callbacks to register.
287
+ If None, starts with an empty callback list.
288
+
289
+ Attributes:
290
+ _callbacks (list[Callback]): Internal list of registered callbacks.
291
+ ctx (dict[str, Any]): Shared context dictionary accessible to all callbacks
292
+ and operations for cross-event communication.
293
+ """
294
+ self._callbacks: list[Callback] = list(callbacks) if callbacks else []
295
+ self.ctx: dict[str, Any] = {}
296
+
297
+ def __repr__(self) -> str:
298
+ """Return a concise representation showing registered callbacks and ctx keys."""
299
+ names = [cb.__class__.__name__ for cb in self._callbacks]
300
+ return f"<CallbackManager callbacks={names} ctx_keys={list(self.ctx.keys())}>"
301
+
302
+ def add(self, callback: Callback) -> Self:
303
+ """Register a single callback.
304
+
305
+ Args:
306
+ callback (Callback): Callback instance to add.
307
+
308
+ Returns:
309
+ Self: Returns self for fluent chaining.
310
+ """
311
+ self._callbacks.append(callback)
312
+ return self
313
+
314
+ def extend(self, callbacks: Iterable[Callback]) -> None:
315
+ """Register multiple callbacks at once.
316
+
317
+ Args:
318
+ callbacks (Iterable[Callback]): Iterable of callbacks to add.
319
+ """
320
+ self._callbacks.extend(callbacks)
321
+
322
+ def run(self, stage: Stage, data: dict[str, Any]) -> dict[str, Any]:
323
+ """Execute all registered callbacks for a specific stage.
324
+
325
+ Args:
326
+ stage (Stage): Lifecycle stage to run (e.g., "on_batch_start").
327
+ data (dict[str, Any]): Event-local data dictionary to pass through callbacks.
328
+
329
+ Returns:
330
+ dict[str, Any]: The final merged data dictionary after executing all callbacks.
331
+
332
+ Raises:
333
+ ValueError: If a callback does not implement the requested stage.
334
+ RuntimeError: If any callback raises an exception during execution.
335
+ """
336
+ for cb in self._callbacks:
337
+ if not hasattr(cb, stage):
338
+ raise ValueError(
339
+ f"Callback {cb.__class__.__name__} has no handler for stage {stage}"
340
+ )
341
+ handler = getattr(cb, stage)
342
+
343
+ try:
344
+ new_data = handler(data, self.ctx)
345
+ if new_data is not None:
346
+ data = new_data
347
+
348
+ except Exception as e:
349
+ raise RuntimeError(f"Error in callback {cb.__class__.__name__}.{stage}") from e
350
+
351
+ return data
352
+
353
+ @property
354
+ def callbacks(self) -> tuple[Callback, ...]:
355
+ """Return a read-only tuple of registered callbacks.
356
+
357
+ Returns:
358
+ tuple[Callback, ...]: Registered callbacks.
359
+ """
360
+ return tuple(self._callbacks)
@@ -0,0 +1,165 @@
1
+ """Holds all callback implementations for use in the training workflow.
2
+
3
+ This module acts as a central registry for defining and storing different
4
+ callback classes, such as logging, checkpointing, or custom behaviors
5
+ triggered during training, validation, or testing. It is intended to
6
+ collect all callback implementations in one place for easy reference
7
+ and import, and can be extended as new callbacks are added.
8
+ """
9
+
10
+ from torch import Tensor
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+ from ..callbacks.base import Callback
14
+ from ..metrics import MetricManager
15
+ from ..utils.utility import CSVLogger
16
+
17
+ __all__ = ["LoggerCallback"]
18
+
19
+
20
+ class LoggerCallback(Callback):
21
+ """Callback to periodically aggregate and store metrics during training and testing.
22
+
23
+ This callback works in conjunction with a MetricManager that accumulates metrics
24
+ internally (e.g. per batch). Metrics are:
25
+
26
+ - Aggregated at a configurable epoch interval (`aggregate_interval`)
27
+ - Cached in memory (GPU-resident tensors)
28
+ - Written to TensorBoard and CSV at a separate interval (`store_interval`)
29
+
30
+ Aggregation and storage are decoupled to avoid unnecessary GPU-to-CPU
31
+ synchronization. Any remaining cached metrics are flushed at the end of training.
32
+
33
+ Methods implemented:
34
+ - on_epoch_end: Periodically aggregates and stores training metrics.
35
+ - on_train_end: Flushes any remaining cached training metrics.
36
+ - on_test_end: Aggregates and stores test metrics immediately.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ metric_manager: MetricManager,
42
+ tensorboard_logger: SummaryWriter,
43
+ csv_logger: CSVLogger,
44
+ *,
45
+ aggregate_interval: int = 1,
46
+ store_interval: int = 1,
47
+ ):
48
+ """Initialize the LoggerCallback.
49
+
50
+ Args:
51
+ metric_manager: Instance of MetricManager used to collect metrics.
52
+ tensorboard_logger: TensorBoard SummaryWriter instance for logging scalars.
53
+ csv_logger: CSVLogger instance for logging metrics to CSV files.
54
+ aggregate_interval: Number of epochs between metric aggregation.
55
+ store_interval: Number of epochs between metric storage.
56
+ """
57
+ super().__init__()
58
+
59
+ # Input validation
60
+ if aggregate_interval <= 0 or store_interval <= 0:
61
+ raise ValueError("Intervals must be positive integers")
62
+
63
+ if store_interval % aggregate_interval != 0:
64
+ raise ValueError("store_interval must be a multiple of aggregate_interval")
65
+
66
+ # Store references
67
+ self.metric_manager = metric_manager
68
+ self.tensorboard_logger = tensorboard_logger
69
+ self.csv_logger = csv_logger
70
+ self.aggregate_interval = aggregate_interval
71
+ self.store_interval = store_interval
72
+
73
+ # Cached metrics on GPU by epoch
74
+ self._accumulated_metrics: dict[int, dict[str, Tensor]] = {}
75
+
76
+ def on_epoch_end(self, data: dict[str, any], ctx: dict[str, any]):
77
+ """Handle end-of-epoch training logic.
78
+
79
+ At the end of each epoch, this method may:
80
+ - Aggregate training metrics from the MetricManager (every `aggregate_interval` epochs)
81
+ - Cache aggregated metrics keyed by epoch
82
+ - Store cached metrics to disk (every `store_interval` epochs)
83
+
84
+ Metric aggregation resets the MetricManager accumulation state.
85
+ Metric storage triggers GPU-to-CPU synchronization and writes to loggers.
86
+
87
+ Args:
88
+ data: Dictionary containing epoch context (must include 'epoch').
89
+ ctx: Additional context dictionary (unused).
90
+
91
+ Returns:
92
+ data: The same input dictionary, unmodified.
93
+ """
94
+ epoch = data["epoch"]
95
+
96
+ # Cache training metrics
97
+ if epoch % self.aggregate_interval == 0:
98
+ metrics = self.metric_manager.aggregate("during_training")
99
+ self._accumulated_metrics[epoch] = metrics
100
+ self.metric_manager.reset("during_training")
101
+
102
+ # Store metrics to disk
103
+ if epoch % self.store_interval == 0:
104
+ self._save(self._accumulated_metrics)
105
+ self._accumulated_metrics.clear()
106
+
107
+ return data
108
+
109
+ def on_train_end(self, data, ctx):
110
+ """Flush any remaining cached training metrics at the end of training.
111
+
112
+ This ensures that aggregated metrics that were not yet written due to
113
+ `store_interval` alignment are persisted before training terminates.
114
+
115
+ Args:
116
+ data: Dictionary containing training context (unused).
117
+ ctx: Additional context dictionary (unused).
118
+
119
+ Returns:
120
+ data: The same input dictionary, unmodified.
121
+ """
122
+ if self._accumulated_metrics:
123
+ self._save(self._accumulated_metrics)
124
+ self._accumulated_metrics.clear()
125
+
126
+ return data
127
+
128
+ def on_test_end(self, data: dict[str, any], ctx: dict[str, any]):
129
+ """Aggregate and store test metrics at the end of testing.
130
+
131
+ Test metrics are aggregated once and written immediately to disk.
132
+ Interval-based aggregation and caching are not applied to testing.
133
+
134
+ Args:
135
+ data: Dictionary containing test context (must include 'epoch').
136
+ ctx: Additional context dictionary (unused).
137
+
138
+ Returns:
139
+ data: The same input dictionary, unmodified.
140
+ """
141
+ epoch = data["epoch"]
142
+
143
+ # Save test metrics
144
+ metrics = self.metric_manager.aggregate("after_training")
145
+ self._save({epoch: metrics})
146
+ self.metric_manager.reset("after_training")
147
+
148
+ return data
149
+
150
+ def _save(self, metrics: dict[int, dict[str, Tensor]]):
151
+ """Write aggregated metrics to TensorBoard and CSV loggers.
152
+
153
+ Args:
154
+ metrics: Mapping from epoch to a dictionary of metric name to scalar tensor.
155
+ Tensors are expected to be detached and graph-free.
156
+ """
157
+ for epoch, metrics_by_name in metrics.items():
158
+ for name, value in metrics_by_name.items():
159
+ cpu_value = value.item()
160
+ self.tensorboard_logger.add_scalar(name, cpu_value, epoch)
161
+ self.csv_logger.add_value(name, cpu_value, epoch)
162
+
163
+ # Flush/save
164
+ self.tensorboard_logger.flush()
165
+ self.csv_logger.save()