congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +0 -17
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +1 -1
- congrads/constraints/base.py +174 -0
- congrads/{constraints.py → constraints/registry.py} +120 -158
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +170 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/descriptor.py +1 -1
- congrads/metrics.py +1 -1
- congrads/transformations/base.py +37 -0
- congrads/{transformations.py → transformations/registry.py} +3 -33
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +194 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/METADATA +1 -1
- congrads-1.2.0.dist-info/RECORD +23 -0
- congrads/core.py +0 -773
- congrads/utils.py +0 -1078
- congrads-1.1.2.dist-info/RECORD +0 -14
- /congrads/{datasets.py → datasets/registry.py} +0 -0
- /congrads/{networks.py → networks/registry.py} +0 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
"""This module holds utility functions and classes for the congrads package."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
from torch import Generator, Tensor, argsort, cat, int32, unique
|
|
13
|
+
from torch.nn.modules.loss import _Loss
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CSVLogger:
|
|
18
|
+
"""A utility class for logging key-value pairs to a CSV file, organized by epochs.
|
|
19
|
+
|
|
20
|
+
Supports merging with existing logs or overwriting them.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
file_path (str): The path to the CSV file for logging.
|
|
24
|
+
overwrite (bool): If True, overwrites any existing file at the file_path.
|
|
25
|
+
merge (bool): If True, merges new values with existing data in the file.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If both overwrite and merge are True.
|
|
29
|
+
FileExistsError: If the file already exists and neither overwrite nor merge is True.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, file_path: str, overwrite: bool = False, merge: bool = True):
|
|
33
|
+
"""Initializes the CSVLogger.
|
|
34
|
+
|
|
35
|
+
Supports merging with existing logs or overwriting them.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
file_path (str): The path to the CSV file for logging.
|
|
39
|
+
overwrite (optional, bool): If True, overwrites any existing file at the file_path. Defaults to False.
|
|
40
|
+
merge (optional, bool): If True, merges new values with existing data in the file. Defaults to True.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If both overwrite and merge are True.
|
|
44
|
+
FileExistsError: If the file already exists and neither overwrite nor merge is True.
|
|
45
|
+
"""
|
|
46
|
+
self.file_path = file_path
|
|
47
|
+
self.values: dict[tuple[int, str], float] = {}
|
|
48
|
+
|
|
49
|
+
if merge and overwrite:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"The attributes overwrite and merge cannot be True at the "
|
|
52
|
+
"same time. Either specify overwrite=True or merge=True."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if not os.path.exists(file_path):
|
|
56
|
+
pass
|
|
57
|
+
elif merge:
|
|
58
|
+
self.load()
|
|
59
|
+
elif overwrite:
|
|
60
|
+
pass
|
|
61
|
+
else:
|
|
62
|
+
raise FileExistsError(
|
|
63
|
+
f"A CSV file already exists at {file_path}. Specify "
|
|
64
|
+
"CSVLogger(..., overwrite=True) to overwrite the file."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def add_value(self, name: str, value: float, epoch: int):
|
|
68
|
+
"""Adds a value to the logger for a specific epoch and name.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
name (str): The name of the metric or value to log.
|
|
72
|
+
value (float): The value to log.
|
|
73
|
+
epoch (int): The epoch associated with the value.
|
|
74
|
+
"""
|
|
75
|
+
self.values[epoch, name] = value
|
|
76
|
+
|
|
77
|
+
def save(self):
|
|
78
|
+
"""Saves the logged values to the specified CSV file.
|
|
79
|
+
|
|
80
|
+
If the file exists and merge is enabled, merges the current data
|
|
81
|
+
with the existing file.
|
|
82
|
+
"""
|
|
83
|
+
data = self.to_dataframe(self.values)
|
|
84
|
+
data.to_csv(self.file_path, index=False)
|
|
85
|
+
|
|
86
|
+
def load(self):
|
|
87
|
+
"""Loads data from the CSV file into the logger.
|
|
88
|
+
|
|
89
|
+
Converts the CSV data into the internal dictionary format for
|
|
90
|
+
further updates or operations.
|
|
91
|
+
"""
|
|
92
|
+
df = pd.read_csv(self.file_path)
|
|
93
|
+
self.values = self.to_dict(df)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
|
|
97
|
+
"""Converts a dictionary of values into a DataFrame.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
values (dict[tuple[int, str], float]): A dictionary of values keyed by (epoch, name).
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
pd.DataFrame: A DataFrame where epochs are rows, names are columns, and values are the cell data.
|
|
104
|
+
"""
|
|
105
|
+
# Convert to a DataFrame
|
|
106
|
+
df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
|
|
107
|
+
|
|
108
|
+
# Reset the index to separate epoch and name into columns
|
|
109
|
+
df.index = pd.MultiIndex.from_tuples(df.index, names=["epoch", "name"])
|
|
110
|
+
df = df.reset_index()
|
|
111
|
+
|
|
112
|
+
# Pivot the DataFrame so epochs are rows and names are columns
|
|
113
|
+
result = df.pivot(index="epoch", columns="name", values="value")
|
|
114
|
+
|
|
115
|
+
# Optional: Reset the column names for a cleaner look
|
|
116
|
+
result = result.reset_index().rename_axis(columns=None)
|
|
117
|
+
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
|
|
122
|
+
"""Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}."""
|
|
123
|
+
# Set the epoch column as the index (if not already)
|
|
124
|
+
df = df.set_index("epoch")
|
|
125
|
+
|
|
126
|
+
# Stack the DataFrame to create a multi-index series
|
|
127
|
+
stacked = df.stack()
|
|
128
|
+
|
|
129
|
+
# Convert the multi-index series to a dictionary
|
|
130
|
+
result = stacked.to_dict()
|
|
131
|
+
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def split_data_loaders(
|
|
136
|
+
data: Dataset,
|
|
137
|
+
loader_args: dict = None,
|
|
138
|
+
train_loader_args: dict = None,
|
|
139
|
+
valid_loader_args: dict = None,
|
|
140
|
+
test_loader_args: dict = None,
|
|
141
|
+
train_size: float = 0.8,
|
|
142
|
+
valid_size: float = 0.1,
|
|
143
|
+
test_size: float = 0.1,
|
|
144
|
+
split_generator: Generator = None,
|
|
145
|
+
) -> tuple[DataLoader, DataLoader, DataLoader]:
|
|
146
|
+
"""Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
data (Dataset): The dataset to be split.
|
|
150
|
+
loader_args (dict, optional): Default DataLoader arguments, merges
|
|
151
|
+
with loader-specific arguments, overlapping keys from
|
|
152
|
+
loader-specific arguments are superseded.
|
|
153
|
+
train_loader_args (dict, optional): Training DataLoader arguments,
|
|
154
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
155
|
+
valid_loader_args (dict, optional): Validation DataLoader arguments,
|
|
156
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
157
|
+
test_loader_args (dict, optional): Test DataLoader arguments,
|
|
158
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
159
|
+
train_size (float, optional): Proportion of data to be used for
|
|
160
|
+
training. Defaults to 0.8.
|
|
161
|
+
valid_size (float, optional): Proportion of data to be used for
|
|
162
|
+
validation. Defaults to 0.1.
|
|
163
|
+
test_size (float, optional): Proportion of data to be used for
|
|
164
|
+
testing. Defaults to 0.1.
|
|
165
|
+
split_generator (Generator, optional): Optional random seed generator
|
|
166
|
+
to control the splitting of the dataset.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
tuple: A tuple containing three DataLoader objects: one for the
|
|
170
|
+
training, validation and test set.
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If the train_size, valid_size, and test_size are not
|
|
174
|
+
between 0 and 1, or if their sum does not equal 1.
|
|
175
|
+
"""
|
|
176
|
+
# Validate split sizes
|
|
177
|
+
if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
|
|
178
|
+
raise ValueError("train_size, valid_size, and test_size must be between 0 and 1.")
|
|
179
|
+
if not abs(train_size + valid_size + test_size - 1.0) < 1e-6:
|
|
180
|
+
raise ValueError("train_size, valid_size, and test_size must sum to 1.")
|
|
181
|
+
|
|
182
|
+
# Perform the splits
|
|
183
|
+
train_val_data, test_data = random_split(
|
|
184
|
+
data, [1 - test_size, test_size], generator=split_generator
|
|
185
|
+
)
|
|
186
|
+
train_data, valid_data = random_split(
|
|
187
|
+
train_val_data,
|
|
188
|
+
[
|
|
189
|
+
train_size / (1 - test_size),
|
|
190
|
+
valid_size / (1 - test_size),
|
|
191
|
+
],
|
|
192
|
+
generator=split_generator,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Set default arguments for each loader
|
|
196
|
+
train_loader_args = dict(loader_args or {}, **(train_loader_args or {}))
|
|
197
|
+
valid_loader_args = dict(loader_args or {}, **(valid_loader_args or {}))
|
|
198
|
+
test_loader_args = dict(loader_args or {}, **(test_loader_args or {}))
|
|
199
|
+
|
|
200
|
+
# Create the DataLoaders
|
|
201
|
+
train_generator = DataLoader(train_data, **train_loader_args)
|
|
202
|
+
valid_generator = DataLoader(valid_data, **valid_loader_args)
|
|
203
|
+
test_generator = DataLoader(test_data, **test_loader_args)
|
|
204
|
+
|
|
205
|
+
return train_generator, valid_generator, test_generator
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class ZeroLoss(_Loss):
|
|
209
|
+
"""A loss function that always returns zero.
|
|
210
|
+
|
|
211
|
+
This custom loss function ignores the input and target tensors
|
|
212
|
+
and returns a constant zero loss, which can be useful for debugging
|
|
213
|
+
or when no meaningful loss computation is required.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
reduction (str, optional): Specifies the reduction to apply to
|
|
217
|
+
the output. Defaults to "mean". Although specified, it has
|
|
218
|
+
no effect as the loss is always zero.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, reduction: str = "mean"):
|
|
222
|
+
"""Initialize ZeroLoss with a specified reduction method.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
reduction (str): Specifies the reduction to apply to the output. Defaults to "mean".
|
|
226
|
+
"""
|
|
227
|
+
super().__init__(reduction=reduction)
|
|
228
|
+
|
|
229
|
+
def forward(self, predictions: Tensor, target: Tensor, **kwargs) -> torch.Tensor:
|
|
230
|
+
"""Return a dummy loss of zero regardless of input and target."""
|
|
231
|
+
return (predictions * 0).sum()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class LossWrapper:
|
|
235
|
+
"""Wraps a loss function to optionally accept batch-level data.
|
|
236
|
+
|
|
237
|
+
This adapter allows both standard PyTorch loss functions (e.g.
|
|
238
|
+
``nn.MSELoss``) and custom loss functions that accept an additional
|
|
239
|
+
``data`` keyword argument to be used interchangeably.
|
|
240
|
+
|
|
241
|
+
The wrapped loss can always be called with the same signature:
|
|
242
|
+
|
|
243
|
+
loss(output, target, data=batch)
|
|
244
|
+
|
|
245
|
+
If the underlying loss function does not accept ``data``, the
|
|
246
|
+
argument is silently ignored.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(self, loss_fn: Callable):
|
|
250
|
+
"""Initializes the LossWrapper.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
loss_fn (Callable): The underlying loss function or callable
|
|
254
|
+
(e.g. a ``torch.nn.Module`` or a custom function).
|
|
255
|
+
"""
|
|
256
|
+
self.loss_fn = loss_fn
|
|
257
|
+
self.accepts_data = self._accepts_data()
|
|
258
|
+
|
|
259
|
+
def _accepts_data(self) -> bool:
|
|
260
|
+
"""Checks whether the wrapped loss function accepts a ``data`` argument.
|
|
261
|
+
|
|
262
|
+
The check returns ``True`` if either:
|
|
263
|
+
- The function explicitly defines a ``data`` parameter, or
|
|
264
|
+
- The function accepts arbitrary keyword arguments (``**kwargs``).
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
bool: ``True`` if the loss function can accept ``data``,
|
|
268
|
+
``False`` otherwise.
|
|
269
|
+
"""
|
|
270
|
+
# For nn.Module, inspect forward(), not __call__()
|
|
271
|
+
if isinstance(self.loss_fn, nn.Module):
|
|
272
|
+
fn = self.loss_fn.forward
|
|
273
|
+
else:
|
|
274
|
+
fn = self.loss_fn
|
|
275
|
+
|
|
276
|
+
sig = inspect.signature(fn)
|
|
277
|
+
|
|
278
|
+
return "data" in sig.parameters or any(
|
|
279
|
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def __call__(self, output: Tensor, target: Tensor, *, data: dict | None = None) -> Tensor:
|
|
283
|
+
"""Computes the loss.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
output (torch.Tensor): Model predictions.
|
|
287
|
+
target (torch.Tensor): Ground-truth targets.
|
|
288
|
+
data (dict, optional): Full batch data passed to custom loss
|
|
289
|
+
functions that require additional context.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
torch.Tensor: Computed loss value.
|
|
293
|
+
"""
|
|
294
|
+
if self.accepts_data:
|
|
295
|
+
return self.loss_fn(output, target, data=data)
|
|
296
|
+
return self.loss_fn(output, target)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor):
|
|
300
|
+
"""Reorders input samples to support monotonicity checking.
|
|
301
|
+
|
|
302
|
+
Reorders input samples such that:
|
|
303
|
+
1. Samples from the same run are grouped together.
|
|
304
|
+
2. Within each run, samples are sorted chronologically.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
data (Tensor): The input data.
|
|
308
|
+
ordering (Tensor): On what to order the data.
|
|
309
|
+
identifiers (Tensor): Identifiers specifying different runs.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Tuple[Tensor, Tensor, Tensor]: Sorted data, ordering, and
|
|
313
|
+
identifiers.
|
|
314
|
+
"""
|
|
315
|
+
# Step 1: Sort by run identifiers
|
|
316
|
+
sorted_indices = argsort(identifiers, stable=True, dim=0).reshape(-1)
|
|
317
|
+
data_sorted, ordering_sorted, identifiers_sorted = (
|
|
318
|
+
data[sorted_indices],
|
|
319
|
+
ordering[sorted_indices],
|
|
320
|
+
identifiers[sorted_indices],
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Step 2: Get unique runs and their counts
|
|
324
|
+
_, counts = unique(identifiers, sorted=False, return_counts=True)
|
|
325
|
+
counts = counts.to(int32) # Avoid repeated conversions
|
|
326
|
+
|
|
327
|
+
sorted_data, sorted_ordering, sorted_identifiers = [], [], []
|
|
328
|
+
index = 0 # Tracks the current batch element index
|
|
329
|
+
|
|
330
|
+
# Step 3: Process each run independently
|
|
331
|
+
for count in counts:
|
|
332
|
+
end = index + count
|
|
333
|
+
run_data, run_ordering, run_identifiers = (
|
|
334
|
+
data_sorted[index:end],
|
|
335
|
+
ordering_sorted[index:end],
|
|
336
|
+
identifiers_sorted[index:end],
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Step 4: Sort within each run by time
|
|
340
|
+
time_sorted_indices = argsort(run_ordering, stable=True, dim=0).reshape(-1)
|
|
341
|
+
sorted_data.append(run_data[time_sorted_indices])
|
|
342
|
+
sorted_ordering.append(run_ordering[time_sorted_indices])
|
|
343
|
+
sorted_identifiers.append(run_identifiers[time_sorted_indices])
|
|
344
|
+
|
|
345
|
+
index = end # Move to next run
|
|
346
|
+
|
|
347
|
+
# Step 5: Concatenate results and return
|
|
348
|
+
return (
|
|
349
|
+
cat(sorted_data, dim=0),
|
|
350
|
+
cat(sorted_ordering, dim=0),
|
|
351
|
+
cat(sorted_identifiers, dim=0),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class DictDatasetWrapper(Dataset):
|
|
356
|
+
"""A wrapper for PyTorch datasets that converts each sample into a dictionary.
|
|
357
|
+
|
|
358
|
+
This class takes any PyTorch dataset and returns its samples as dictionaries,
|
|
359
|
+
where each element of the original sample is mapped to a key. This is useful
|
|
360
|
+
for integration with the Congrads toolbox or other frameworks that expect
|
|
361
|
+
dictionary-formatted data.
|
|
362
|
+
|
|
363
|
+
Attributes:
|
|
364
|
+
base_dataset (Dataset): The underlying PyTorch dataset being wrapped.
|
|
365
|
+
field_names (list[str] | None): Names assigned to each field of a sample.
|
|
366
|
+
If None, default names like 'field0', 'field1', ... are generated.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
base_dataset (Dataset): The PyTorch dataset to wrap.
|
|
370
|
+
field_names (list[str] | None, optional): Custom names for each field.
|
|
371
|
+
If provided, the list is truncated or extended to match the number
|
|
372
|
+
of elements in a sample. Defaults to None.
|
|
373
|
+
|
|
374
|
+
Example:
|
|
375
|
+
Wrapping a TensorDataset with custom field names:
|
|
376
|
+
|
|
377
|
+
>>> from torch.utils.data import TensorDataset
|
|
378
|
+
>>> import torch
|
|
379
|
+
>>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
|
|
380
|
+
>>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
|
|
381
|
+
>>> wrapped[0]
|
|
382
|
+
{'features': tensor([...]), 'label': tensor(1)}
|
|
383
|
+
|
|
384
|
+
Wrapping a built-in dataset like CIFAR10:
|
|
385
|
+
|
|
386
|
+
>>> from torchvision.datasets import CIFAR10
|
|
387
|
+
>>> from torchvision import transforms
|
|
388
|
+
>>> cifar = CIFAR10(
|
|
389
|
+
... root="./data", train=True, download=True, transform=transforms.ToTensor()
|
|
390
|
+
... )
|
|
391
|
+
>>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
|
|
392
|
+
>>> wrapped_cifar[0]
|
|
393
|
+
{'input': tensor([...]), 'output': tensor(6)}
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def __init__(self, base_dataset: Dataset, field_names: list[str] | None = None):
|
|
397
|
+
"""Initialize the DictDatasetWrapper.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
base_dataset (Dataset): The PyTorch dataset to wrap.
|
|
401
|
+
field_names (list[str] | None, optional): Optional list of field names
|
|
402
|
+
for the dictionary output. Defaults to None, in which case
|
|
403
|
+
automatic names 'field0', 'field1', ... are generated.
|
|
404
|
+
"""
|
|
405
|
+
self.base_dataset = base_dataset
|
|
406
|
+
self.field_names = field_names
|
|
407
|
+
|
|
408
|
+
def __getitem__(self, idx: int):
|
|
409
|
+
"""Retrieve a sample from the dataset as a dictionary.
|
|
410
|
+
|
|
411
|
+
Each element in the original sample is mapped to a key in the dictionary.
|
|
412
|
+
If the sample is not a tuple or list, it is converted into a single-element
|
|
413
|
+
tuple. Numerical values (int or float) are automatically converted to tensors.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
idx (int): Index of the sample to retrieve.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
dict: A dictionary mapping field names to sample values.
|
|
420
|
+
"""
|
|
421
|
+
sample = self.base_dataset[idx]
|
|
422
|
+
|
|
423
|
+
# Ensure sample is always a tuple
|
|
424
|
+
if not isinstance(sample, (tuple, list)):
|
|
425
|
+
sample = (sample,)
|
|
426
|
+
|
|
427
|
+
n_fields = len(sample)
|
|
428
|
+
|
|
429
|
+
# Generate default field names if none are provided
|
|
430
|
+
if self.field_names is None:
|
|
431
|
+
names = [f"field{i}" for i in range(n_fields)]
|
|
432
|
+
else:
|
|
433
|
+
names = list(self.field_names)
|
|
434
|
+
if len(names) < n_fields:
|
|
435
|
+
names.extend([f"field{i}" for i in range(len(names), n_fields)])
|
|
436
|
+
names = names[:n_fields] # truncate if too long
|
|
437
|
+
|
|
438
|
+
# Build dictionary
|
|
439
|
+
out = {}
|
|
440
|
+
for name, value in zip(names, sample, strict=False):
|
|
441
|
+
if isinstance(value, (int, float)):
|
|
442
|
+
value = torch.tensor(value)
|
|
443
|
+
out[name] = value
|
|
444
|
+
|
|
445
|
+
return out
|
|
446
|
+
|
|
447
|
+
def __len__(self):
|
|
448
|
+
"""Return the number of samples in the dataset.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
int: Length of the underlying dataset.
|
|
452
|
+
"""
|
|
453
|
+
return len(self.base_dataset)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class Seeder:
|
|
457
|
+
"""A deterministic seed manager for reproducible experiments.
|
|
458
|
+
|
|
459
|
+
This class provides a way to consistently generate pseudo-random
|
|
460
|
+
seeds derived from a fixed base seed. It ensures that different
|
|
461
|
+
libraries (Python's `random`, NumPy, and PyTorch) are initialized
|
|
462
|
+
with reproducible seeds, making experiments deterministic across runs.
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
def __init__(self, base_seed: int):
|
|
466
|
+
"""Initialize the Seeder with a base seed.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
base_seed (int): The initial seed from which all subsequent
|
|
470
|
+
pseudo-random seeds are deterministically derived.
|
|
471
|
+
"""
|
|
472
|
+
self._rng = random.Random(base_seed)
|
|
473
|
+
|
|
474
|
+
def roll_seed(self) -> int:
|
|
475
|
+
"""Generate a new deterministic pseudo-random seed.
|
|
476
|
+
|
|
477
|
+
Each call returns an integer seed derived from the internal
|
|
478
|
+
pseudo-random generator, which itself is initialized by the
|
|
479
|
+
base seed.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
int: A pseudo-random integer seed in the range [0, 2**31 - 1].
|
|
483
|
+
"""
|
|
484
|
+
return self._rng.randint(0, 2**31 - 1)
|
|
485
|
+
|
|
486
|
+
def set_reproducible(self) -> None:
|
|
487
|
+
"""Configure global random states for reproducibility.
|
|
488
|
+
|
|
489
|
+
Seeds the following libraries with deterministically generated
|
|
490
|
+
seeds based on the base seed:
|
|
491
|
+
- Python's built-in `random`
|
|
492
|
+
- NumPy's random number generator
|
|
493
|
+
- PyTorch (CPU and GPU)
|
|
494
|
+
|
|
495
|
+
Also enforces deterministic behavior in PyTorch by:
|
|
496
|
+
- Seeding all CUDA devices
|
|
497
|
+
- Disabling CuDNN benchmarking
|
|
498
|
+
- Enabling CuDNN deterministic mode
|
|
499
|
+
"""
|
|
500
|
+
random.seed(self.roll_seed())
|
|
501
|
+
np.random.seed(self.roll_seed())
|
|
502
|
+
torch.manual_seed(self.roll_seed())
|
|
503
|
+
torch.cuda.manual_seed_all(self.roll_seed())
|
|
504
|
+
|
|
505
|
+
torch.backends.cudnn.deterministic = True
|
|
506
|
+
torch.backends.cudnn.benchmark = False
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Validation utilities for type checking and argument validation.
|
|
2
|
+
|
|
3
|
+
This module provides utility functions for validating function arguments,
|
|
4
|
+
including type validation, callable validation, and PyTorch-specific
|
|
5
|
+
validation functions.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_type(name, value, expected_types, allow_none=False):
|
|
13
|
+
"""Validate that a value is of the specified type(s).
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
name (str): Name of the argument for error messages.
|
|
17
|
+
value: Value to validate.
|
|
18
|
+
expected_types (type or tuple of types): Expected type(s) for the value.
|
|
19
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
20
|
+
Defaults to False.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
TypeError: If the value is not of the expected type(s).
|
|
24
|
+
"""
|
|
25
|
+
if value is None:
|
|
26
|
+
if not allow_none:
|
|
27
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
if not isinstance(value, expected_types):
|
|
31
|
+
raise TypeError(
|
|
32
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
33
|
+
f"Only values of type {str(expected_types)} are allowed."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def validate_iterable(
|
|
38
|
+
name,
|
|
39
|
+
value,
|
|
40
|
+
expected_element_types,
|
|
41
|
+
allowed_iterables=(list, set, tuple),
|
|
42
|
+
allow_empty=False,
|
|
43
|
+
allow_none=False,
|
|
44
|
+
):
|
|
45
|
+
"""Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name (str): Name of the argument for error messages.
|
|
49
|
+
value: Value to validate.
|
|
50
|
+
expected_element_types (type or tuple of types): Expected type(s)
|
|
51
|
+
for the elements.
|
|
52
|
+
allowed_iterables (tuple of types): Iterable types that are
|
|
53
|
+
allowed (default: list and set).
|
|
54
|
+
allow_empty (bool): Whether to allow empty iterables. Defaults to False.
|
|
55
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
56
|
+
Defaults to False.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
TypeError: If the value is not an allowed iterable type or if
|
|
60
|
+
any element is not of the expected type(s).
|
|
61
|
+
"""
|
|
62
|
+
if value is None:
|
|
63
|
+
if not allow_none:
|
|
64
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
if len(value) == 0:
|
|
68
|
+
if not allow_empty:
|
|
69
|
+
raise TypeError(f"Argument {name} cannot be an empty iterable.")
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
if not isinstance(value, allowed_iterables):
|
|
73
|
+
raise TypeError(
|
|
74
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
75
|
+
f"Only values of type {str(allowed_iterables)} are allowed."
|
|
76
|
+
)
|
|
77
|
+
if not all(isinstance(element, expected_element_types) for element in value):
|
|
78
|
+
raise TypeError(
|
|
79
|
+
f"Invalid elements in {name} '{str(value)}'. "
|
|
80
|
+
f"Only elements of type {str(expected_element_types)} are allowed."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def validate_comparator_pytorch(name, value):
|
|
85
|
+
"""Validate that a value is a callable PyTorch comparator function.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
name (str): Name of the argument for error messages.
|
|
89
|
+
value: Value to validate.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
TypeError: If the value is not callable or not a PyTorch comparator.
|
|
93
|
+
"""
|
|
94
|
+
# List of valid PyTorch comparator functions
|
|
95
|
+
pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
|
|
96
|
+
|
|
97
|
+
# Check if value is callable and if it's one of
|
|
98
|
+
# the PyTorch comparator functions
|
|
99
|
+
if not callable(value):
|
|
100
|
+
raise TypeError(
|
|
101
|
+
f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if value not in pytorch_comparators:
|
|
105
|
+
raise TypeError(
|
|
106
|
+
f"Argument {name} '{str(value)}' is not a valid PyTorch comparator "
|
|
107
|
+
"function. Only PyTorch functions like torch.gt, torch.lt, "
|
|
108
|
+
"torch.ge, torch.le are allowed."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def validate_callable(name, value, allow_none=False):
|
|
113
|
+
"""Validate that a value is callable function.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
name (str): Name of the argument for error messages.
|
|
117
|
+
value: Value to validate.
|
|
118
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
119
|
+
Defaults to False.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
TypeError: If the value is not callable.
|
|
123
|
+
"""
|
|
124
|
+
if value is None:
|
|
125
|
+
if not allow_none:
|
|
126
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
if not callable(value):
|
|
130
|
+
raise TypeError(
|
|
131
|
+
f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def validate_callable_iterable(
|
|
136
|
+
name,
|
|
137
|
+
value,
|
|
138
|
+
allowed_iterables=(list, set, tuple),
|
|
139
|
+
allow_none=False,
|
|
140
|
+
):
|
|
141
|
+
"""Validate that a value is an iterable containing only callable elements.
|
|
142
|
+
|
|
143
|
+
This function ensures that the given value is an iterable
|
|
144
|
+
(e.g., list or set and that all its elements are callable functions.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
name (str): Name of the argument for error messages.
|
|
148
|
+
value: The value to validate.
|
|
149
|
+
allowed_iterables (tuple of types, optional): Iterable types that are
|
|
150
|
+
allowed. Defaults to (list, set).
|
|
151
|
+
allow_none (bool, optional): Whether to allow the value to be None.
|
|
152
|
+
Defaults to False.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
TypeError: If the value is not an allowed iterable type or if any
|
|
156
|
+
element is not callable.
|
|
157
|
+
"""
|
|
158
|
+
if value is None:
|
|
159
|
+
if not allow_none:
|
|
160
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
if not isinstance(value, allowed_iterables):
|
|
164
|
+
raise TypeError(
|
|
165
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
166
|
+
f"Only values of type {str(allowed_iterables)} are allowed."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if not all(callable(element) for element in value):
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"Invalid elements in {name} '{str(value)}'. Only callable functions are allowed."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader]):
|
|
176
|
+
"""Validates that `loaders` is a tuple of three DataLoader instances.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
name (str): The name of the parameter being validated.
|
|
180
|
+
loaders (tuple[DataLoader, DataLoader, DataLoader]): A tuple of
|
|
181
|
+
three DataLoader instances.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
TypeError: If `loaders` is not a tuple of three DataLoader
|
|
185
|
+
instances or contains invalid types.
|
|
186
|
+
"""
|
|
187
|
+
if not isinstance(loaders, tuple) or len(loaders) != 3:
|
|
188
|
+
raise TypeError(f"{name} must be a tuple of three DataLoader instances.")
|
|
189
|
+
|
|
190
|
+
for i, loader in enumerate(loaders):
|
|
191
|
+
if not isinstance(loader, DataLoader):
|
|
192
|
+
raise TypeError(
|
|
193
|
+
f"{name}[{i}] must be an instance of DataLoader, got {type(loader).__name__}."
|
|
194
|
+
)
|