congrads 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,271 @@
1
+ """This module provides the core CongradsCore class for the main training functionality.
2
+
3
+ It is designed to integrate constraint-guided optimization into neural network training.
4
+ It extends traditional training processes by enforcing specific constraints
5
+ on the model's outputs, ensuring that the network satisfies domain-specific
6
+ requirements during both training and evaluation.
7
+
8
+ The `CongradsCore` class serves as the central engine for managing the
9
+ training, validation, and testing phases of a neural network model,
10
+ incorporating constraints that influence the loss function and model updates.
11
+ The model is trained with standard loss functions while also incorporating
12
+ constraint-based adjustments, which are tracked and logged
13
+ throughout the process.
14
+
15
+ Key features:
16
+ - Support for various constraints that can influence the training process.
17
+ - Integration with PyTorch's `DataLoader` for efficient batch processing.
18
+ - Metric management for tracking loss and constraint satisfaction.
19
+ - Checkpoint management for saving and evaluating model states.
20
+
21
+ """
22
+
23
+ from collections.abc import Callable
24
+
25
+ import torch
26
+ from torch import Tensor, sum
27
+ from torch.nn import Module
28
+ from torch.nn.modules.loss import _Loss
29
+ from torch.optim import Optimizer
30
+ from torch.utils.data import DataLoader
31
+ from tqdm import tqdm
32
+
33
+ from congrads.utils.utility import LossWrapper
34
+
35
+ from ..callbacks.base import CallbackManager
36
+ from ..checkpoints import CheckpointManager
37
+ from ..constraints.base import Constraint
38
+ from ..core.batch_runner import BatchRunner
39
+ from ..core.constraint_engine import ConstraintEngine
40
+ from ..core.epoch_runner import EpochRunner
41
+ from ..descriptor import Descriptor
42
+ from ..metrics import MetricManager
43
+
44
+
45
+ class CongradsCore:
46
+ """The CongradsCore class is the central training engine for constraint-guided optimization.
47
+
48
+ It integrates standard neural network training
49
+ with additional constraint-driven adjustments to the loss function, ensuring
50
+ that the network satisfies domain-specific constraints during training.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ descriptor: Descriptor,
56
+ constraints: list[Constraint],
57
+ network: Module,
58
+ criterion: _Loss,
59
+ optimizer: Optimizer,
60
+ device: torch.device,
61
+ dataloader_train: DataLoader,
62
+ dataloader_valid: DataLoader | None = None,
63
+ dataloader_test: DataLoader | None = None,
64
+ metric_manager: MetricManager | None = None,
65
+ callback_manager: CallbackManager | None = None,
66
+ checkpoint_manager: CheckpointManager | None = None,
67
+ network_uses_grad: bool = False,
68
+ epsilon: float = 1e-6,
69
+ constraint_aggregator: Callable[..., Tensor] = sum,
70
+ enforce_all: bool = True,
71
+ disable_progress_bar_epoch: bool = False,
72
+ disable_progress_bar_batch: bool = False,
73
+ epoch_runner_cls: type["EpochRunner"] | None = None,
74
+ batch_runner_cls: type["BatchRunner"] | None = None,
75
+ constraint_engine_cls: type["ConstraintEngine"] | None = None,
76
+ ):
77
+ """Initialize the CongradsCore object.
78
+
79
+ Args:
80
+ descriptor (Descriptor): Describes variable layers in the network.
81
+ constraints (list[Constraint]): List of constraints to guide training.
82
+ network (Module): The neural network model to train.
83
+ criterion (callable): The loss function used for
84
+ training and validation.
85
+ optimizer (Optimizer): The optimizer used for updating model parameters.
86
+ device (torch.device): The device (e.g., CPU or GPU) for computations.
87
+ dataloader_train (DataLoader): DataLoader for training data.
88
+ dataloader_valid (DataLoader, optional): DataLoader for validation data.
89
+ If not provided, validation is skipped.
90
+ dataloader_test (DataLoader, optional): DataLoader for test data.
91
+ If not provided, testing is skipped.
92
+ metric_manager (MetricManager, optional): Manages metric tracking and recording.
93
+ callback_manager (CallbackManager, optional): Manages training callbacks.
94
+ checkpoint_manager (CheckpointManager, optional): Manages
95
+ checkpointing. If not set, no checkpointing is done.
96
+ network_uses_grad (bool, optional): A flag indicating if the network
97
+ contains gradient calculation computations. Default is False.
98
+ epsilon (float, optional): A small value to avoid division by zero
99
+ in gradient calculations. Default is 1e-10.
100
+ constraint_aggregator (Callable[..., Tensor], optional): A function
101
+ to aggregate the constraint rescale loss. Default is `sum`.
102
+ enforce_all (bool, optional): If set to False, constraints will only be monitored and
103
+ not influence the training process. Overrides constraint-specific `enforce` parameters.
104
+ Defaults to True.
105
+ disable_progress_bar_epoch (bool, optional): If set to True, the epoch
106
+ progress bar will not show. Defaults to False.
107
+ disable_progress_bar_batch (bool, optional): If set to True, the batch
108
+ progress bar will not show. Defaults to False.
109
+ epoch_runner_cls (type[EpochRunner], optional): Custom EpochRunner class.
110
+ If not provided, the default EpochRunner is used.
111
+ batch_runner_cls (type[BatchRunner], optional): Custom BatchRunner class.
112
+ If not provided, the default BatchRunner is used.
113
+ constraint_engine_cls (type[ConstraintEngine], optional): Custom ConstraintEngine class.
114
+ If not provided, the default ConstraintEngine is used.
115
+
116
+ Note:
117
+ A warning is logged if the descriptor has no variable layers,
118
+ as at least one variable layer is required for the constraint logic
119
+ to influence the training process.
120
+ """
121
+ # Init object variables
122
+ self.device = device
123
+ self.network = network.to(device)
124
+ self.criterion = LossWrapper(criterion)
125
+ self.optimizer = optimizer
126
+ self.descriptor = descriptor
127
+
128
+ self.constraints = constraints or []
129
+ self.epsilon = epsilon
130
+ self.aggregator = constraint_aggregator
131
+ self.enforce_all = enforce_all
132
+
133
+ self.metric_manager = metric_manager
134
+ self.callback_manager = callback_manager
135
+ self.checkpoint_manager = checkpoint_manager
136
+
137
+ self.disable_progress_bar_epoch = disable_progress_bar_epoch
138
+ self.disable_progress_bar_batch = disable_progress_bar_batch
139
+
140
+ # Initialize constraint engine
141
+ self.constraint_engine = (constraint_engine_cls or ConstraintEngine)(
142
+ constraints=self.constraints,
143
+ descriptor=self.descriptor,
144
+ device=self.device,
145
+ epsilon=self.epsilon,
146
+ aggregator=self.aggregator,
147
+ enforce_all=self.enforce_all,
148
+ metric_manager=self.metric_manager,
149
+ )
150
+
151
+ # Initialize runners
152
+ self.batch_runner = (batch_runner_cls or BatchRunner)(
153
+ network=self.network,
154
+ criterion=self.criterion,
155
+ optimizer=self.optimizer,
156
+ constraint_engine=self.constraint_engine,
157
+ metric_manager=self.metric_manager,
158
+ callback_manager=self.callback_manager,
159
+ device=self.device,
160
+ )
161
+
162
+ self.epoch_runner = (epoch_runner_cls or EpochRunner)(
163
+ batch_runner=self.batch_runner,
164
+ network=self.network,
165
+ train_loader=dataloader_train,
166
+ valid_loader=dataloader_valid,
167
+ test_loader=dataloader_test,
168
+ network_uses_grad=network_uses_grad,
169
+ disable_progress_bar=self.disable_progress_bar_batch,
170
+ )
171
+
172
+ # Initialize constraint metrics
173
+ if self.metric_manager is not None:
174
+ self._initialize_metrics()
175
+
176
+ def _initialize_metrics(self) -> None:
177
+ """Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
178
+
179
+ This method registers the following metrics:
180
+
181
+ - Loss/train: Training loss.
182
+ - Loss/valid: Validation loss.
183
+ - Loss/test: Test loss after training.
184
+ - CSR/train: Constraint satisfaction ratio during training.
185
+ - CSR/valid: Constraint satisfaction ratio during validation.
186
+ - CSR/test: Constraint satisfaction ratio after training.
187
+ - One metric per constraint, for both training and validation.
188
+
189
+ """
190
+ self.metric_manager.register("Loss/train", "during_training")
191
+ self.metric_manager.register("Loss/valid", "during_training")
192
+ self.metric_manager.register("Loss/test", "after_training")
193
+
194
+ if len(self.constraints) > 0:
195
+ self.metric_manager.register("CSR/train", "during_training")
196
+ self.metric_manager.register("CSR/valid", "during_training")
197
+ self.metric_manager.register("CSR/test", "after_training")
198
+
199
+ for constraint in self.constraints:
200
+ self.metric_manager.register(f"{constraint.name}/train", "during_training")
201
+ self.metric_manager.register(f"{constraint.name}/valid", "during_training")
202
+ self.metric_manager.register(f"{constraint.name}/test", "after_training")
203
+
204
+ def fit(
205
+ self,
206
+ *,
207
+ start_epoch: int = 0,
208
+ max_epochs: int = 100,
209
+ test_model: bool = True,
210
+ final_checkpoint_name: str = "checkpoint_final.pth",
211
+ ) -> None:
212
+ """Run the full training loop, including optional validation, testing, and checkpointing.
213
+
214
+ This method performs training over multiple epochs with the following steps:
215
+ 1. Trigger "on_train_start" callbacks if a callback manager is present.
216
+ 2. For each epoch:
217
+ - Trigger "on_epoch_start" callbacks.
218
+ - Run a training epoch via the EpochRunner.
219
+ - Run a validation epoch via the EpochRunner.
220
+ - Evaluate checkpoint criteria if a checkpoint manager is present.
221
+ - Trigger "on_epoch_end" callbacks.
222
+ 3. Trigger "on_train_end" callbacks after all epochs.
223
+ 4. Optionally run a test epoch via the EpochRunner if `test_model` is True,
224
+ with corresponding "on_test_start" and "on_test_end" callbacks.
225
+ 5. Save a final checkpoint using the checkpoint manager.
226
+
227
+ Args:
228
+ start_epoch: Index of the starting epoch (default 0). Useful for resuming training.
229
+ max_epochs: Maximum number of epochs to run (default 100).
230
+ test_model: Whether to run a test epoch after training (default True).
231
+ final_checkpoint_name: Filename for the final checkpoint saved at the end of training
232
+ (default "checkpoint_final.pth").
233
+
234
+ Returns:
235
+ None
236
+ """
237
+ if self.callback_manager:
238
+ self.callback_manager.run("on_train_start", {"epoch": start_epoch})
239
+
240
+ for epoch in tqdm(
241
+ range(start_epoch, max_epochs),
242
+ initial=start_epoch,
243
+ desc="Epoch",
244
+ disable=self.disable_progress_bar_epoch,
245
+ ):
246
+ if self.callback_manager:
247
+ self.callback_manager.run("on_epoch_start", {"epoch": epoch})
248
+
249
+ self.epoch_runner.train()
250
+ self.epoch_runner.validate()
251
+
252
+ if self.checkpoint_manager:
253
+ self.checkpoint_manager.evaluate_criteria(epoch)
254
+
255
+ if self.callback_manager:
256
+ self.callback_manager.run("on_epoch_end", {"epoch": epoch})
257
+
258
+ if self.callback_manager:
259
+ self.callback_manager.run("on_train_end", {"epoch": epoch})
260
+
261
+ if test_model:
262
+ if self.callback_manager:
263
+ self.callback_manager.run("on_test_start", {"epoch": epoch})
264
+
265
+ self.epoch_runner.test()
266
+
267
+ if self.callback_manager:
268
+ self.callback_manager.run("on_test_end", {"epoch": epoch})
269
+
270
+ if self.checkpoint_manager:
271
+ self.checkpoint_manager.save(epoch, final_checkpoint_name)
@@ -0,0 +1,170 @@
1
+ """Manages the evaluation and optional enforcement of constraints on neural network outputs.
2
+
3
+ Responsibilities:
4
+ - Compute and log Constraint Satisfaction Rate (CSR) for training, validation, and test batches.
5
+ - Optionally adjust loss during training based on constraint directions and rescale factors.
6
+ - Handle gradient computation and CGGD application.
7
+ """
8
+
9
+ import torch
10
+ from torch import Tensor, no_grad
11
+ from torch.linalg import vector_norm
12
+
13
+ from ..constraints.base import Constraint
14
+ from ..descriptor import Descriptor
15
+ from ..metrics import MetricManager
16
+
17
+
18
+ class ConstraintEngine:
19
+ """Manages constraint evaluation and enforcement for a neural network.
20
+
21
+ The ConstraintEngine coordinates constraints defined in Constraint objects,
22
+ computes gradients for layers that affect the loss, logs metrics, and optionally
23
+ modifies the loss during training according to the constraints. It supports
24
+ separate phases for training, validation, and testing.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ constraints: list[Constraint],
31
+ descriptor: Descriptor,
32
+ metric_manager: MetricManager,
33
+ device: torch.device,
34
+ epsilon: float,
35
+ aggregator: callable,
36
+ enforce_all: bool,
37
+ ) -> None:
38
+ """Initialize the ConstraintEngine.
39
+
40
+ Args:
41
+ constraints: List of Constraint objects to evaluate and optionally enforce.
42
+ descriptor: Descriptor containing metadata about network layers and which
43
+ variables affect the loss.
44
+ metric_manager: MetricManager instance for logging CSR metrics.
45
+ device: Torch device where tensors will be allocated (CPU or GPU).
46
+ epsilon: Small positive value to avoid division by zero in gradient norms.
47
+ aggregator: Callable used to reduce per-layer constraint contributions
48
+ to a scalar loss adjustment.
49
+ enforce_all: Whether to enforce all constraints during training.
50
+ """
51
+ self.constraints = constraints
52
+ self.descriptor = descriptor
53
+ self.metric_manager = metric_manager
54
+ self.device = device
55
+ self.epsilon = epsilon
56
+ self.enforce_all = enforce_all
57
+ self.aggregator = aggregator
58
+
59
+ def train(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
60
+ """Apply all active constraints during training.
61
+
62
+ Computes the original loss gradients for layers that affect the loss,
63
+ evaluates each constraint, logs the Constraint Satisfaction Rate (CSR),
64
+ and adjusts the loss according to constraint satisfaction.
65
+
66
+ Args:
67
+ data: Dictionary containing input and prediction tensors for the batch.
68
+ loss: The original loss tensor computed from the network output.
69
+
70
+ Returns:
71
+ Tensor: The loss tensor after applying constraint-based adjustments.
72
+ """
73
+ return self._apply_constraints(data, loss, phase="train", enforce=True)
74
+
75
+ def validate(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
76
+ """Evaluate constraints during validation without modifying the loss.
77
+
78
+ Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
79
+ but does not apply rescale adjustments to the loss.
80
+
81
+ Args:
82
+ data: Dictionary containing input and prediction tensors for the batch.
83
+ loss: The original loss tensor computed from the network output.
84
+
85
+ Returns:
86
+ Tensor: The original loss tensor, unchanged.
87
+ """
88
+ return self._apply_constraints(data, loss, phase="valid", enforce=False)
89
+
90
+ def test(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
91
+ """Evaluate constraints during testing without modifying the loss.
92
+
93
+ Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
94
+ but does not apply rescale adjustments to the loss.
95
+
96
+ Args:
97
+ data: Dictionary containing input and prediction tensors for the batch.
98
+ loss: The original loss tensor computed from the network output.
99
+
100
+ Returns:
101
+ Tensor: The original loss tensor, unchanged.
102
+ """
103
+ return self._apply_constraints(data, loss, phase="test", enforce=False)
104
+
105
+ def _apply_constraints(
106
+ self, data: dict[str, Tensor], loss: Tensor, phase: str, enforce: bool
107
+ ) -> Tensor:
108
+ """Evaluate constraints, log CSR, and optionally adjust the loss.
109
+
110
+ During training, computes loss gradients for variable layers that affect the loss.
111
+ Iterates over all constraints, logging the Constraint Satisfaction Rate (CSR)
112
+ and, if enforcement is enabled, adjusts the loss using constraint-specific
113
+ directions and rescale factors.
114
+
115
+ Args:
116
+ data: Dictionary containing input and prediction tensors for the batch.
117
+ loss: Original loss tensor computed from the network output.
118
+ phase: Current phase, one of "train", "valid", or "test".
119
+ enforce: If True, constraint-based adjustments are applied to the loss.
120
+
121
+ Returns:
122
+ Tensor: The combined loss after applying constraints (or the original loss
123
+ if enforce is False or not in training phase).
124
+ """
125
+ total_rescale_loss = torch.tensor(0.0, device=self.device, dtype=loss.dtype)
126
+
127
+ # Precompute gradients for variable layers affecting the loss
128
+ if phase == "train":
129
+ norm_loss_grad = {}
130
+ variable_keys = self.descriptor.variable_keys & self.descriptor.affects_loss_keys
131
+
132
+ for key in variable_keys:
133
+ grad = torch.autograd.grad(
134
+ outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
135
+ )[0]
136
+
137
+ if grad is None:
138
+ raise RuntimeError(
139
+ f"Unable to compute loss gradients for layer '{key}'. "
140
+ "Set has_loss=False in Descriptor if this layer does not affect loss."
141
+ )
142
+
143
+ grad_flat = grad.view(grad.shape[0], -1)
144
+ norm_loss_grad[key] = (
145
+ vector_norm(grad_flat, dim=1, ord=2, keepdim=True)
146
+ .clamp(min=self.epsilon)
147
+ .detach()
148
+ )
149
+
150
+ # Iterate constraints
151
+ for constraint in self.constraints:
152
+ checks, mask = constraint.check_constraint(data)
153
+ directions = constraint.calculate_direction(data)
154
+
155
+ # Log CSR
156
+ csr = (torch.sum(checks * mask) / torch.sum(mask)).unsqueeze(0)
157
+ self.metric_manager.accumulate(f"{constraint.name}/{phase}", csr)
158
+ self.metric_manager.accumulate(f"CSR/{phase}", csr)
159
+
160
+ # Skip adjustment if not enforcing
161
+ if not enforce or not constraint.enforce or not self.enforce_all or phase != "train":
162
+ continue
163
+
164
+ # Compute constraint-based rescale loss
165
+ for key in constraint.layers & self.descriptor.variable_keys:
166
+ with no_grad():
167
+ rescale = (1 - checks) * directions[key] * constraint.rescale_factor
168
+ total_rescale_loss += self.aggregator(data[key] * rescale * norm_loss_grad[key])
169
+
170
+ return loss + total_rescale_loss
@@ -0,0 +1,119 @@
1
+ """Defines the EpochRunner class for running full training, validation, and test epochs.
2
+
3
+ This module handles:
4
+ - Switching the network between training and evaluation modes
5
+ - Iterating over DataLoaders with optional progress bars
6
+ - Delegating per-batch processing to a BatchRunner instance
7
+ - Optional gradient tracking control for evaluation phases
8
+ - Warnings when validation or test loaders are not provided
9
+ """
10
+
11
+ import warnings
12
+
13
+ from torch import set_grad_enabled
14
+ from torch.nn import Module
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+
18
+ from ..core.batch_runner import BatchRunner
19
+
20
+
21
+ class EpochRunner:
22
+ """Runs full epochs over DataLoaders.
23
+
24
+ Responsibilities:
25
+ - Model mode switching
26
+ - Iteration over DataLoader
27
+ - Delegation to BatchRunner
28
+ - Progress bars
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ network: Module,
34
+ batch_runner: BatchRunner,
35
+ train_loader: DataLoader,
36
+ valid_loader: DataLoader | None = None,
37
+ test_loader: DataLoader | None = None,
38
+ *,
39
+ network_uses_grad: bool = False,
40
+ disable_progress_bar: bool = False,
41
+ ):
42
+ """Initialize the EpochRunner.
43
+
44
+ Args:
45
+ network: The neural network module to train/validate/test.
46
+ batch_runner: The BatchRunner instance for processing batches.
47
+ train_loader: DataLoader for training data.
48
+ valid_loader: DataLoader for validation data. Defaults to None.
49
+ test_loader: DataLoader for test data. Defaults to None.
50
+ network_uses_grad: Whether the network uses gradient computation. Defaults to False.
51
+ disable_progress_bar: Whether to disable progress bar display. Defaults to False.
52
+ """
53
+ self.network = network
54
+ self.batch_runner = batch_runner
55
+ self.train_loader = train_loader
56
+ self.valid_loader = valid_loader
57
+ self.test_loader = test_loader
58
+ self.network_uses_grad = network_uses_grad
59
+ self.disable_progress_bar = disable_progress_bar
60
+
61
+ def train(self) -> None:
62
+ """Run a training epoch over the training DataLoader.
63
+
64
+ Sets the network to training mode and iterates over batches,
65
+ delegating each batch to the BatchRunner for processing.
66
+ """
67
+ self.network.train()
68
+
69
+ for batch in tqdm(
70
+ self.train_loader,
71
+ desc="Training batches",
72
+ leave=False,
73
+ disable=self.disable_progress_bar,
74
+ ):
75
+ self.batch_runner.train_batch(batch)
76
+
77
+ def validate(self) -> None:
78
+ """Run a validation epoch over the validation DataLoader.
79
+
80
+ Sets the network to evaluation mode and iterates over batches,
81
+ delegating each batch to the BatchRunner for processing.
82
+ Skips validation if no valid_loader is provided.
83
+ """
84
+ if self.valid_loader is None:
85
+ warnings.warn("Validation skipped: no valid_loader provided.", stacklevel=2)
86
+ return
87
+
88
+ with set_grad_enabled(self.network_uses_grad):
89
+ self.network.eval()
90
+
91
+ for batch in tqdm(
92
+ self.valid_loader,
93
+ desc="Validation batches",
94
+ leave=False,
95
+ disable=self.disable_progress_bar,
96
+ ):
97
+ self.batch_runner.valid_batch(batch)
98
+
99
+ def test(self) -> None:
100
+ """Run a test epoch over the test DataLoader.
101
+
102
+ Sets the network to evaluation mode and iterates over batches,
103
+ delegating each batch to the BatchRunner for processing.
104
+ Skips testing if no test_loader is provided.
105
+ """
106
+ if self.test_loader is None:
107
+ warnings.warn("Testing skipped: no test_loader provided.", stacklevel=2)
108
+ return
109
+
110
+ with set_grad_enabled(self.network_uses_grad):
111
+ self.network.eval()
112
+
113
+ for batch in tqdm(
114
+ self.test_loader,
115
+ desc="Test batches",
116
+ leave=False,
117
+ disable=self.disable_progress_bar,
118
+ ):
119
+ self.batch_runner.test_batch(batch)
congrads/descriptor.py CHANGED
@@ -12,7 +12,7 @@ indices, and optional attributes, such as whether the data is constant or variab
12
12
 
13
13
  from torch import Tensor
14
14
 
15
- from .utils import validate_type
15
+ from .utils.validation import validate_type
16
16
 
17
17
 
18
18
  class Descriptor:
congrads/metrics.py CHANGED
@@ -9,7 +9,7 @@ from collections.abc import Callable
9
9
 
10
10
  from torch import Tensor, cat, nanmean, tensor
11
11
 
12
- from .utils import validate_callable, validate_type
12
+ from .utils.validation import validate_callable, validate_type
13
13
 
14
14
 
15
15
  class Metric:
@@ -0,0 +1,37 @@
1
+ """Module defining transformations and components."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from torch import Tensor
6
+
7
+ from ..utils.validation import validate_type
8
+
9
+
10
+ class Transformation(ABC):
11
+ """Abstract base class for tag data transformations."""
12
+
13
+ def __init__(self, tag: str):
14
+ """Initialize a Transformation.
15
+
16
+ Args:
17
+ tag (str): Tag this transformation applies to.
18
+ """
19
+ validate_type("tag", tag, str)
20
+
21
+ super().__init__()
22
+ self.tag = tag
23
+
24
+ @abstractmethod
25
+ def __call__(self, data: Tensor) -> Tensor:
26
+ """Apply the transformation to the input tensor.
27
+
28
+ Args:
29
+ data (Tensor): Input tensor representing network data.
30
+
31
+ Returns:
32
+ Tensor: Transformed tensor.
33
+
34
+ Raises:
35
+ NotImplementedError: Must be implemented by subclasses.
36
+ """
37
+ raise NotImplementedError
@@ -1,41 +1,11 @@
1
- """Module defining transformations and components."""
1
+ """Module holding specific transformation implementations."""
2
2
 
3
- from abc import ABC, abstractmethod
4
3
  from numbers import Number
5
4
 
6
5
  from torch import Tensor
7
6
 
8
- from .utils import validate_callable, validate_type
9
-
10
-
11
- class Transformation(ABC):
12
- """Abstract base class for tag data transformations."""
13
-
14
- def __init__(self, tag: str):
15
- """Initialize a Transformation.
16
-
17
- Args:
18
- tag (str): Tag this transformation applies to.
19
- """
20
- validate_type("tag", tag, str)
21
-
22
- super().__init__()
23
- self.tag = tag
24
-
25
- @abstractmethod
26
- def __call__(self, data: Tensor) -> Tensor:
27
- """Apply the transformation to the input tensor.
28
-
29
- Args:
30
- data (Tensor): Input tensor representing network data.
31
-
32
- Returns:
33
- Tensor: Transformed tensor.
34
-
35
- Raises:
36
- NotImplementedError: Must be implemented by subclasses.
37
- """
38
- raise NotImplementedError
7
+ from ..utils.validation import validate_callable, validate_type
8
+ from .base import Transformation
39
9
 
40
10
 
41
11
  class IdentityTransformation(Transformation):