congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,799 @@
|
|
|
1
|
+
"""This module defines several PyTorch dataset classes for loading and working with various datasets.
|
|
2
|
+
|
|
3
|
+
Each dataset class extends the `torch.utils.data.Dataset` class and provides functionality for
|
|
4
|
+
downloading, loading, and transforming specific datasets where applicable.
|
|
5
|
+
|
|
6
|
+
Classes:
|
|
7
|
+
|
|
8
|
+
- SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
|
|
9
|
+
- BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
|
|
10
|
+
- FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
|
|
11
|
+
|
|
12
|
+
Each dataset class provides methods for downloading the data
|
|
13
|
+
(if not already available or synthetic), checking the integrity of the dataset, loading
|
|
14
|
+
the data from CSV files or generating synthetic data, and applying
|
|
15
|
+
transformations to the data.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import random
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from urllib.error import URLError
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
import torch
|
|
27
|
+
from torch.distributions import Dirichlet
|
|
28
|
+
from torch.utils.data import Dataset
|
|
29
|
+
from torchvision.datasets.utils import (
|
|
30
|
+
check_integrity,
|
|
31
|
+
download_and_extract_archive,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BiasCorrection(Dataset):
|
|
36
|
+
"""A dataset class for accessing the Bias Correction dataset.
|
|
37
|
+
|
|
38
|
+
This class extends the `Dataset` class and provides functionality for
|
|
39
|
+
downloading, loading, and transforming the Bias Correction dataset.
|
|
40
|
+
The dataset is focused on temperature forecast data and is made available
|
|
41
|
+
for use with PyTorch. If `download` is set to True, the dataset will be
|
|
42
|
+
downloaded if it is not already available. The data is then loaded,
|
|
43
|
+
and a transformation function is applied to it.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
root (Union[str, Path]): The root directory where the dataset
|
|
47
|
+
will be stored or loaded from.
|
|
48
|
+
transform (Callable): A function to transform the dataset
|
|
49
|
+
(e.g., preprocessing).
|
|
50
|
+
download (bool, optional): Whether to download the dataset if it's
|
|
51
|
+
not already present. Defaults to False.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If the dataset is not found and `download`
|
|
55
|
+
is not set to True or if all mirrors fail to provide the dataset.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
mirrors = ["https://archive.ics.uci.edu/static/public/514/"]
|
|
59
|
+
resources = [
|
|
60
|
+
(
|
|
61
|
+
"bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
|
|
62
|
+
"3deee56d461a2686887c4ae38fe3ccf3",
|
|
63
|
+
)
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
|
|
67
|
+
"""Constructor method to initialize the dataset."""
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.root = root
|
|
70
|
+
self.transform = transform
|
|
71
|
+
|
|
72
|
+
if download:
|
|
73
|
+
self.download()
|
|
74
|
+
|
|
75
|
+
if not self._check_exists():
|
|
76
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
|
77
|
+
|
|
78
|
+
self.data_input, self.data_output = self._load_data()
|
|
79
|
+
|
|
80
|
+
def _load_data(self):
|
|
81
|
+
"""Loads the dataset from the CSV file and applies the transformation.
|
|
82
|
+
|
|
83
|
+
The data is read from the `Bias_correction_ucl.csv` file, and the
|
|
84
|
+
transformation function is applied to it.
|
|
85
|
+
The input and output data are separated and returned as numpy arrays.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
|
|
89
|
+
and output data as numpy arrays.
|
|
90
|
+
"""
|
|
91
|
+
data: pd.DataFrame = pd.read_csv(
|
|
92
|
+
os.path.join(self.data_folder, "Bias_correction_ucl.csv")
|
|
93
|
+
).pipe(self.transform)
|
|
94
|
+
|
|
95
|
+
data_input = data["Input"].to_numpy(dtype=np.float32)
|
|
96
|
+
data_output = data["Output"].to_numpy(dtype=np.float32)
|
|
97
|
+
|
|
98
|
+
return data_input, data_output
|
|
99
|
+
|
|
100
|
+
def __len__(self):
|
|
101
|
+
"""Returns the number of examples in the dataset.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
int: The number of examples in the dataset
|
|
105
|
+
(i.e., the number of rows in the input data).
|
|
106
|
+
"""
|
|
107
|
+
return self.data_input.shape[0]
|
|
108
|
+
|
|
109
|
+
def __getitem__(self, idx):
|
|
110
|
+
"""Returns the input-output pair for a given index.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
idx (int): The index of the example to retrieve.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
dict: A dictionary with the following keys:
|
|
117
|
+
- "input" (torch.Tensor): The input features for the example.
|
|
118
|
+
- "target" (torch.Tensor): The target output for the example.
|
|
119
|
+
"""
|
|
120
|
+
example = self.data_input[idx, :]
|
|
121
|
+
target = self.data_output[idx, :]
|
|
122
|
+
example = torch.tensor(example)
|
|
123
|
+
target = torch.tensor(target)
|
|
124
|
+
return {"input": example, "target": target}
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def data_folder(self) -> str:
|
|
128
|
+
"""Returns the path to the folder where the dataset is stored.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
str: The path to the dataset folder.
|
|
132
|
+
"""
|
|
133
|
+
return os.path.join(self.root, self.__class__.__name__)
|
|
134
|
+
|
|
135
|
+
def _check_exists(self) -> bool:
|
|
136
|
+
"""Checks if the dataset is already downloaded and verified.
|
|
137
|
+
|
|
138
|
+
This method checks that all required files exist and
|
|
139
|
+
their integrity is validated via MD5 checksums.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
bool: True if all resources exist and their
|
|
143
|
+
integrity is valid, False otherwise.
|
|
144
|
+
"""
|
|
145
|
+
return all(
|
|
146
|
+
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
147
|
+
for file_path, checksum in self.resources
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def download(self) -> None:
|
|
151
|
+
"""Downloads and extracts the dataset.
|
|
152
|
+
|
|
153
|
+
This method attempts to download the dataset from the mirrors and
|
|
154
|
+
extract it into the appropriate folder. If any error occurs during
|
|
155
|
+
downloading, it will try each mirror in sequence.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
RuntimeError: If all mirrors fail to provide the dataset.
|
|
159
|
+
"""
|
|
160
|
+
if self._check_exists():
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
os.makedirs(self.data_folder, exist_ok=True)
|
|
164
|
+
|
|
165
|
+
# download files
|
|
166
|
+
for filename, md5 in self.resources:
|
|
167
|
+
errors = []
|
|
168
|
+
for mirror in self.mirrors:
|
|
169
|
+
url = f"{mirror}{filename}"
|
|
170
|
+
try:
|
|
171
|
+
download_and_extract_archive(
|
|
172
|
+
url, download_root=self.data_folder, filename=filename, md5=md5
|
|
173
|
+
)
|
|
174
|
+
except URLError as e:
|
|
175
|
+
errors.append(e)
|
|
176
|
+
continue
|
|
177
|
+
break
|
|
178
|
+
else:
|
|
179
|
+
s = f"Error downloading {filename}:\n"
|
|
180
|
+
for mirror, err in zip(self.mirrors, errors, strict=False):
|
|
181
|
+
s += f"Tried {mirror}, got:\n{str(err)}\n"
|
|
182
|
+
raise RuntimeError(s)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class FamilyIncome(Dataset):
|
|
186
|
+
"""A dataset class for accessing the Family Income and Expenditure dataset.
|
|
187
|
+
|
|
188
|
+
This class extends the `Dataset` class and provides functionality for
|
|
189
|
+
downloading, loading, and transforming the Family Income and
|
|
190
|
+
Expenditure dataset. The dataset is intended for use with
|
|
191
|
+
PyTorch-based projects, offering a convenient interface for data handling.
|
|
192
|
+
This class provides access to the Family Income and Expenditure dataset
|
|
193
|
+
for use with PyTorch. If `download` is set to True, the dataset will be
|
|
194
|
+
downloaded if it is not already available. The data is then loaded,
|
|
195
|
+
and a user-defined transformation function is applied to it.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
root (Union[str, Path]): The root directory where the dataset will
|
|
199
|
+
be stored or loaded from.
|
|
200
|
+
transform (Callable): A function to transform the dataset
|
|
201
|
+
(e.g., preprocessing).
|
|
202
|
+
download (bool, optional): Whether to download the dataset if it's
|
|
203
|
+
not already present. Defaults to False.
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
RuntimeError: If the dataset is not found and `download`
|
|
207
|
+
is not set to True or if all mirrors fail to provide the dataset.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
mirrors = [
|
|
211
|
+
"https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure"
|
|
212
|
+
]
|
|
213
|
+
resources = [
|
|
214
|
+
(
|
|
215
|
+
"archive.zip",
|
|
216
|
+
"7d74bc7facc3d7c07c4df1c1c6ac563e",
|
|
217
|
+
)
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
|
|
221
|
+
"""Constructor method to initialize the dataset."""
|
|
222
|
+
super().__init__()
|
|
223
|
+
self.root = root
|
|
224
|
+
self.transform = transform
|
|
225
|
+
|
|
226
|
+
if download:
|
|
227
|
+
self.download()
|
|
228
|
+
|
|
229
|
+
if not self._check_exists():
|
|
230
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it.")
|
|
231
|
+
|
|
232
|
+
self.data_input, self.data_output = self._load_data()
|
|
233
|
+
|
|
234
|
+
def _load_data(self):
|
|
235
|
+
"""Load and transform the Family Income and Expenditure dataset.
|
|
236
|
+
|
|
237
|
+
Reads the data from the `Family Income and Expenditure.csv` file located
|
|
238
|
+
in `self.data_folder` and applies the transformation function. The input
|
|
239
|
+
and output columns are extracted and returned as NumPy arrays.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
|
243
|
+
- input data as a NumPy array of type float32
|
|
244
|
+
- output data as a NumPy array of type float32
|
|
245
|
+
"""
|
|
246
|
+
data: pd.DataFrame = pd.read_csv(
|
|
247
|
+
os.path.join(self.data_folder, "Family Income and Expenditure.csv")
|
|
248
|
+
).pipe(self.transform)
|
|
249
|
+
|
|
250
|
+
data_input = data["Input"].to_numpy(dtype=np.float32)
|
|
251
|
+
data_output = data["Output"].to_numpy(dtype=np.float32)
|
|
252
|
+
|
|
253
|
+
return data_input, data_output
|
|
254
|
+
|
|
255
|
+
def __len__(self):
|
|
256
|
+
"""Returns the number of examples in the dataset.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
int: The number of examples in the dataset
|
|
260
|
+
(i.e., the number of rows in the input data).
|
|
261
|
+
"""
|
|
262
|
+
return self.data_input.shape[0]
|
|
263
|
+
|
|
264
|
+
def __getitem__(self, idx):
|
|
265
|
+
"""Returns the input-output pair for a given index.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
idx (int): The index of the example to retrieve.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
dict: A dictionary with the following keys:
|
|
272
|
+
- "input" (torch.Tensor): The input features for the example.
|
|
273
|
+
- "target" (torch.Tensor): The target output for the example.
|
|
274
|
+
"""
|
|
275
|
+
example = self.data_input[idx, :]
|
|
276
|
+
target = self.data_output[idx, :]
|
|
277
|
+
example = torch.tensor(example)
|
|
278
|
+
target = torch.tensor(target)
|
|
279
|
+
return {"input": example, "target": target}
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def data_folder(self) -> str:
|
|
283
|
+
"""Returns the path to the folder where the dataset is stored.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
str: The path to the dataset folder.
|
|
287
|
+
"""
|
|
288
|
+
return os.path.join(self.root, self.__class__.__name__)
|
|
289
|
+
|
|
290
|
+
def _check_exists(self) -> bool:
|
|
291
|
+
"""Checks if the dataset is already downloaded and verified.
|
|
292
|
+
|
|
293
|
+
This method checks that all required files exist and
|
|
294
|
+
their integrity is validated via MD5 checksums.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
bool: True if all resources exist and their
|
|
298
|
+
integrity is valid, False otherwise.
|
|
299
|
+
"""
|
|
300
|
+
return all(
|
|
301
|
+
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
302
|
+
for file_path, checksum in self.resources
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def download(self) -> None:
|
|
306
|
+
"""Downloads and extracts the dataset.
|
|
307
|
+
|
|
308
|
+
This method attempts to download the dataset from the mirrors
|
|
309
|
+
and extract it into the appropriate folder. If any error occurs
|
|
310
|
+
during downloading, it will try each mirror in sequence.
|
|
311
|
+
|
|
312
|
+
Raises:
|
|
313
|
+
RuntimeError: If all mirrors fail to provide the dataset.
|
|
314
|
+
"""
|
|
315
|
+
if self._check_exists():
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
os.makedirs(self.data_folder, exist_ok=True)
|
|
319
|
+
|
|
320
|
+
# download files
|
|
321
|
+
for filename, md5 in self.resources:
|
|
322
|
+
errors = []
|
|
323
|
+
for mirror in self.mirrors:
|
|
324
|
+
url = f"{mirror}"
|
|
325
|
+
try:
|
|
326
|
+
download_and_extract_archive(
|
|
327
|
+
url, download_root=self.data_folder, filename=filename, md5=md5
|
|
328
|
+
)
|
|
329
|
+
except URLError as e:
|
|
330
|
+
errors.append(e)
|
|
331
|
+
continue
|
|
332
|
+
break
|
|
333
|
+
else:
|
|
334
|
+
s = f"Error downloading {filename}:\n"
|
|
335
|
+
for mirror, err in zip(self.mirrors, errors, strict=False):
|
|
336
|
+
s += f"Tried {mirror}, got:\n{str(err)}\n"
|
|
337
|
+
raise RuntimeError(s)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class SectionedGaussians(Dataset):
|
|
341
|
+
"""Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
|
|
342
|
+
|
|
343
|
+
Each section defines a subrange of x-values with its own Gaussian distribution
|
|
344
|
+
(mean and standard deviation). Instead of abrupt transitions, the parameters
|
|
345
|
+
are blended smoothly between sections using a sigmoid function.
|
|
346
|
+
|
|
347
|
+
The resulting signal can represent a continuous process where statistical
|
|
348
|
+
properties gradually evolve over time or position.
|
|
349
|
+
|
|
350
|
+
Features:
|
|
351
|
+
- Input: Gaussian signal samples (y-values)
|
|
352
|
+
- Context: Concatenation of time (x) and normalized energy feature
|
|
353
|
+
- Target: Exponential decay ground truth from 1 at x_min to 0 at x_max
|
|
354
|
+
|
|
355
|
+
Attributes:
|
|
356
|
+
sections (list[dict]): List of section definitions.
|
|
357
|
+
n_samples (int): Total number of samples across all sections.
|
|
358
|
+
n_runs (int): Number of random waveforms generated from base configuration.
|
|
359
|
+
time (torch.Tensor): Sampled x-values, shape [n_samples, 1].
|
|
360
|
+
signal (torch.Tensor): Generated Gaussian signal values, shape [n_samples, 1].
|
|
361
|
+
energies (torch.Tensor): Normalized energy feature, shape [n_samples, 1].
|
|
362
|
+
context (torch.Tensor): Concatenation of time and energy, shape [n_samples, 2].
|
|
363
|
+
x_min (float): Minimum x-value across all sections.
|
|
364
|
+
x_max (float): Maximum x-value across all sections.
|
|
365
|
+
ground_truth_steepness (float): Exponential decay steepness for target output.
|
|
366
|
+
blend_k (float): Sharpness parameter controlling how rapidly means and
|
|
367
|
+
standard deviations transition between sections.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
sections: list[dict],
|
|
373
|
+
n_samples: int = 1000,
|
|
374
|
+
n_runs: int = 1,
|
|
375
|
+
seed: int | None = None,
|
|
376
|
+
device="cpu",
|
|
377
|
+
ground_truth_steepness: float = 0.0,
|
|
378
|
+
blend_k: float = 10.0,
|
|
379
|
+
):
|
|
380
|
+
"""Initializes the dataset and generates smoothly blended Gaussian signals.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
sections (list[dict]): List of sections. Each section must define:
|
|
384
|
+
- range (tuple[float, float]): Start and end of the x-interval.
|
|
385
|
+
- mean (float): Mean of the Gaussian for this section.
|
|
386
|
+
- std (float): Standard deviation of the Gaussian for this section.
|
|
387
|
+
- max_splits (int): Max number of extra splits per section.
|
|
388
|
+
- split_prob (float): Probability of splitting a section.
|
|
389
|
+
- mean_var (float): How much to vary mean (fraction of original).
|
|
390
|
+
- std_var (float): How much to vary std (fraction of original).
|
|
391
|
+
- range_var (float): How much to vary start/end positions (fraction of section length).
|
|
392
|
+
n_samples (int, optional): Total number of samples to generate. Defaults to 1000.
|
|
393
|
+
n_runs (int, optional): Number of random waveforms to generate from the
|
|
394
|
+
base configuration. Defaults to 1.
|
|
395
|
+
seed (int or None, optional): Random seed for reproducibility. Defaults to None.
|
|
396
|
+
device (str or torch.device, optional): Device on which tensors are allocated.
|
|
397
|
+
Defaults to "cpu".
|
|
398
|
+
ground_truth_steepness (float, optional): Controls how sharply the ground-truth
|
|
399
|
+
exponential decay decreases from 1 to 0. Defaults to 0.0 (linear decay).
|
|
400
|
+
blend_k (float, optional): Controls the sharpness of the sigmoid blending
|
|
401
|
+
between sections. Higher values make the transition steeper; lower
|
|
402
|
+
values make it smoother. Defaults to 10.0.
|
|
403
|
+
"""
|
|
404
|
+
self.sections = sections
|
|
405
|
+
self.n_samples = n_samples
|
|
406
|
+
self.n_runs = n_runs
|
|
407
|
+
self.device = device
|
|
408
|
+
self.blend_k = blend_k
|
|
409
|
+
self.ground_truth_steepness = torch.tensor(ground_truth_steepness, device=device)
|
|
410
|
+
|
|
411
|
+
if seed is not None:
|
|
412
|
+
torch.manual_seed(seed)
|
|
413
|
+
|
|
414
|
+
time: list[torch.Tensor] = []
|
|
415
|
+
signal: list[torch.Tensor] = []
|
|
416
|
+
energy: list[torch.Tensor] = []
|
|
417
|
+
identifier: list[torch.Tensor] = []
|
|
418
|
+
|
|
419
|
+
# Compute global min/max for time
|
|
420
|
+
self.t_min = min(s["range"][0] for s in sections)
|
|
421
|
+
self.t_max = max(s["range"][1] for s in sections)
|
|
422
|
+
|
|
423
|
+
# Generate waveforms from base configuration
|
|
424
|
+
for run in range(self.n_runs):
|
|
425
|
+
waveform = self._generate_waveform()
|
|
426
|
+
time.append(waveform[0])
|
|
427
|
+
signal.append(waveform[1])
|
|
428
|
+
energy.append(waveform[2])
|
|
429
|
+
identifier.append(torch.full_like(waveform[0], float(run)))
|
|
430
|
+
|
|
431
|
+
# Concatenate runs into single tensors
|
|
432
|
+
self.time = torch.cat(time, dim=0)
|
|
433
|
+
self.signal = torch.cat(signal, dim=0)
|
|
434
|
+
self.energy = torch.cat(energy, dim=0)
|
|
435
|
+
self.identifier = torch.cat(identifier, dim=0)
|
|
436
|
+
self.context = torch.hstack([self.time, self.energy, self.identifier])
|
|
437
|
+
|
|
438
|
+
# Adjust n_samples in case of rounding mismatch
|
|
439
|
+
self.n_samples = len(self.time)
|
|
440
|
+
|
|
441
|
+
def _generate_waveform(
|
|
442
|
+
self,
|
|
443
|
+
):
|
|
444
|
+
sections = []
|
|
445
|
+
|
|
446
|
+
prev_mean = 0
|
|
447
|
+
|
|
448
|
+
for sec in self.sections:
|
|
449
|
+
start, end = sec["range"]
|
|
450
|
+
mean, std = sec["add_mean"], sec["std"]
|
|
451
|
+
max_splits = sec["max_splits"]
|
|
452
|
+
split_prob = sec["split_prob"]
|
|
453
|
+
mean_scale = sec["mean_var"]
|
|
454
|
+
std_scale = sec["std_var"]
|
|
455
|
+
range_scale = sec["range_var"]
|
|
456
|
+
section_len = end - start
|
|
457
|
+
|
|
458
|
+
# Decide whether to split this section into multiple contiguous parts
|
|
459
|
+
n_splits = 1
|
|
460
|
+
if random.random() < split_prob:
|
|
461
|
+
n_splits += random.randint(1, max_splits)
|
|
462
|
+
|
|
463
|
+
# Randomly divide the section into contiguous subsections
|
|
464
|
+
random_fracs = Dirichlet(torch.ones(n_splits) * 3).sample()
|
|
465
|
+
sub_lengths = (random_fracs * section_len).tolist()
|
|
466
|
+
|
|
467
|
+
sub_start = start
|
|
468
|
+
for base_len in sub_lengths:
|
|
469
|
+
# Compute unperturbed subsection boundaries
|
|
470
|
+
base_start = sub_start
|
|
471
|
+
base_end = base_start + base_len
|
|
472
|
+
|
|
473
|
+
# Define smooth range variation parameters
|
|
474
|
+
max_shift = range_scale * base_len / 2
|
|
475
|
+
|
|
476
|
+
# Decide how much to increase the mean overall
|
|
477
|
+
prev_mean += mean * random.uniform(0.0, 2 * mean_scale)
|
|
478
|
+
|
|
479
|
+
# Apply small random shifts to start/end, relative to subsection size
|
|
480
|
+
if range_scale > 0:
|
|
481
|
+
# Keep total ordering consistent (avoid overlaps or inversions)
|
|
482
|
+
local_shift = random.uniform(-max_shift / n_splits, max_shift / n_splits)
|
|
483
|
+
new_start = base_start + local_shift
|
|
484
|
+
new_end = base_end + local_shift
|
|
485
|
+
else:
|
|
486
|
+
new_start, new_end = base_start, base_end
|
|
487
|
+
|
|
488
|
+
# Clamp to valid time bounds
|
|
489
|
+
new_start = max(self.t_min, min(new_start, end))
|
|
490
|
+
new_end = max(new_start, min(new_end, end))
|
|
491
|
+
|
|
492
|
+
# Randomize mean and std within allowed ranges
|
|
493
|
+
mean_var = prev_mean * (1 + random.uniform(-mean_scale, mean_scale))
|
|
494
|
+
std_var = std * (1 + random.uniform(-std_scale, std_scale))
|
|
495
|
+
|
|
496
|
+
sections.append(
|
|
497
|
+
{
|
|
498
|
+
"range": (new_start, new_end),
|
|
499
|
+
"mean": mean_var,
|
|
500
|
+
"std": std_var,
|
|
501
|
+
}
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Prepare for next subsection
|
|
505
|
+
sub_start = base_end
|
|
506
|
+
prev_mean = mean_var
|
|
507
|
+
|
|
508
|
+
# Ensure last section ends exactly at x_max
|
|
509
|
+
sections[-1]["range"] = (sections[-1]["range"][0], self.t_max)
|
|
510
|
+
|
|
511
|
+
# Precompute exact float counts
|
|
512
|
+
section_lengths = [s["range"][1] - s["range"][0] for s in sections]
|
|
513
|
+
total_length = sum(section_lengths)
|
|
514
|
+
exact_counts = [length / total_length * self.n_samples for length in section_lengths]
|
|
515
|
+
|
|
516
|
+
# Allocate integer samples with rounding, track residual
|
|
517
|
+
n_section_samples_list = []
|
|
518
|
+
residual = 0.0
|
|
519
|
+
for exact in exact_counts[:-1]:
|
|
520
|
+
n = int(round(exact + residual))
|
|
521
|
+
n_section_samples_list.append(n)
|
|
522
|
+
residual += exact - n # carry over rounding error
|
|
523
|
+
|
|
524
|
+
# Assign remaining samples to last section
|
|
525
|
+
n_section_samples_list.append(self.n_samples - sum(n_section_samples_list))
|
|
526
|
+
|
|
527
|
+
# Generate data for each section
|
|
528
|
+
signal_segments: list[torch.Tensor] = []
|
|
529
|
+
for i, section in enumerate(sections):
|
|
530
|
+
start, end = section["range"]
|
|
531
|
+
mean, std = section["mean"], section["std"]
|
|
532
|
+
|
|
533
|
+
# Samples proportional to section length
|
|
534
|
+
n_section_samples = n_section_samples_list[i]
|
|
535
|
+
|
|
536
|
+
# Determine next section’s parameters for blending
|
|
537
|
+
if i < len(sections) - 1:
|
|
538
|
+
next_mean = sections[i + 1]["mean"]
|
|
539
|
+
next_std = sections[i + 1]["std"]
|
|
540
|
+
else:
|
|
541
|
+
next_mean = mean
|
|
542
|
+
next_std = std
|
|
543
|
+
|
|
544
|
+
# Sigmoid-based blend curve from 0 → 1
|
|
545
|
+
x = torch.linspace(
|
|
546
|
+
-self.blend_k, self.blend_k, n_section_samples, device=self.device
|
|
547
|
+
).unsqueeze(1)
|
|
548
|
+
fade = torch.sigmoid(x)
|
|
549
|
+
fade = (fade - fade.min()) / (fade.max() - fade.min())
|
|
550
|
+
|
|
551
|
+
# Interpolate mean and std within the section
|
|
552
|
+
mean_curve = mean + (next_mean - mean) * fade
|
|
553
|
+
std_curve = std + (next_std - std) * fade
|
|
554
|
+
|
|
555
|
+
# Sample from a gradually changing Gaussian
|
|
556
|
+
y = torch.normal(mean=mean_curve, std=std_curve)
|
|
557
|
+
signal_segments.append(y)
|
|
558
|
+
|
|
559
|
+
# Concatenate tensors
|
|
560
|
+
time: torch.Tensor = torch.linspace(
|
|
561
|
+
self.t_min, self.t_max, self.n_samples, device=self.device
|
|
562
|
+
).unsqueeze(1)
|
|
563
|
+
signal: torch.Tensor = torch.cat(signal_segments, dim=0)
|
|
564
|
+
|
|
565
|
+
# Compute and normalize energy feature
|
|
566
|
+
energy = torch.linalg.vector_norm(
|
|
567
|
+
signal - signal[0].reshape([1, -1]), ord=2, dim=1
|
|
568
|
+
).unsqueeze(1)
|
|
569
|
+
min_e, max_e = energy.min(), energy.max()
|
|
570
|
+
energy = 1 - (energy - min_e) / (max_e - min_e) * 3
|
|
571
|
+
|
|
572
|
+
# Combine into context tensor
|
|
573
|
+
return time, signal, energy
|
|
574
|
+
|
|
575
|
+
def _compute_ground_truth(self, time: torch.Tensor) -> torch.Tensor:
|
|
576
|
+
"""Computes the ground-truth exponential decay at coordinate t.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
time (torch.Tensor): Input coordinate (shape: [1]).
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
torch.Tensor: Corresponding target value between 0 and 1.
|
|
583
|
+
"""
|
|
584
|
+
time_norm = (time - self.t_min) / (self.t_max - self.t_min)
|
|
585
|
+
k = self.ground_truth_steepness
|
|
586
|
+
if k == 0:
|
|
587
|
+
return 1 - time_norm
|
|
588
|
+
return (torch.exp(-k * time_norm) - torch.exp(-k)) / (1 - torch.exp(-k))
|
|
589
|
+
|
|
590
|
+
def __len__(self):
|
|
591
|
+
"""Returns the total number of generated samples."""
|
|
592
|
+
return self.n_samples
|
|
593
|
+
|
|
594
|
+
def __getitem__(self, idx):
|
|
595
|
+
"""Retrieves a single dataset sample.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
idx (int): Sample index.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
dict:
|
|
602
|
+
- "input": Gaussian signal value (torch.Tensor),
|
|
603
|
+
- "context: Concatenated features (time, energy, identifier),
|
|
604
|
+
- "target": Ground-truth exponential decay value
|
|
605
|
+
|
|
606
|
+
"""
|
|
607
|
+
sigal = self.signal[idx]
|
|
608
|
+
context = self.context[idx]
|
|
609
|
+
time = self.time[idx]
|
|
610
|
+
target = self._compute_ground_truth(time)
|
|
611
|
+
|
|
612
|
+
return {
|
|
613
|
+
"input": sigal,
|
|
614
|
+
"context": context,
|
|
615
|
+
"target": target,
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class SyntheticMonotonicity(Dataset):
|
|
620
|
+
"""Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
|
|
621
|
+
|
|
622
|
+
True function:
|
|
623
|
+
y_true(x) = log(1 + x)
|
|
624
|
+
|
|
625
|
+
Observed:
|
|
626
|
+
y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
n_samples (int): number of data points (default 200)
|
|
630
|
+
x_range (tuple): range of x values (default [0, 5])
|
|
631
|
+
noise_base (float): baseline noise level (default 0.05)
|
|
632
|
+
noise_scale (float): scale of heteroscedastic noise (default 0.15)
|
|
633
|
+
noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
|
|
634
|
+
noise_center (float): center point of heteroscedastic increase (default 2.5)
|
|
635
|
+
osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
|
|
636
|
+
osc_frequency (float): frequency of oscillation (default 6.0)
|
|
637
|
+
osc_prob (float): probability each sample receives oscillation (default 0.5)
|
|
638
|
+
seed (int or None): random seed
|
|
639
|
+
"""
|
|
640
|
+
|
|
641
|
+
def __init__(
|
|
642
|
+
self,
|
|
643
|
+
n_samples=200,
|
|
644
|
+
x_range=(0.0, 5.0),
|
|
645
|
+
noise_base=0.05,
|
|
646
|
+
noise_scale=0.15,
|
|
647
|
+
noise_sharpness=4.0,
|
|
648
|
+
noise_center=2.5,
|
|
649
|
+
osc_amplitude=0.08,
|
|
650
|
+
osc_frequency=6.0,
|
|
651
|
+
osc_prob=0.5,
|
|
652
|
+
seed=None,
|
|
653
|
+
):
|
|
654
|
+
"""Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
|
|
655
|
+
|
|
656
|
+
True function:
|
|
657
|
+
y_true(x) = log(1 + x)
|
|
658
|
+
|
|
659
|
+
Observed:
|
|
660
|
+
y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
n_samples (int): number of data points (default 200)
|
|
664
|
+
x_range (tuple): range of x values (default [0, 5])
|
|
665
|
+
noise_base (float): baseline noise level (default 0.05)
|
|
666
|
+
noise_scale (float): scale of heteroscedastic noise (default 0.15)
|
|
667
|
+
noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
|
|
668
|
+
noise_center (float): center point of heteroscedastic increase (default 2.5)
|
|
669
|
+
osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
|
|
670
|
+
osc_frequency (float): frequency of oscillation (default 6.0)
|
|
671
|
+
osc_prob (float): probability each sample receives oscillation (default 0.5)
|
|
672
|
+
seed (int or None): random seed
|
|
673
|
+
"""
|
|
674
|
+
super().__init__()
|
|
675
|
+
if seed is not None:
|
|
676
|
+
np.random.seed(seed)
|
|
677
|
+
|
|
678
|
+
self.n_samples = n_samples
|
|
679
|
+
self.x_range = x_range
|
|
680
|
+
|
|
681
|
+
# Sample inputs
|
|
682
|
+
self.x = np.random.rand(n_samples) * (x_range[1] - x_range[0]) + x_range[0]
|
|
683
|
+
self.x = np.sort(self.x)
|
|
684
|
+
|
|
685
|
+
# Heteroscedastic noise (logistic growth with x)
|
|
686
|
+
noise_sigma = noise_base + noise_scale / (
|
|
687
|
+
1 + np.exp(-noise_sharpness * (self.x - noise_center))
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Oscillatory perturbation (applied with probability osc_prob)
|
|
691
|
+
mask = (np.random.rand(n_samples) < osc_prob).astype(float)
|
|
692
|
+
osc = mask * (osc_amplitude * np.sin(osc_frequency * self.x))
|
|
693
|
+
|
|
694
|
+
# Final observed target
|
|
695
|
+
self.y = np.log1p(self.x) + osc + noise_sigma * np.random.randn(n_samples)
|
|
696
|
+
|
|
697
|
+
# Convert to tensors
|
|
698
|
+
self.inputs = torch.tensor(self.x, dtype=torch.float32).unsqueeze(1)
|
|
699
|
+
self.targets = torch.tensor(self.y, dtype=torch.float32).unsqueeze(1)
|
|
700
|
+
|
|
701
|
+
def __len__(self):
|
|
702
|
+
"""Return the total number of samples in the dataset.
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
int: The number of generated data points.
|
|
706
|
+
"""
|
|
707
|
+
return self.n_samples
|
|
708
|
+
|
|
709
|
+
def __getitem__(self, idx) -> dict:
|
|
710
|
+
"""Retrieve a single sample from the dataset.
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
idx (int): Index of the sample to retrieve.
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
dict: A dictionary with the following keys:
|
|
717
|
+
- "input" (torch.Tensor): The x value.
|
|
718
|
+
- "target" (torch.Tensor): The corresponding y value.
|
|
719
|
+
"""
|
|
720
|
+
return {"input": self.inputs[idx], "target": self.targets[idx]}
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class SyntheticClusters(Dataset):
|
|
724
|
+
"""PyTorch dataset for generating synthetic clustered 2D data with labels.
|
|
725
|
+
|
|
726
|
+
Each cluster is defined by its center, size, spread (standard deviation), and label.
|
|
727
|
+
The dataset samples points from a Gaussian distribution centered at the cluster mean.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
|
|
731
|
+
e.g. [(x1, y1), (x2, y2), ...].
|
|
732
|
+
cluster_sizes (list[int]): Number of points to generate in each cluster.
|
|
733
|
+
cluster_std (list[float]): Standard deviation (spread) of each cluster.
|
|
734
|
+
cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
|
|
735
|
+
|
|
736
|
+
Raises:
|
|
737
|
+
AssertionError: If the input lists do not all have the same length.
|
|
738
|
+
|
|
739
|
+
Attributes:
|
|
740
|
+
data (torch.Tensor): A concatenated tensor of all generated points with shape (N, 2).
|
|
741
|
+
labels (torch.Tensor): A concatenated tensor of class labels with shape (N,),
|
|
742
|
+
where N is the total number of generated points.
|
|
743
|
+
"""
|
|
744
|
+
|
|
745
|
+
def __init__(self, cluster_centers, cluster_sizes, cluster_std, cluster_labels):
|
|
746
|
+
"""Initialize the ClusterDataset.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
|
|
750
|
+
e.g. [(x1, y1), (x2, y2), ...].
|
|
751
|
+
cluster_sizes (list[int]): Number of points to generate in each cluster.
|
|
752
|
+
cluster_std (list[float]): Standard deviation (spread) of each cluster.
|
|
753
|
+
cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
|
|
754
|
+
|
|
755
|
+
Raises:
|
|
756
|
+
AssertionError: If the input lists do not all have the same length.
|
|
757
|
+
"""
|
|
758
|
+
assert (
|
|
759
|
+
len(cluster_centers) == len(cluster_sizes) == len(cluster_std) == len(cluster_labels)
|
|
760
|
+
), "All input lists must have the same length"
|
|
761
|
+
|
|
762
|
+
self.data = []
|
|
763
|
+
self.labels = []
|
|
764
|
+
|
|
765
|
+
# Generate points for each cluster
|
|
766
|
+
for center, size, std, label in zip(
|
|
767
|
+
cluster_centers, cluster_sizes, cluster_std, cluster_labels, strict=False
|
|
768
|
+
):
|
|
769
|
+
x = torch.normal(mean=center[0], std=std, size=(size, 1))
|
|
770
|
+
y = torch.normal(mean=center[1], std=std, size=(size, 1))
|
|
771
|
+
points = torch.cat([x, y], dim=1)
|
|
772
|
+
|
|
773
|
+
self.data.append(points)
|
|
774
|
+
self.labels.append(torch.full((size,), label, dtype=torch.long))
|
|
775
|
+
|
|
776
|
+
# Concatenate all clusters
|
|
777
|
+
self.data = torch.cat(self.data, dim=0)
|
|
778
|
+
self.labels = torch.cat(self.labels, dim=0)
|
|
779
|
+
|
|
780
|
+
def __len__(self):
|
|
781
|
+
"""Return the total number of samples in the dataset.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
int: The number of generated data points.
|
|
785
|
+
"""
|
|
786
|
+
return len(self.data)
|
|
787
|
+
|
|
788
|
+
def __getitem__(self, idx) -> dict:
|
|
789
|
+
"""Retrieve a single sample from the dataset.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
idx (int): Index of the sample to retrieve.
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
dict: A dictionary with the following keys:
|
|
796
|
+
- "input" (torch.Tensor): The 2D point at index `idx`.
|
|
797
|
+
- "target" (torch.Tensor): The corresponding class label.
|
|
798
|
+
"""
|
|
799
|
+
return {"input": self.data[idx], "target": self.labels[idx]}
|