congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
congrads/descriptor.py
CHANGED
|
@@ -1,65 +1,169 @@
|
|
|
1
|
+
"""This module defines the `Descriptor` class, which allows assigning tags to parts in the network.
|
|
2
|
+
|
|
3
|
+
It is designed to manage the mapping between tags, their corresponding data dictionary keys and indices,
|
|
4
|
+
and additional properties such as constant or variable status. It provides a way to easily
|
|
5
|
+
place constraints on parts of your network, by referencing the tags
|
|
6
|
+
instead of indices.
|
|
7
|
+
|
|
8
|
+
The `Descriptor` class allows for easy constraint definitions on parts of
|
|
9
|
+
your neural network. It supports registering tags with associated data dictionary keys,
|
|
10
|
+
indices, and optional attributes, such as whether the data is constant or variable.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
from .utils.validation import validate_type
|
|
16
|
+
|
|
17
|
+
|
|
1
18
|
class Descriptor:
|
|
2
|
-
"""
|
|
3
|
-
A class to manage the mapping of neurons to layers and their properties
|
|
4
|
-
(e.g., output, constant, or variable) in a neural network.
|
|
19
|
+
"""A class to manage the mapping between tags.
|
|
5
20
|
|
|
6
|
-
|
|
7
|
-
such as
|
|
8
|
-
as outputs, constants, or variables.
|
|
21
|
+
It represents data locations in the data dictionary and holds the dictionary keys, indices,
|
|
22
|
+
and additional properties (such as min/max values, output, and constant variables).
|
|
9
23
|
|
|
10
|
-
This
|
|
11
|
-
|
|
24
|
+
This class is designed to manage the relationships between the assigned tags and the
|
|
25
|
+
data dictionary keys in a neural network model. It allows for the assignment of properties
|
|
26
|
+
(like minimum and maximum values, and whether data is an output, constant, or variable) to
|
|
27
|
+
each tag. The data is stored in dictionaries and sets for efficient lookups.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
constant_keys (set): A set of keys that represent constant data in the data dictionary.
|
|
31
|
+
variable_keys (set): A set of keys that represent variable data in the data dictionary.
|
|
32
|
+
affects_loss_keys (set): A set of keys that represent data affecting the loss computation.
|
|
12
33
|
"""
|
|
13
34
|
|
|
14
35
|
def __init__(
|
|
15
36
|
self,
|
|
16
37
|
):
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
- `neuron_to_layer`: A dictionary mapping neuron names to their corresponding layer names.
|
|
22
|
-
- `neuron_to_index`: A dictionary mapping neuron names to their corresponding index within a layer.
|
|
23
|
-
- `output_layers`: A set that holds the names of layers marked as output layers.
|
|
24
|
-
- `constant_layers`: A set that holds the names of layers marked as constant layers.
|
|
25
|
-
- `variable_layers`: A set that holds the names of layers marked as variable layers.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
# Define dictionaries that will translate neuron names to layer and index
|
|
29
|
-
self.neuron_to_layer: dict[str, str] = {}
|
|
30
|
-
self.neuron_to_index: dict[str, int] = {}
|
|
38
|
+
"""Initializes the Descriptor object."""
|
|
39
|
+
# Define dictionaries that will translate tags to keys and indices
|
|
40
|
+
self._tag_to_key: dict[str, str] = {}
|
|
41
|
+
self._tag_to_index: dict[str, int] = {}
|
|
31
42
|
|
|
32
|
-
# Define sets that will hold the
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
43
|
+
# Define sets that will hold the keys based on which type
|
|
44
|
+
self.constant_keys: set[str] = set()
|
|
45
|
+
self.variable_keys: set[str] = set()
|
|
46
|
+
self.affects_loss_keys: set[str] = set()
|
|
36
47
|
|
|
37
48
|
def add(
|
|
38
49
|
self,
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
50
|
+
key: str,
|
|
51
|
+
tag: str,
|
|
52
|
+
index: int = None,
|
|
42
53
|
constant: bool = False,
|
|
54
|
+
affects_loss: bool = True,
|
|
43
55
|
):
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
|
|
56
|
+
"""Adds a tag to the descriptor with its associated key, index, and properties.
|
|
57
|
+
|
|
58
|
+
This method registers a tag name and associates it with a
|
|
59
|
+
data dictionary key, its index, and optional properties such as whether
|
|
60
|
+
the key hold output or constant data.
|
|
47
61
|
|
|
48
62
|
Args:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
constant (bool, optional):
|
|
63
|
+
key (str): The key on which the tagged data is located in the data dictionary.
|
|
64
|
+
tag (str): The identifier of the tag.
|
|
65
|
+
index (int): The index were the data is present. Defaults to None.
|
|
66
|
+
constant (bool, optional): Whether the data is constant and is not learned. Defaults to False.
|
|
67
|
+
affects_loss (bool, optional): Whether the data affects the loss computation. Defaults to True.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
71
|
+
ValueError: If a key or index is already assigned for a tag or a duplicate index is used within a key.
|
|
53
72
|
"""
|
|
73
|
+
# Type checking
|
|
74
|
+
validate_type("key", key, str)
|
|
75
|
+
validate_type("tag", tag, str)
|
|
76
|
+
validate_type("index", index, int, allow_none=True)
|
|
77
|
+
validate_type("constant", constant, bool)
|
|
78
|
+
validate_type("affects_loss", affects_loss, bool)
|
|
54
79
|
|
|
55
|
-
|
|
56
|
-
|
|
80
|
+
# Other validations
|
|
81
|
+
if tag in self._tag_to_key:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"There already is a key registered for the tag '{tag}'. "
|
|
84
|
+
"Please use a unique key name for each tag."
|
|
85
|
+
)
|
|
57
86
|
|
|
87
|
+
if tag in self._tag_to_index:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"There already is an index registered for the tag '{tag}'. "
|
|
90
|
+
"Please use a unique name for each tag."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
for existing_tag, assigned_index in self._tag_to_index.items():
|
|
94
|
+
if assigned_index == index and self._tag_to_key[existing_tag] == key:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"The index {index} on key {key} is already "
|
|
97
|
+
"assigned. Every tag must be assigned a different "
|
|
98
|
+
"index that matches the network's output."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Add to dictionaries and sets
|
|
102
|
+
# TODO this now happens on key level, can this also be done on tag level?
|
|
58
103
|
if constant:
|
|
59
|
-
self.
|
|
104
|
+
self.constant_keys.add(key)
|
|
60
105
|
else:
|
|
61
|
-
self.
|
|
106
|
+
self.variable_keys.add(key)
|
|
107
|
+
|
|
108
|
+
if affects_loss:
|
|
109
|
+
self.affects_loss_keys.add(key)
|
|
110
|
+
|
|
111
|
+
self._tag_to_key[tag] = key
|
|
112
|
+
self._tag_to_index[tag] = index
|
|
62
113
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
114
|
+
def exists(self, tag: str) -> bool:
|
|
115
|
+
"""Check if a tag is registered in the descriptor.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
tag (str): The tag identifier to check.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
bool: True if the tag is registered, False otherwise.
|
|
122
|
+
"""
|
|
123
|
+
return tag in self._tag_to_key and tag in self._tag_to_index
|
|
124
|
+
|
|
125
|
+
def location(self, tag: str) -> tuple[str, int | None]:
|
|
126
|
+
"""Get the key and index for a given tag.
|
|
127
|
+
|
|
128
|
+
Looks up the mapping for a registered tag and returns the associated
|
|
129
|
+
dictionary key and the index.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
tag (str): The tag identifier. Must be registered.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
tuple ((str, int | None)): A tuple containing:
|
|
136
|
+
- The key in the data dictionary which holds the data (str).
|
|
137
|
+
- The tensor index where the data is present or None (int | None).
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the tag is not registered in the descriptor.
|
|
141
|
+
"""
|
|
142
|
+
key = self._tag_to_key.get(tag)
|
|
143
|
+
index = self._tag_to_index.get(tag)
|
|
144
|
+
if key is None:
|
|
145
|
+
raise ValueError(f"Tag '{tag}' is not registered in descriptor.")
|
|
146
|
+
return key, index
|
|
147
|
+
|
|
148
|
+
def select(self, tag: str, data: dict[str, Tensor]) -> Tensor:
|
|
149
|
+
"""Extract prediction values for a specific tag.
|
|
150
|
+
|
|
151
|
+
Retrieves the key and index associated with a tag and selects
|
|
152
|
+
the corresponding slice from the given prediction tensor.
|
|
153
|
+
Returns the full tensor if no index was specified when registering the tag.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
tag (str): The tag identifier. Must be registered.
|
|
157
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Tensor: A tensor slice of shape ``(batch_size, 1)`` containing
|
|
161
|
+
the predictions for the specified tag, or the full tensor if no index was specified when registering the tag.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ValueError: If the tag is not registered in the descriptor.
|
|
165
|
+
"""
|
|
166
|
+
key, index = self.location(tag)
|
|
167
|
+
if index is None:
|
|
168
|
+
return data[key]
|
|
169
|
+
return data[key][:, index : index + 1]
|
congrads/metrics.py
CHANGED
|
@@ -1,64 +1,139 @@
|
|
|
1
|
-
|
|
2
|
-
from torchmetrics import Metric
|
|
1
|
+
"""Module for managing metrics during training.
|
|
3
2
|
|
|
4
|
-
|
|
3
|
+
Provides the `Metric` and `MetricManager` classes for accumulating,
|
|
4
|
+
aggregating, and resetting metrics over training batches. Supports
|
|
5
|
+
grouping metrics and using custom accumulation functions.
|
|
6
|
+
"""
|
|
5
7
|
|
|
8
|
+
from collections.abc import Callable
|
|
6
9
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
It computes the proportion of constraints that have been satisfied,
|
|
11
|
-
where satisfaction is determined based on the provided constraint results.
|
|
10
|
+
from torch import Tensor, cat, nanmean, tensor
|
|
11
|
+
|
|
12
|
+
from .utils.validation import validate_callable, validate_type
|
|
12
13
|
|
|
13
|
-
This metric tracks the number of unsatisfied constraints and the total number of constraints
|
|
14
|
-
during the training process, and computes the ratio of satisfied constraints once all updates
|
|
15
|
-
have been made.
|
|
16
14
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
total (Tensor): Tracks the total number of constraints processed.
|
|
15
|
+
class Metric:
|
|
16
|
+
"""Represents a single metric to be accumulated and aggregated.
|
|
20
17
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
at https://lightning.ai/docs/torchmetrics/stable/pages/implement.html
|
|
18
|
+
Stores metric values over multiple batches and computes an aggregated
|
|
19
|
+
result using a specified accumulation function.
|
|
24
20
|
"""
|
|
25
21
|
|
|
26
|
-
def __init__(self,
|
|
27
|
-
"""
|
|
28
|
-
Initializes the ConstraintSatisfactionRatio metric by setting up the
|
|
29
|
-
state variables to track the number of unsatisfied and total constraints.
|
|
22
|
+
def __init__(self, name: str, accumulator: Callable[..., Tensor] = nanmean) -> None:
|
|
23
|
+
"""Initialize a Metric instance.
|
|
30
24
|
|
|
31
25
|
Args:
|
|
32
|
-
|
|
26
|
+
name (str): Name of the metric.
|
|
27
|
+
accumulator (Callable[..., Tensor], optional): Function to aggregate
|
|
28
|
+
accumulated values. Defaults to `torch.nanmean`.
|
|
33
29
|
"""
|
|
30
|
+
# Type checking
|
|
31
|
+
validate_type("name", name, str)
|
|
32
|
+
validate_callable("accumulator", accumulator)
|
|
33
|
+
|
|
34
|
+
self.name = name
|
|
35
|
+
self.accumulator = accumulator
|
|
36
|
+
self.values: list[Tensor] = []
|
|
37
|
+
self.sample_count = 0
|
|
34
38
|
|
|
35
|
-
|
|
36
|
-
|
|
39
|
+
def accumulate(self, value: Tensor) -> None:
|
|
40
|
+
"""Accumulate a new value for the metric.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
value (Tensor): Metric values for the current batch.
|
|
44
|
+
"""
|
|
45
|
+
self.values.append(value.detach().clone())
|
|
46
|
+
self.sample_count += value.size(0)
|
|
37
47
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
|
48
|
+
def aggregate(self) -> Tensor:
|
|
49
|
+
"""Compute the aggregated value of the metric.
|
|
41
50
|
|
|
42
|
-
|
|
51
|
+
Returns:
|
|
52
|
+
Tensor: The aggregated metric value. Returns NaN if no values
|
|
53
|
+
have been accumulated.
|
|
43
54
|
"""
|
|
44
|
-
|
|
55
|
+
if not self.values:
|
|
56
|
+
return tensor(float("nan"))
|
|
57
|
+
|
|
58
|
+
combined = cat(self.values)
|
|
59
|
+
return self.accumulator(combined)
|
|
60
|
+
|
|
61
|
+
def reset(self) -> None:
|
|
62
|
+
"""Reset the accumulated values and sample count for the metric."""
|
|
63
|
+
self.values = []
|
|
64
|
+
self.sample_count = 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class MetricManager:
|
|
68
|
+
"""Manages multiple metrics and groups for training or evaluation.
|
|
69
|
+
|
|
70
|
+
Supports registering metrics, accumulating values by name, aggregating
|
|
71
|
+
metrics by group, and resetting metrics by group.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self) -> None:
|
|
75
|
+
"""Initialize a MetricManager instance."""
|
|
76
|
+
self.metrics: dict[str, Metric] = {}
|
|
77
|
+
self.groups: dict[str, str] = {}
|
|
78
|
+
|
|
79
|
+
def register(
|
|
80
|
+
self, name: str, group: str = "default", accumulator: Callable[..., Tensor] = nanmean
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Register a new metric under a specified group.
|
|
45
83
|
|
|
46
84
|
Args:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
1 for unsatisfied).
|
|
85
|
+
name (str): Name of the metric.
|
|
86
|
+
group (str, optional): Group name for the metric. Defaults to "default".
|
|
87
|
+
accumulator (Callable[..., Tensor], optional): Function to aggregate
|
|
88
|
+
accumulated values. Defaults to `torch.nanmean`.
|
|
52
89
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
90
|
+
# Type checking
|
|
91
|
+
validate_type("name", name, str)
|
|
92
|
+
validate_type("group", group, str)
|
|
93
|
+
validate_callable("accumulator", accumulator)
|
|
94
|
+
|
|
95
|
+
self.metrics[name] = Metric(name, accumulator)
|
|
96
|
+
self.groups[name] = group
|
|
97
|
+
|
|
98
|
+
def accumulate(self, name: str, value: Tensor) -> None:
|
|
99
|
+
"""Accumulate a value for a specific metric by name.
|
|
55
100
|
|
|
56
|
-
|
|
101
|
+
Args:
|
|
102
|
+
name (str): Name of the metric.
|
|
103
|
+
value (Tensor): Metric values for the current batch.
|
|
57
104
|
"""
|
|
58
|
-
|
|
59
|
-
|
|
105
|
+
if name not in self.metrics:
|
|
106
|
+
raise KeyError(f"Metric '{name}' is not registered.")
|
|
107
|
+
|
|
108
|
+
self.metrics[name].accumulate(value)
|
|
109
|
+
|
|
110
|
+
def aggregate(self, group: str = "default") -> dict[str, Tensor]:
|
|
111
|
+
"""Aggregate all metrics in a specified group.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
group (str, optional): The group of metrics to aggregate. Defaults to "default".
|
|
60
115
|
|
|
61
116
|
Returns:
|
|
62
|
-
Tensor:
|
|
117
|
+
dict[str, Tensor]: Dictionary mapping metric names to their
|
|
118
|
+
aggregated values.
|
|
63
119
|
"""
|
|
64
|
-
return
|
|
120
|
+
return {
|
|
121
|
+
name: metric.aggregate()
|
|
122
|
+
for name, metric in self.metrics.items()
|
|
123
|
+
if self.groups[name] == group
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def reset(self, group: str = "default") -> None:
|
|
127
|
+
"""Reset all metrics in a specified group.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
group (str, optional): The group of metrics to reset. Defaults to "default".
|
|
131
|
+
"""
|
|
132
|
+
for name, metric in self.metrics.items():
|
|
133
|
+
if self.groups[name] == group:
|
|
134
|
+
metric.reset()
|
|
135
|
+
|
|
136
|
+
def reset_all(self) -> None:
|
|
137
|
+
"""Reset all metrics across all groups."""
|
|
138
|
+
for metric in self.metrics.values():
|
|
139
|
+
metric.reset()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Module defining the network architectures and components."""
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from torch.nn import Linear, Module, ReLU, Sequential
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MLPNetwork(Module):
|
|
8
|
+
"""A multi-layer perceptron (MLP) neural network with configurable hidden layers."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
n_inputs,
|
|
13
|
+
n_outputs,
|
|
14
|
+
n_hidden_layers=3,
|
|
15
|
+
hidden_dim=35,
|
|
16
|
+
activation=None,
|
|
17
|
+
):
|
|
18
|
+
"""Initialize the MLPNetwork.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
n_inputs (int, optional): Number of input features. Defaults to 25.
|
|
22
|
+
n_outputs (int, optional): Number of output features. Defaults to 2.
|
|
23
|
+
n_hidden_layers (int, optional): Number of hidden layers. Defaults to 3.
|
|
24
|
+
hidden_dim (int, optional): Dimensionality of hidden layers. Defaults to 35.
|
|
25
|
+
activation (nn.Module, optional): Activation function module (e.g.,
|
|
26
|
+
`ReLU()`, `Tanh()`, `LeakyReLU(0.1)`). Defaults to `ReLU()`.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
# Init object variables
|
|
31
|
+
self.n_inputs = n_inputs
|
|
32
|
+
self.n_outputs = n_outputs
|
|
33
|
+
self.n_hidden_layers = n_hidden_layers
|
|
34
|
+
self.hidden_dim = hidden_dim
|
|
35
|
+
|
|
36
|
+
# Default activation function
|
|
37
|
+
if activation is None:
|
|
38
|
+
activation = ReLU()
|
|
39
|
+
self.activation = activation
|
|
40
|
+
|
|
41
|
+
# Build network layers
|
|
42
|
+
layers = []
|
|
43
|
+
|
|
44
|
+
# Input layer with activation
|
|
45
|
+
layers.append(Linear(n_inputs, hidden_dim))
|
|
46
|
+
layers.append(self.activation)
|
|
47
|
+
|
|
48
|
+
# Hidden layers (with activation after each)
|
|
49
|
+
for _ in range(n_hidden_layers - 1):
|
|
50
|
+
layers.append(Linear(hidden_dim, hidden_dim))
|
|
51
|
+
layers.append(self.activation)
|
|
52
|
+
|
|
53
|
+
# Output layer (no activation by default)
|
|
54
|
+
layers.append(Linear(hidden_dim, n_outputs))
|
|
55
|
+
|
|
56
|
+
self.network = Sequential(*layers)
|
|
57
|
+
|
|
58
|
+
def forward(self, data: dict[str, Tensor]):
|
|
59
|
+
"""Run a forward pass through the network.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
data (dict[str, Tensor]): Input data to be processed by the network.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
dict: The original data tensor augmented with the network's output (having key "output").
|
|
66
|
+
"""
|
|
67
|
+
data["output"] = self.network(data["input"])
|
|
68
|
+
return data
|
congrads/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Module defining transformations and components."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from ..utils.validation import validate_type
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Transformation(ABC):
|
|
11
|
+
"""Abstract base class for tag data transformations."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, tag: str):
|
|
14
|
+
"""Initialize a Transformation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
tag (str): Tag this transformation applies to.
|
|
18
|
+
"""
|
|
19
|
+
validate_type("tag", tag, str)
|
|
20
|
+
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.tag = tag
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def __call__(self, data: Tensor) -> Tensor:
|
|
26
|
+
"""Apply the transformation to the input tensor.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data (Tensor): Input tensor representing network data.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tensor: Transformed tensor.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
NotImplementedError: Must be implemented by subclasses.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Module holding specific transformation implementations."""
|
|
2
|
+
|
|
3
|
+
from numbers import Number
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from ..utils.validation import validate_callable, validate_type
|
|
8
|
+
from .base import Transformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IdentityTransformation(Transformation):
|
|
12
|
+
"""A transformation that returns the input unchanged."""
|
|
13
|
+
|
|
14
|
+
def __call__(self, data: Tensor) -> Tensor:
|
|
15
|
+
"""Return the input tensor without any modification.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data (Tensor): Input tensor.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tensor: The same input tensor.
|
|
22
|
+
"""
|
|
23
|
+
return data
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DenormalizeMinMax(Transformation):
|
|
27
|
+
"""A transformation that denormalizes data using min-max scaling."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, tag: str, min: Number, max: Number):
|
|
30
|
+
"""Initialize a min-max denormalization transformation.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
tag (str): Tag this transformation applies to.
|
|
34
|
+
min (Number): Minimum value used for denormalization.
|
|
35
|
+
max (Number): Maximum value used for denormalization.
|
|
36
|
+
"""
|
|
37
|
+
validate_type("min", min, Number)
|
|
38
|
+
validate_type("max", max, Number)
|
|
39
|
+
|
|
40
|
+
super().__init__(tag)
|
|
41
|
+
|
|
42
|
+
self.min = min
|
|
43
|
+
self.max = max
|
|
44
|
+
|
|
45
|
+
def __call__(self, data: Tensor) -> Tensor:
|
|
46
|
+
"""Denormalize the input tensor using the min-max range.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
data (Tensor): Normalized input tensor (typically in range [0, 1]).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tensor: Denormalized tensor in the range [min, max].
|
|
53
|
+
"""
|
|
54
|
+
return data * (self.max - self.min) + self.min
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ApplyOperator(Transformation):
|
|
58
|
+
"""A transformation that applies a binary operator to the input tensor."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, tag: str, operator: callable, value: Number):
|
|
61
|
+
"""Initialize an operator-based transformation.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tag (str): Tag this transformation applies to.
|
|
65
|
+
operator (callable): A callable that takes two arguments (tensor, value)
|
|
66
|
+
and returns a tensor.
|
|
67
|
+
value (Number): The value to use as the second argument in the operator.
|
|
68
|
+
"""
|
|
69
|
+
validate_callable("operator", operator)
|
|
70
|
+
validate_type("value", value, Number)
|
|
71
|
+
|
|
72
|
+
super().__init__(tag)
|
|
73
|
+
|
|
74
|
+
self.operator = operator
|
|
75
|
+
self.value = value
|
|
76
|
+
|
|
77
|
+
def __call__(self, data: Tensor) -> Tensor:
|
|
78
|
+
"""Apply the operator to the input tensor and the specified value.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
data (Tensor): Input tensor.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Tensor: Result of applying `operator(data, value)`.
|
|
85
|
+
"""
|
|
86
|
+
return self.operator(data, self.value)
|