congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/__init__.py CHANGED
@@ -9,20 +9,3 @@ try:
9
9
  version = get_version("congrads") # Replace with your package name
10
10
  except Exception:
11
11
  version = "0.0.0" # Fallback if the package isn't installed
12
-
13
- # Only expose the submodules, not individual classes
14
- from . import constraints, core, datasets, descriptor, metrics, networks, utils
15
-
16
- # Define __all__ to specify that the submodules are accessible,
17
- # but not classes directly.
18
- __all__ = [
19
- "checkpoints",
20
- "constraints",
21
- "core",
22
- "datasets",
23
- "descriptor",
24
- "metrics",
25
- "networks",
26
- "transformations",
27
- "utils",
28
- ]
@@ -0,0 +1,357 @@
1
+ """Callback and Operation Framework for Modular Training Pipelines.
2
+
3
+ This module provides a structured system for defining and executing
4
+ callbacks and operations at different stages of a training lifecycle.
5
+ It is designed to support:
6
+
7
+ - Stateless, reusable operations that produce outputs merged into
8
+ the event-local data.
9
+ - Callbacks that group operations and/or custom logic for specific
10
+ stages of training, epochs, batches, and steps.
11
+ - A central CallbackManager to orchestrate multiple callbacks,
12
+ maintain shared context, and execute stage-specific pipelines
13
+ in deterministic order.
14
+
15
+ Stages supported:
16
+ - on_train_start
17
+ - on_train_end
18
+ - on_epoch_start
19
+ - on_epoch_end
20
+ - on_batch_start
21
+ - on_batch_end
22
+ - on_test_start
23
+ - on_test_end
24
+ - on_train_batch_start
25
+ - on_train_batch_end
26
+ - on_valid_batch_start
27
+ - on_valid_batch_end
28
+ - on_test_batch_start
29
+ - on_test_batch_end
30
+ - after_train_forward
31
+ - after_valid_forward
32
+ - after_test_forward
33
+
34
+ Usage:
35
+ 1. Define Operations by subclassing `Operation` and implementing
36
+ the `compute` method.
37
+ 2. Create a Callback subclass or instance and register Operations
38
+ to stages via `add(stage, operation)`.
39
+ 3. Register callbacks with `CallbackManager`.
40
+ 4. Invoke `CallbackManager.run(stage, data)` at appropriate points
41
+ in the training loop, passing in event-local data.
42
+ """
43
+
44
+ from abc import ABC, abstractmethod
45
+ from collections.abc import Iterable
46
+ from typing import Any, Literal, Self
47
+
48
+ Stage = Literal[
49
+ "on_train_start",
50
+ "on_train_end",
51
+ "on_epoch_start",
52
+ "on_epoch_end",
53
+ "on_test_start",
54
+ "on_test_end",
55
+ "on_batch_start",
56
+ "on_batch_end",
57
+ "on_train_batch_start",
58
+ "on_train_batch_end",
59
+ "on_valid_batch_start",
60
+ "on_valid_batch_end",
61
+ "on_test_batch_start",
62
+ "on_test_batch_end",
63
+ "after_train_forward",
64
+ "after_valid_forward",
65
+ "after_test_forward",
66
+ ]
67
+
68
+ STAGES: tuple[Stage, ...] = (
69
+ "on_train_start",
70
+ "on_train_end",
71
+ "on_epoch_start",
72
+ "on_epoch_end",
73
+ "on_test_start",
74
+ "on_test_end",
75
+ "on_batch_start",
76
+ "on_batch_end",
77
+ "on_train_batch_start",
78
+ "on_train_batch_end",
79
+ "on_valid_batch_start",
80
+ "on_valid_batch_end",
81
+ "on_test_batch_start",
82
+ "on_test_batch_end",
83
+ "after_train_forward",
84
+ "after_valid_forward",
85
+ "after_test_forward",
86
+ )
87
+
88
+
89
+ class Operation(ABC):
90
+ """Abstract base class representing a stateless unit of work executed inside a callback stage.
91
+
92
+ Subclasses should implement the `compute` method which returns
93
+ a dictionary of outputs to merge into the running event data.
94
+ """
95
+
96
+ def __repr__(self) -> str:
97
+ """Return a concise string representation of the operation."""
98
+ return f"<{self.__class__.__name__}>"
99
+
100
+ def __call__(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
101
+ """Execute the operation with the given event-local data and shared context.
102
+
103
+ Args:
104
+ data (dict[str, Any]): Event-local dictionary containing data for this stage.
105
+ ctx (dict[str, Any]): Shared context dictionary accessible by all operations and callbacks.
106
+
107
+ Returns:
108
+ dict[str, Any]: Outputs produced by the operation to merge into the running data.
109
+ Returns an empty dict if `compute` returns None.
110
+ """
111
+ out = self.compute(data, ctx)
112
+ if out is None:
113
+ return {}
114
+ if not isinstance(out, dict):
115
+ raise TypeError(f"{self.__class__.__name__}.compute must return dict or None")
116
+ return out
117
+
118
+ @abstractmethod
119
+ def compute(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any] | None:
120
+ """Perform the operation's computation.
121
+
122
+ Args:
123
+ data (dict[str, Any]): Event-local dictionary containing the current data.
124
+ ctx (dict[str, Any]): Shared context dictionary.
125
+
126
+ Returns:
127
+ dict[str, Any] or None: Outputs to merge into the running data.
128
+ Returning None is equivalent to {}.
129
+ """
130
+ raise NotImplementedError
131
+
132
+
133
+ class Callback(ABC): # noqa: B024
134
+ """Abstract base class representing a callback that can have multiple operations registered to different stages of the training lifecycle.
135
+
136
+ Each stage method executes all operations registered for that stage
137
+ in insertion order. Operations can modify the event-local data dictionary.
138
+ """
139
+
140
+ def __init__(self):
141
+ """Initialize the callback with empty operation lists for all stages."""
142
+ self._ops_by_stage: dict[Stage, list[Operation]] = {s: [] for s in STAGES}
143
+
144
+ def __repr__(self) -> str:
145
+ """Return a concise string showing number of operations per stage."""
146
+ ops_summary = {stage: len(ops) for stage, ops in self._ops_by_stage.items() if ops}
147
+ return f"<{self.__class__.__name__} ops={ops_summary}>"
148
+
149
+ def add(self, stage: Stage, op: Operation) -> Self:
150
+ """Register an operation to execute at the given stage.
151
+
152
+ Args:
153
+ stage (Stage): Lifecycle stage at which to run the operation.
154
+ op (Operation): Operation instance to add.
155
+
156
+ Returns:
157
+ Self: Returns self for method chaining.
158
+ """
159
+ self._ops_by_stage[stage].append(op)
160
+ return self
161
+
162
+ def _run_ops(self, stage: Stage, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
163
+ """Execute all operations registered for a specific stage.
164
+
165
+ Args:
166
+ stage (Stage): Lifecycle stage to execute.
167
+ data (dict[str, Any]): Event-local data to pass to operations.
168
+ ctx (dict[str, Any]): Shared context across callbacks and operations.
169
+
170
+ Returns:
171
+ dict[str, Any]: Merged data dictionary after executing all operations.
172
+
173
+ Notes:
174
+ - Operations are executed in insertion order.
175
+ - If an operation overwrites existing keys, a warning is issued.
176
+ """
177
+ out = dict(data)
178
+
179
+ for operation in self._ops_by_stage[stage]:
180
+ try:
181
+ produced = operation(out, ctx) or {}
182
+ except Exception as e:
183
+ raise RuntimeError(f"Error in operation {operation} at stage {stage}") from e
184
+
185
+ collisions = set(produced.keys()) & set(out.keys())
186
+ if collisions:
187
+ import warnings
188
+
189
+ warnings.warn(
190
+ f"Operation {operation} at stage '{stage}' is overwriting keys: {collisions}",
191
+ stacklevel=2,
192
+ )
193
+
194
+ out.update(produced)
195
+
196
+ return out
197
+
198
+ # --- training ---
199
+ def on_train_start(self, data: dict[str, Any], ctx: dict[str, Any]):
200
+ """Execute operations registered for the 'on_train_start' stage."""
201
+ self._run_ops("on_train_start", data, ctx)
202
+
203
+ def on_train_end(self, data: dict[str, Any], ctx: dict[str, Any]):
204
+ """Execute operations registered for the 'on_train_end' stage."""
205
+ self._run_ops("on_train_end", data, ctx)
206
+
207
+ # --- epoch ---
208
+ def on_epoch_start(self, data: dict[str, Any], ctx: dict[str, Any]):
209
+ """Execute operations registered for the 'on_epoch_start' stage."""
210
+ self._run_ops("on_epoch_start", data, ctx)
211
+
212
+ def on_epoch_end(self, data: dict[str, Any], ctx: dict[str, Any]):
213
+ """Execute operations registered for the 'on_epoch_end' stage."""
214
+ self._run_ops("on_epoch_end", data, ctx)
215
+
216
+ # --- test ---
217
+ def on_test_start(self, data: dict[str, Any], ctx: dict[str, Any]):
218
+ """Execute operations registered for the 'on_test_start' stage."""
219
+ self._run_ops("on_test_start", data, ctx)
220
+
221
+ def on_test_end(self, data: dict[str, Any], ctx: dict[str, Any]):
222
+ """Execute operations registered for the 'on_test_end' stage."""
223
+ self._run_ops("on_test_end", data, ctx)
224
+
225
+ # --- batch ---
226
+ def on_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
227
+ """Execute operations registered for the 'on_batch_start' stage."""
228
+ return self._run_ops("on_batch_start", data, ctx)
229
+
230
+ def on_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
231
+ """Execute operations registered for the 'on_batch_end' stage."""
232
+ return self._run_ops("on_batch_end", data, ctx)
233
+
234
+ def on_train_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
235
+ """Execute operations registered for the 'on_train_batch_start' stage."""
236
+ return self._run_ops("on_train_batch_start", data, ctx)
237
+
238
+ def on_train_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
239
+ """Execute operations registered for the 'on_train_batch_end' stage."""
240
+ return self._run_ops("on_train_batch_end", data, ctx)
241
+
242
+ def on_valid_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
243
+ """Execute operations registered for the 'on_valid_batch_start' stage."""
244
+ return self._run_ops("on_valid_batch_start", data, ctx)
245
+
246
+ def on_valid_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
247
+ """Execute operations registered for the 'on_valid_batch_end' stage."""
248
+ return self._run_ops("on_valid_batch_end", data, ctx)
249
+
250
+ def on_test_batch_start(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
251
+ """Execute operations registered for the 'on_test_batch_start' stage."""
252
+ return self._run_ops("on_test_batch_start", data, ctx)
253
+
254
+ def on_test_batch_end(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
255
+ """Execute operations registered for the 'on_test_batch_end' stage."""
256
+ return self._run_ops("on_test_batch_end", data, ctx)
257
+
258
+ def after_train_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
259
+ """Execute operations registered for the 'after_train_forward' stage."""
260
+ return self._run_ops("after_train_forward", data, ctx)
261
+
262
+ def after_valid_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
263
+ """Execute operations registered for the 'after_valid_forward' stage."""
264
+ return self._run_ops("after_valid_forward", data, ctx)
265
+
266
+ def after_test_forward(self, data: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]:
267
+ """Execute operations registered for the 'after_test_forward' stage."""
268
+ return self._run_ops("after_test_forward", data, ctx)
269
+
270
+
271
+ class CallbackManager:
272
+ """Orchestrates multiple callbacks and executes them at specific lifecycle stages.
273
+
274
+ - Callbacks are executed in registration order.
275
+ - Event-local data flows through all callbacks.
276
+ - Shared context is available for cross-callback communication.
277
+ """
278
+
279
+ def __init__(self, callbacks: Iterable[Callback] | None = None):
280
+ """Initialize a CallbackManager instance.
281
+
282
+ Args:
283
+ callbacks (Iterable[Callback] | None): Optional initial callbacks to register.
284
+ If None, starts with an empty callback list.
285
+
286
+ Attributes:
287
+ _callbacks (list[Callback]): Internal list of registered callbacks.
288
+ ctx (dict[str, Any]): Shared context dictionary accessible to all callbacks
289
+ and operations for cross-event communication.
290
+ """
291
+ self._callbacks: list[Callback] = list(callbacks) if callbacks else []
292
+ self.ctx: dict[str, Any] = {}
293
+
294
+ def __repr__(self) -> str:
295
+ """Return a concise representation showing registered callbacks and ctx keys."""
296
+ names = [cb.__class__.__name__ for cb in self._callbacks]
297
+ return f"<CallbackManager callbacks={names} ctx_keys={list(self.ctx.keys())}>"
298
+
299
+ def add(self, callback: Callback) -> Self:
300
+ """Register a single callback.
301
+
302
+ Args:
303
+ callback (Callback): Callback instance to add.
304
+
305
+ Returns:
306
+ Self: Returns self for fluent chaining.
307
+ """
308
+ self._callbacks.append(callback)
309
+ return self
310
+
311
+ def extend(self, callbacks: Iterable[Callback]) -> None:
312
+ """Register multiple callbacks at once.
313
+
314
+ Args:
315
+ callbacks (Iterable[Callback]): Iterable of callbacks to add.
316
+ """
317
+ self._callbacks.extend(callbacks)
318
+
319
+ def run(self, stage: Stage, data: dict[str, Any]) -> dict[str, Any]:
320
+ """Execute all registered callbacks for a specific stage.
321
+
322
+ Args:
323
+ stage (Stage): Lifecycle stage to run (e.g., "on_batch_start").
324
+ data (dict[str, Any]): Event-local data dictionary to pass through callbacks.
325
+
326
+ Returns:
327
+ dict[str, Any]: The final merged data dictionary after executing all callbacks.
328
+
329
+ Raises:
330
+ ValueError: If a callback does not implement the requested stage.
331
+ RuntimeError: If any callback raises an exception during execution.
332
+ """
333
+ for cb in self._callbacks:
334
+ if not hasattr(cb, stage):
335
+ raise ValueError(
336
+ f"Callback {cb.__class__.__name__} has no handler for stage {stage}"
337
+ )
338
+ handler = getattr(cb, stage)
339
+
340
+ try:
341
+ new_data = handler(data, self.ctx)
342
+ if new_data is not None:
343
+ data = new_data
344
+
345
+ except Exception as e:
346
+ raise RuntimeError(f"Error in callback {cb.__class__.__name__}.{stage}") from e
347
+
348
+ return data
349
+
350
+ @property
351
+ def callbacks(self) -> tuple[Callback, ...]:
352
+ """Return a read-only tuple of registered callbacks.
353
+
354
+ Returns:
355
+ tuple[Callback, ...]: Registered callbacks.
356
+ """
357
+ return tuple(self._callbacks)
@@ -0,0 +1,106 @@
1
+ """Holds all callback implementations for use in the training workflow.
2
+
3
+ This module acts as a central registry for defining and storing different
4
+ callback classes, such as logging, checkpointing, or custom behaviors
5
+ triggered during training, validation, or testing. It is intended to
6
+ collect all callback implementations in one place for easy reference
7
+ and import, and can be extended as new callbacks are added.
8
+ """
9
+
10
+ from torch.utils.tensorboard import SummaryWriter
11
+
12
+ from ..callbacks.base import Callback
13
+ from ..metrics import MetricManager
14
+ from ..utils.utility import CSVLogger
15
+
16
+
17
+ class LoggerCallback(Callback):
18
+ """Callback to log metrics to TensorBoard and CSV after each epoch or test.
19
+
20
+ This callback queries a MetricManager for aggregated metrics, writes them
21
+ to TensorBoard using SummaryWriter, and logs them to a CSV file via CSVLogger.
22
+ It also flushes loggers and resets metrics after logging.
23
+
24
+ Methods implemented:
25
+ - on_epoch_end: Logs metrics at the end of a training epoch.
26
+ - on_test_end: Logs metrics at the end of testing.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ metric_manager: MetricManager,
32
+ tensorboard_logger: SummaryWriter,
33
+ csv_logger: CSVLogger,
34
+ ):
35
+ """Initialize the LoggerCallback.
36
+
37
+ Args:
38
+ metric_manager: Instance of MetricManager used to collect metrics.
39
+ tensorboard_logger: TensorBoard SummaryWriter instance for logging scalars.
40
+ csv_logger: CSVLogger instance for logging metrics to CSV files.
41
+ """
42
+ super().__init__()
43
+ self.metric_manager = metric_manager
44
+ self.tensorboard_logger = tensorboard_logger
45
+ self.csv_logger = csv_logger
46
+
47
+ def on_epoch_end(self, data: dict[str, any], ctx: dict[str, any]):
48
+ """Log training metrics at the end of an epoch.
49
+
50
+ Aggregates metrics from the MetricManager under the 'during_training' category,
51
+ writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics
52
+ for the next epoch.
53
+
54
+ Args:
55
+ data: Dictionary containing batch/epoch context (must include 'epoch').
56
+ ctx: Additional context dictionary (unused in this implementation).
57
+
58
+ Returns:
59
+ data: The same input dictionary, unmodified.
60
+ """
61
+ epoch = data["epoch"]
62
+
63
+ # Log training metrics
64
+ metrics = self.metric_manager.aggregate("during_training")
65
+ for name, value in metrics.items():
66
+ self.tensorboard_logger.add_scalar(name, value.item(), epoch)
67
+ self.csv_logger.add_value(name, value.item(), epoch)
68
+
69
+ # Flush/save
70
+ self.tensorboard_logger.flush()
71
+ self.csv_logger.save()
72
+
73
+ # Reset metric manager for training
74
+ self.metric_manager.reset("during_training")
75
+
76
+ return data
77
+
78
+ def on_test_end(self, data: dict[str, any], ctx: dict[str, any]):
79
+ """Log test metrics at the end of testing.
80
+
81
+ Aggregates metrics from the MetricManager under the 'after_training' category,
82
+ writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics.
83
+
84
+ Args:
85
+ data: Dictionary containing test context (must include 'epoch').
86
+ ctx: Additional context dictionary (unused in this implementation).
87
+
88
+ Returns:
89
+ data: The same input dictionary, unmodified.
90
+ """
91
+ epoch = data["epoch"]
92
+
93
+ # Log test metrics
94
+ metrics = self.metric_manager.aggregate("after_training")
95
+ for name, value in metrics.items():
96
+ self.tensorboard_logger.add_scalar(name, value.item(), epoch)
97
+ self.csv_logger.add_value(name, value.item(), epoch)
98
+
99
+ # Flush/save
100
+ self.tensorboard_logger.flush()
101
+ self.csv_logger.save()
102
+
103
+ # Reset metric manager for test
104
+ self.metric_manager.reset("after_training")
105
+
106
+ return data
congrads/checkpoints.py CHANGED
@@ -14,7 +14,7 @@ from torch.nn import Module
14
14
  from torch.optim import Optimizer
15
15
 
16
16
  from .metrics import MetricManager
17
- from .utils import validate_callable, validate_type
17
+ from .utils.validation import validate_callable, validate_type
18
18
 
19
19
 
20
20
  class CheckpointManager:
@@ -0,0 +1,174 @@
1
+ """Defines the abstract base class `Constraint` for specifying constraints on neural network outputs.
2
+
3
+ A `Constraint` monitors whether the network predictions satisfy certain
4
+ conditions during training, validation, and testing. It can optionally
5
+ adjust the loss to enforce constraints, and logs the relevant metrics.
6
+
7
+ Responsibilities:
8
+ - Track which network layers/tags the constraint applies to
9
+ - Check constraint satisfaction for a batch of predictions
10
+ - Compute adjustment directions to enforce the constraint
11
+ - Provide a rescale factor and enforcement flag to influence loss adjustment
12
+
13
+ Subclasses must implement the abstract methods:
14
+ - `check_constraint(data)`: Evaluate constraint satisfaction for a batch
15
+ - `calculate_direction(data)`: Compute directions to adjust predictions
16
+ """
17
+
18
+ import random
19
+ import string
20
+ import warnings
21
+ from abc import ABC, abstractmethod
22
+ from numbers import Number
23
+
24
+ from torch import Tensor
25
+
26
+ from congrads.descriptor import Descriptor
27
+ from congrads.utils.validation import validate_iterable, validate_type
28
+
29
+
30
+ class Constraint(ABC):
31
+ """Abstract base class for defining constraints applied to neural networks.
32
+
33
+ A `Constraint` specifies conditions that the neural network outputs
34
+ should satisfy. It supports monitoring constraint satisfaction
35
+ during training and can adjust loss to enforce constraints. Subclasses
36
+ must implement the `check_constraint` and `calculate_direction` methods.
37
+
38
+ Args:
39
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
40
+ name (str, optional): A unique name for the constraint. If not provided,
41
+ a name is generated based on the class name and a random suffix.
42
+ enforce (bool, optional): If False, only monitor the constraint
43
+ without adjusting the loss. Defaults to True.
44
+ rescale_factor (Number, optional): Factor to scale the
45
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
46
+ than 1 to give weight to the constraint.
47
+
48
+ Raises:
49
+ TypeError: If a provided attribute has an incompatible type.
50
+ ValueError: If any tag in `tags` is not
51
+ defined in the `descriptor`.
52
+
53
+ Note:
54
+ - If `rescale_factor <= 1`, a warning is issued.
55
+ - If `name` is not provided, a name is auto-generated,
56
+ and a warning is logged.
57
+
58
+ """
59
+
60
+ descriptor: Descriptor = None
61
+ device = None
62
+
63
+ def __init__(
64
+ self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
65
+ ) -> None:
66
+ """Initializes a new Constraint instance.
67
+
68
+ Args:
69
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
70
+ name (str, optional): A unique name for the constraint. If not
71
+ provided, a name is generated based on the class name and a
72
+ random suffix.
73
+ enforce (bool, optional): If False, only monitor the constraint
74
+ without adjusting the loss. Defaults to True.
75
+ rescale_factor (Number, optional): Factor to scale the
76
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
77
+ than 1 to give weight to the constraint.
78
+
79
+ Raises:
80
+ TypeError: If a provided attribute has an incompatible type.
81
+ ValueError: If any tag in `tags` is not defined in the `descriptor`.
82
+
83
+ Note:
84
+ - If `rescale_factor <= 1`, a warning is issued.
85
+ - If `name` is not provided, a name is auto-generated, and a
86
+ warning is logged.
87
+ """
88
+ # Init parent class
89
+ super().__init__()
90
+
91
+ # Type checking
92
+ validate_iterable("tags", tags, str)
93
+ validate_type("name", name, str, allow_none=True)
94
+ validate_type("enforce", enforce, bool)
95
+ validate_type("rescale_factor", rescale_factor, Number)
96
+
97
+ # Init object variables
98
+ self.tags = tags
99
+ self.rescale_factor = rescale_factor
100
+ self.initial_rescale_factor = rescale_factor
101
+ self.enforce = enforce
102
+
103
+ # Perform checks
104
+ if rescale_factor <= 1:
105
+ warnings.warn(
106
+ f"Rescale factor for constraint {name} is <= 1. The network "
107
+ "will favor general loss over the constraint-adjusted loss. "
108
+ "Is this intended behavior? Normally, the rescale factor "
109
+ "should always be larger than 1.",
110
+ stacklevel=2,
111
+ )
112
+
113
+ # If no constraint_name is set, generate one based
114
+ # on the class name and a random suffix
115
+ if name:
116
+ self.name = name
117
+ else:
118
+ random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
119
+ self.name = f"{self.__class__.__name__}_{random_suffix}"
120
+ warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
121
+
122
+ # Infer layers from descriptor and tags
123
+ self.layers = set()
124
+ for tag in self.tags:
125
+ if not self.descriptor.exists(tag):
126
+ raise ValueError(
127
+ f"The tag {tag} used with constraint "
128
+ f"{self.name} is not defined in the descriptor. Please "
129
+ "add it to the correct layer using "
130
+ "descriptor.add('layer', ...)."
131
+ )
132
+
133
+ layer, _ = self.descriptor.location(tag)
134
+ self.layers.add(layer)
135
+
136
+ @abstractmethod
137
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
138
+ """Evaluates whether the given model predictions satisfy the constraint.
139
+
140
+ 1 IS SATISFIED, 0 IS NOT SATISFIED
141
+
142
+ Args:
143
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
144
+
145
+ Returns:
146
+ tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
147
+ indicating whether the constraint is satisfied (with value 1.0
148
+ for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
149
+ mask that indicates the relevance of each sample (`True` for relevant
150
+ samples and `False` for irrelevant ones).
151
+
152
+ Raises:
153
+ NotImplementedError: If not implemented in a subclass.
154
+ """
155
+ raise NotImplementedError
156
+
157
+ @abstractmethod
158
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
159
+ """Compute adjustment directions to better satisfy the constraint.
160
+
161
+ Given the model predictions, input batch, and context, this method calculates the direction
162
+ in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
163
+
164
+ Args:
165
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
166
+
167
+ Returns:
168
+ dict[str, Tensor]: Dictionary mapping network layers to tensors that
169
+ specify the adjustment direction for each tag.
170
+
171
+ Raises:
172
+ NotImplementedError: Must be implemented by subclasses.
173
+ """
174
+ raise NotImplementedError