congrads 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,809 @@
1
+ """This module defines several PyTorch dataset classes for loading and working with various datasets.
2
+
3
+ Each dataset class extends the `torch.utils.data.Dataset` class and provides functionality for
4
+ downloading, loading, and transforming specific datasets where applicable.
5
+
6
+ Classes:
7
+
8
+ - BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
9
+ - FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
10
+ - SectionedGaussians: A synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
11
+ - SyntheticMonotonicity: A synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
12
+ - SyntheticClusters: A dataset class for generating synthetic clustered 2D data with labels.
13
+
14
+ Each dataset class provides methods for downloading the data
15
+ (if not already available or synthetic), checking the integrity of the dataset, loading
16
+ the data from CSV files or generating synthetic data, and applying
17
+ transformations to the data.
18
+ """
19
+
20
+ import os
21
+ import random
22
+ from collections.abc import Callable
23
+ from pathlib import Path
24
+ from urllib.error import URLError
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ import torch
29
+ from torch.distributions import Dirichlet
30
+ from torch.utils.data import Dataset
31
+ from torchvision.datasets.utils import (
32
+ check_integrity,
33
+ download_and_extract_archive,
34
+ )
35
+
36
+ __all__ = [
37
+ "BiasCorrection",
38
+ "FamilyIncome",
39
+ "SectionedGaussians",
40
+ "SyntheticMonotonicity",
41
+ "SyntheticClusters",
42
+ ]
43
+
44
+
45
+ class BiasCorrection(Dataset):
46
+ """A dataset class for accessing the Bias Correction dataset.
47
+
48
+ This class extends the `Dataset` class and provides functionality for
49
+ downloading, loading, and transforming the Bias Correction dataset.
50
+ The dataset is focused on temperature forecast data and is made available
51
+ for use with PyTorch. If `download` is set to True, the dataset will be
52
+ downloaded if it is not already available. The data is then loaded,
53
+ and a transformation function is applied to it.
54
+
55
+ Args:
56
+ root (Union[str, Path]): The root directory where the dataset
57
+ will be stored or loaded from.
58
+ transform (Callable): A function to transform the dataset
59
+ (e.g., preprocessing).
60
+ download (bool, optional): Whether to download the dataset if it's
61
+ not already present. Defaults to False.
62
+
63
+ Raises:
64
+ RuntimeError: If the dataset is not found and `download`
65
+ is not set to True or if all mirrors fail to provide the dataset.
66
+ """
67
+
68
+ mirrors = ["https://archive.ics.uci.edu/static/public/514/"]
69
+ resources = [
70
+ (
71
+ "bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
72
+ "3deee56d461a2686887c4ae38fe3ccf3",
73
+ )
74
+ ]
75
+
76
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
77
+ """Constructor method to initialize the dataset."""
78
+ super().__init__()
79
+ self.root = root
80
+ self.transform = transform
81
+
82
+ if download:
83
+ self.download()
84
+
85
+ if not self._check_exists():
86
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
87
+
88
+ self.data_input, self.data_output = self._load_data()
89
+
90
+ def _load_data(self):
91
+ """Loads the dataset from the CSV file and applies the transformation.
92
+
93
+ The data is read from the `Bias_correction_ucl.csv` file, and the
94
+ transformation function is applied to it.
95
+ The input and output data are separated and returned as numpy arrays.
96
+
97
+ Returns:
98
+ Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
99
+ and output data as numpy arrays.
100
+ """
101
+ data: pd.DataFrame = pd.read_csv(
102
+ os.path.join(self.data_folder, "Bias_correction_ucl.csv")
103
+ ).pipe(self.transform)
104
+
105
+ data_input = data["Input"].to_numpy(dtype=np.float32)
106
+ data_output = data["Output"].to_numpy(dtype=np.float32)
107
+
108
+ return data_input, data_output
109
+
110
+ def __len__(self):
111
+ """Returns the number of examples in the dataset.
112
+
113
+ Returns:
114
+ int: The number of examples in the dataset
115
+ (i.e., the number of rows in the input data).
116
+ """
117
+ return self.data_input.shape[0]
118
+
119
+ def __getitem__(self, idx):
120
+ """Returns the input-output pair for a given index.
121
+
122
+ Args:
123
+ idx (int): The index of the example to retrieve.
124
+
125
+ Returns:
126
+ dict: A dictionary with the following keys:
127
+ - "input" (torch.Tensor): The input features for the example.
128
+ - "target" (torch.Tensor): The target output for the example.
129
+ """
130
+ example = self.data_input[idx, :]
131
+ target = self.data_output[idx, :]
132
+ example = torch.tensor(example)
133
+ target = torch.tensor(target)
134
+ return {"input": example, "target": target}
135
+
136
+ @property
137
+ def data_folder(self) -> str:
138
+ """Returns the path to the folder where the dataset is stored.
139
+
140
+ Returns:
141
+ str: The path to the dataset folder.
142
+ """
143
+ return os.path.join(self.root, self.__class__.__name__)
144
+
145
+ def _check_exists(self) -> bool:
146
+ """Checks if the dataset is already downloaded and verified.
147
+
148
+ This method checks that all required files exist and
149
+ their integrity is validated via MD5 checksums.
150
+
151
+ Returns:
152
+ bool: True if all resources exist and their
153
+ integrity is valid, False otherwise.
154
+ """
155
+ return all(
156
+ check_integrity(os.path.join(self.data_folder, file_path), checksum)
157
+ for file_path, checksum in self.resources
158
+ )
159
+
160
+ def download(self) -> None:
161
+ """Downloads and extracts the dataset.
162
+
163
+ This method attempts to download the dataset from the mirrors and
164
+ extract it into the appropriate folder. If any error occurs during
165
+ downloading, it will try each mirror in sequence.
166
+
167
+ Raises:
168
+ RuntimeError: If all mirrors fail to provide the dataset.
169
+ """
170
+ if self._check_exists():
171
+ return
172
+
173
+ os.makedirs(self.data_folder, exist_ok=True)
174
+
175
+ # download files
176
+ for filename, md5 in self.resources:
177
+ errors = []
178
+ for mirror in self.mirrors:
179
+ url = f"{mirror}{filename}"
180
+ try:
181
+ download_and_extract_archive(
182
+ url, download_root=self.data_folder, filename=filename, md5=md5
183
+ )
184
+ except URLError as e:
185
+ errors.append(e)
186
+ continue
187
+ break
188
+ else:
189
+ s = f"Error downloading {filename}:\n"
190
+ for mirror, err in zip(self.mirrors, errors, strict=False):
191
+ s += f"Tried {mirror}, got:\n{str(err)}\n"
192
+ raise RuntimeError(s)
193
+
194
+
195
+ class FamilyIncome(Dataset):
196
+ """A dataset class for accessing the Family Income and Expenditure dataset.
197
+
198
+ This class extends the `Dataset` class and provides functionality for
199
+ downloading, loading, and transforming the Family Income and
200
+ Expenditure dataset. The dataset is intended for use with
201
+ PyTorch-based projects, offering a convenient interface for data handling.
202
+ This class provides access to the Family Income and Expenditure dataset
203
+ for use with PyTorch. If `download` is set to True, the dataset will be
204
+ downloaded if it is not already available. The data is then loaded,
205
+ and a user-defined transformation function is applied to it.
206
+
207
+ Args:
208
+ root (Union[str, Path]): The root directory where the dataset will
209
+ be stored or loaded from.
210
+ transform (Callable): A function to transform the dataset
211
+ (e.g., preprocessing).
212
+ download (bool, optional): Whether to download the dataset if it's
213
+ not already present. Defaults to False.
214
+
215
+ Raises:
216
+ RuntimeError: If the dataset is not found and `download`
217
+ is not set to True or if all mirrors fail to provide the dataset.
218
+ """
219
+
220
+ mirrors = [
221
+ "https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure"
222
+ ]
223
+ resources = [
224
+ (
225
+ "archive.zip",
226
+ "7d74bc7facc3d7c07c4df1c1c6ac563e",
227
+ )
228
+ ]
229
+
230
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
231
+ """Constructor method to initialize the dataset."""
232
+ super().__init__()
233
+ self.root = root
234
+ self.transform = transform
235
+
236
+ if download:
237
+ self.download()
238
+
239
+ if not self._check_exists():
240
+ raise RuntimeError("Dataset not found. You can use download=True to download it.")
241
+
242
+ self.data_input, self.data_output = self._load_data()
243
+
244
+ def _load_data(self):
245
+ """Load and transform the Family Income and Expenditure dataset.
246
+
247
+ Reads the data from the `Family Income and Expenditure.csv` file located
248
+ in `self.data_folder` and applies the transformation function. The input
249
+ and output columns are extracted and returned as NumPy arrays.
250
+
251
+ Returns:
252
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
253
+ - input data as a NumPy array of type float32
254
+ - output data as a NumPy array of type float32
255
+ """
256
+ data: pd.DataFrame = pd.read_csv(
257
+ os.path.join(self.data_folder, "Family Income and Expenditure.csv")
258
+ ).pipe(self.transform)
259
+
260
+ data_input = data["Input"].to_numpy(dtype=np.float32)
261
+ data_output = data["Output"].to_numpy(dtype=np.float32)
262
+
263
+ return data_input, data_output
264
+
265
+ def __len__(self):
266
+ """Returns the number of examples in the dataset.
267
+
268
+ Returns:
269
+ int: The number of examples in the dataset
270
+ (i.e., the number of rows in the input data).
271
+ """
272
+ return self.data_input.shape[0]
273
+
274
+ def __getitem__(self, idx):
275
+ """Returns the input-output pair for a given index.
276
+
277
+ Args:
278
+ idx (int): The index of the example to retrieve.
279
+
280
+ Returns:
281
+ dict: A dictionary with the following keys:
282
+ - "input" (torch.Tensor): The input features for the example.
283
+ - "target" (torch.Tensor): The target output for the example.
284
+ """
285
+ example = self.data_input[idx, :]
286
+ target = self.data_output[idx, :]
287
+ example = torch.tensor(example)
288
+ target = torch.tensor(target)
289
+ return {"input": example, "target": target}
290
+
291
+ @property
292
+ def data_folder(self) -> str:
293
+ """Returns the path to the folder where the dataset is stored.
294
+
295
+ Returns:
296
+ str: The path to the dataset folder.
297
+ """
298
+ return os.path.join(self.root, self.__class__.__name__)
299
+
300
+ def _check_exists(self) -> bool:
301
+ """Checks if the dataset is already downloaded and verified.
302
+
303
+ This method checks that all required files exist and
304
+ their integrity is validated via MD5 checksums.
305
+
306
+ Returns:
307
+ bool: True if all resources exist and their
308
+ integrity is valid, False otherwise.
309
+ """
310
+ return all(
311
+ check_integrity(os.path.join(self.data_folder, file_path), checksum)
312
+ for file_path, checksum in self.resources
313
+ )
314
+
315
+ def download(self) -> None:
316
+ """Downloads and extracts the dataset.
317
+
318
+ This method attempts to download the dataset from the mirrors
319
+ and extract it into the appropriate folder. If any error occurs
320
+ during downloading, it will try each mirror in sequence.
321
+
322
+ Raises:
323
+ RuntimeError: If all mirrors fail to provide the dataset.
324
+ """
325
+ if self._check_exists():
326
+ return
327
+
328
+ os.makedirs(self.data_folder, exist_ok=True)
329
+
330
+ # download files
331
+ for filename, md5 in self.resources:
332
+ errors = []
333
+ for mirror in self.mirrors:
334
+ url = f"{mirror}"
335
+ try:
336
+ download_and_extract_archive(
337
+ url, download_root=self.data_folder, filename=filename, md5=md5
338
+ )
339
+ except URLError as e:
340
+ errors.append(e)
341
+ continue
342
+ break
343
+ else:
344
+ s = f"Error downloading {filename}:\n"
345
+ for mirror, err in zip(self.mirrors, errors, strict=False):
346
+ s += f"Tried {mirror}, got:\n{str(err)}\n"
347
+ raise RuntimeError(s)
348
+
349
+
350
+ class SectionedGaussians(Dataset):
351
+ """Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
352
+
353
+ Each section defines a subrange of x-values with its own Gaussian distribution
354
+ (mean and standard deviation). Instead of abrupt transitions, the parameters
355
+ are blended smoothly between sections using a sigmoid function.
356
+
357
+ The resulting signal can represent a continuous process where statistical
358
+ properties gradually evolve over time or position.
359
+
360
+ Features:
361
+ - Input: Gaussian signal samples (y-values)
362
+ - Context: Concatenation of time (x) and normalized energy feature
363
+ - Target: Exponential decay ground truth from 1 at x_min to 0 at x_max
364
+
365
+ Attributes:
366
+ sections (list[dict]): List of section definitions.
367
+ n_samples (int): Total number of samples across all sections.
368
+ n_runs (int): Number of random waveforms generated from base configuration.
369
+ time (torch.Tensor): Sampled x-values, shape [n_samples, 1].
370
+ signal (torch.Tensor): Generated Gaussian signal values, shape [n_samples, 1].
371
+ energies (torch.Tensor): Normalized energy feature, shape [n_samples, 1].
372
+ context (torch.Tensor): Concatenation of time and energy, shape [n_samples, 2].
373
+ x_min (float): Minimum x-value across all sections.
374
+ x_max (float): Maximum x-value across all sections.
375
+ ground_truth_steepness (float): Exponential decay steepness for target output.
376
+ blend_k (float): Sharpness parameter controlling how rapidly means and
377
+ standard deviations transition between sections.
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ sections: list[dict],
383
+ n_samples: int = 1000,
384
+ n_runs: int = 1,
385
+ seed: int | None = None,
386
+ device="cpu",
387
+ ground_truth_steepness: float = 0.0,
388
+ blend_k: float = 10.0,
389
+ ):
390
+ """Initializes the dataset and generates smoothly blended Gaussian signals.
391
+
392
+ Args:
393
+ sections (list[dict]): List of sections. Each section must define:
394
+ - range (tuple[float, float]): Start and end of the x-interval.
395
+ - mean (float): Mean of the Gaussian for this section.
396
+ - std (float): Standard deviation of the Gaussian for this section.
397
+ - max_splits (int): Max number of extra splits per section.
398
+ - split_prob (float): Probability of splitting a section.
399
+ - mean_var (float): How much to vary mean (fraction of original).
400
+ - std_var (float): How much to vary std (fraction of original).
401
+ - range_var (float): How much to vary start/end positions (fraction of section length).
402
+ n_samples (int, optional): Total number of samples to generate. Defaults to 1000.
403
+ n_runs (int, optional): Number of random waveforms to generate from the
404
+ base configuration. Defaults to 1.
405
+ seed (int or None, optional): Random seed for reproducibility. Defaults to None.
406
+ device (str or torch.device, optional): Device on which tensors are allocated.
407
+ Defaults to "cpu".
408
+ ground_truth_steepness (float, optional): Controls how sharply the ground-truth
409
+ exponential decay decreases from 1 to 0. Defaults to 0.0 (linear decay).
410
+ blend_k (float, optional): Controls the sharpness of the sigmoid blending
411
+ between sections. Higher values make the transition steeper; lower
412
+ values make it smoother. Defaults to 10.0.
413
+ """
414
+ self.sections = sections
415
+ self.n_samples = n_samples
416
+ self.n_runs = n_runs
417
+ self.device = device
418
+ self.blend_k = blend_k
419
+ self.ground_truth_steepness = torch.tensor(ground_truth_steepness, device=device)
420
+
421
+ if seed is not None:
422
+ torch.manual_seed(seed)
423
+
424
+ time: list[torch.Tensor] = []
425
+ signal: list[torch.Tensor] = []
426
+ energy: list[torch.Tensor] = []
427
+ identifier: list[torch.Tensor] = []
428
+
429
+ # Compute global min/max for time
430
+ self.t_min = min(s["range"][0] for s in sections)
431
+ self.t_max = max(s["range"][1] for s in sections)
432
+
433
+ # Generate waveforms from base configuration
434
+ for run in range(self.n_runs):
435
+ waveform = self._generate_waveform()
436
+ time.append(waveform[0])
437
+ signal.append(waveform[1])
438
+ energy.append(waveform[2])
439
+ identifier.append(torch.full_like(waveform[0], float(run)))
440
+
441
+ # Concatenate runs into single tensors
442
+ self.time = torch.cat(time, dim=0)
443
+ self.signal = torch.cat(signal, dim=0)
444
+ self.energy = torch.cat(energy, dim=0)
445
+ self.identifier = torch.cat(identifier, dim=0)
446
+ self.context = torch.hstack([self.time, self.energy, self.identifier])
447
+
448
+ # Adjust n_samples in case of rounding mismatch
449
+ self.n_samples = len(self.time)
450
+
451
+ def _generate_waveform(
452
+ self,
453
+ ):
454
+ sections = []
455
+
456
+ prev_mean = 0
457
+
458
+ for sec in self.sections:
459
+ start, end = sec["range"]
460
+ mean, std = sec["add_mean"], sec["std"]
461
+ max_splits = sec["max_splits"]
462
+ split_prob = sec["split_prob"]
463
+ mean_scale = sec["mean_var"]
464
+ std_scale = sec["std_var"]
465
+ range_scale = sec["range_var"]
466
+ section_len = end - start
467
+
468
+ # Decide whether to split this section into multiple contiguous parts
469
+ n_splits = 1
470
+ if random.random() < split_prob:
471
+ n_splits += random.randint(1, max_splits)
472
+
473
+ # Randomly divide the section into contiguous subsections
474
+ random_fracs = Dirichlet(torch.ones(n_splits) * 3).sample()
475
+ sub_lengths = (random_fracs * section_len).tolist()
476
+
477
+ sub_start = start
478
+ for base_len in sub_lengths:
479
+ # Compute unperturbed subsection boundaries
480
+ base_start = sub_start
481
+ base_end = base_start + base_len
482
+
483
+ # Define smooth range variation parameters
484
+ max_shift = range_scale * base_len / 2
485
+
486
+ # Decide how much to increase the mean overall
487
+ prev_mean += mean * random.uniform(0.0, 2 * mean_scale)
488
+
489
+ # Apply small random shifts to start/end, relative to subsection size
490
+ if range_scale > 0:
491
+ # Keep total ordering consistent (avoid overlaps or inversions)
492
+ local_shift = random.uniform(-max_shift / n_splits, max_shift / n_splits)
493
+ new_start = base_start + local_shift
494
+ new_end = base_end + local_shift
495
+ else:
496
+ new_start, new_end = base_start, base_end
497
+
498
+ # Clamp to valid time bounds
499
+ new_start = max(self.t_min, min(new_start, end))
500
+ new_end = max(new_start, min(new_end, end))
501
+
502
+ # Randomize mean and std within allowed ranges
503
+ mean_var = prev_mean * (1 + random.uniform(-mean_scale, mean_scale))
504
+ std_var = std * (1 + random.uniform(-std_scale, std_scale))
505
+
506
+ sections.append(
507
+ {
508
+ "range": (new_start, new_end),
509
+ "mean": mean_var,
510
+ "std": std_var,
511
+ }
512
+ )
513
+
514
+ # Prepare for next subsection
515
+ sub_start = base_end
516
+ prev_mean = mean_var
517
+
518
+ # Ensure last section ends exactly at x_max
519
+ sections[-1]["range"] = (sections[-1]["range"][0], self.t_max)
520
+
521
+ # Precompute exact float counts
522
+ section_lengths = [s["range"][1] - s["range"][0] for s in sections]
523
+ total_length = sum(section_lengths)
524
+ exact_counts = [length / total_length * self.n_samples for length in section_lengths]
525
+
526
+ # Allocate integer samples with rounding, track residual
527
+ n_section_samples_list = []
528
+ residual = 0.0
529
+ for exact in exact_counts[:-1]:
530
+ n = int(round(exact + residual))
531
+ n_section_samples_list.append(n)
532
+ residual += exact - n # carry over rounding error
533
+
534
+ # Assign remaining samples to last section
535
+ n_section_samples_list.append(self.n_samples - sum(n_section_samples_list))
536
+
537
+ # Generate data for each section
538
+ signal_segments: list[torch.Tensor] = []
539
+ for i, section in enumerate(sections):
540
+ start, end = section["range"]
541
+ mean, std = section["mean"], section["std"]
542
+
543
+ # Samples proportional to section length
544
+ n_section_samples = n_section_samples_list[i]
545
+
546
+ # Determine next section’s parameters for blending
547
+ if i < len(sections) - 1:
548
+ next_mean = sections[i + 1]["mean"]
549
+ next_std = sections[i + 1]["std"]
550
+ else:
551
+ next_mean = mean
552
+ next_std = std
553
+
554
+ # Sigmoid-based blend curve from 0 → 1
555
+ x = torch.linspace(
556
+ -self.blend_k, self.blend_k, n_section_samples, device=self.device
557
+ ).unsqueeze(1)
558
+ fade = torch.sigmoid(x)
559
+ fade = (fade - fade.min()) / (fade.max() - fade.min())
560
+
561
+ # Interpolate mean and std within the section
562
+ mean_curve = mean + (next_mean - mean) * fade
563
+ std_curve = std + (next_std - std) * fade
564
+
565
+ # Sample from a gradually changing Gaussian
566
+ y = torch.normal(mean=mean_curve, std=std_curve)
567
+ signal_segments.append(y)
568
+
569
+ # Concatenate tensors
570
+ time: torch.Tensor = torch.linspace(
571
+ self.t_min, self.t_max, self.n_samples, device=self.device
572
+ ).unsqueeze(1)
573
+ signal: torch.Tensor = torch.cat(signal_segments, dim=0)
574
+
575
+ # Compute and normalize energy feature
576
+ energy = torch.linalg.vector_norm(
577
+ signal - signal[0].reshape([1, -1]), ord=2, dim=1
578
+ ).unsqueeze(1)
579
+ min_e, max_e = energy.min(), energy.max()
580
+ energy = 1 - (energy - min_e) / (max_e - min_e) * 3
581
+
582
+ # Combine into context tensor
583
+ return time, signal, energy
584
+
585
+ def _compute_ground_truth(self, time: torch.Tensor) -> torch.Tensor:
586
+ """Computes the ground-truth exponential decay at coordinate t.
587
+
588
+ Args:
589
+ time (torch.Tensor): Input coordinate (shape: [1]).
590
+
591
+ Returns:
592
+ torch.Tensor: Corresponding target value between 0 and 1.
593
+ """
594
+ time_norm = (time - self.t_min) / (self.t_max - self.t_min)
595
+ k = self.ground_truth_steepness
596
+ if k == 0:
597
+ return 1 - time_norm
598
+ return (torch.exp(-k * time_norm) - torch.exp(-k)) / (1 - torch.exp(-k))
599
+
600
+ def __len__(self):
601
+ """Returns the total number of generated samples."""
602
+ return self.n_samples
603
+
604
+ def __getitem__(self, idx):
605
+ """Retrieves a single dataset sample.
606
+
607
+ Args:
608
+ idx (int): Sample index.
609
+
610
+ Returns:
611
+ dict:
612
+ - "input": Gaussian signal value (torch.Tensor),
613
+ - "context: Concatenated features (time, energy, identifier),
614
+ - "target": Ground-truth exponential decay value
615
+
616
+ """
617
+ sigal = self.signal[idx]
618
+ context = self.context[idx]
619
+ time = self.time[idx]
620
+ target = self._compute_ground_truth(time)
621
+
622
+ return {
623
+ "input": sigal,
624
+ "context": context,
625
+ "target": target,
626
+ }
627
+
628
+
629
+ class SyntheticMonotonicity(Dataset):
630
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
631
+
632
+ True function:
633
+ y_true(x) = log(1 + x)
634
+
635
+ Observed:
636
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
637
+
638
+ Args:
639
+ n_samples (int): number of data points (default 200)
640
+ x_range (tuple): range of x values (default [0, 5])
641
+ noise_base (float): baseline noise level (default 0.05)
642
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
643
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
644
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
645
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
646
+ osc_frequency (float): frequency of oscillation (default 6.0)
647
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
648
+ seed (int or None): random seed
649
+ """
650
+
651
+ def __init__(
652
+ self,
653
+ n_samples=200,
654
+ x_range=(0.0, 5.0),
655
+ noise_base=0.05,
656
+ noise_scale=0.15,
657
+ noise_sharpness=4.0,
658
+ noise_center=2.5,
659
+ osc_amplitude=0.08,
660
+ osc_frequency=6.0,
661
+ osc_prob=0.5,
662
+ seed=None,
663
+ ):
664
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
665
+
666
+ True function:
667
+ y_true(x) = log(1 + x)
668
+
669
+ Observed:
670
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
671
+
672
+ Args:
673
+ n_samples (int): number of data points (default 200)
674
+ x_range (tuple): range of x values (default [0, 5])
675
+ noise_base (float): baseline noise level (default 0.05)
676
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
677
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
678
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
679
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
680
+ osc_frequency (float): frequency of oscillation (default 6.0)
681
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
682
+ seed (int or None): random seed
683
+ """
684
+ super().__init__()
685
+ if seed is not None:
686
+ np.random.seed(seed)
687
+
688
+ self.n_samples = n_samples
689
+ self.x_range = x_range
690
+
691
+ # Sample inputs
692
+ self.x = np.random.rand(n_samples) * (x_range[1] - x_range[0]) + x_range[0]
693
+ self.x = np.sort(self.x)
694
+
695
+ # Heteroscedastic noise (logistic growth with x)
696
+ noise_sigma = noise_base + noise_scale / (
697
+ 1 + np.exp(-noise_sharpness * (self.x - noise_center))
698
+ )
699
+
700
+ # Oscillatory perturbation (applied with probability osc_prob)
701
+ mask = (np.random.rand(n_samples) < osc_prob).astype(float)
702
+ osc = mask * (osc_amplitude * np.sin(osc_frequency * self.x))
703
+
704
+ # Final observed target
705
+ self.y = np.log1p(self.x) + osc + noise_sigma * np.random.randn(n_samples)
706
+
707
+ # Convert to tensors
708
+ self.inputs = torch.tensor(self.x, dtype=torch.float32).unsqueeze(1)
709
+ self.targets = torch.tensor(self.y, dtype=torch.float32).unsqueeze(1)
710
+
711
+ def __len__(self):
712
+ """Return the total number of samples in the dataset.
713
+
714
+ Returns:
715
+ int: The number of generated data points.
716
+ """
717
+ return self.n_samples
718
+
719
+ def __getitem__(self, idx) -> dict:
720
+ """Retrieve a single sample from the dataset.
721
+
722
+ Args:
723
+ idx (int): Index of the sample to retrieve.
724
+
725
+ Returns:
726
+ dict: A dictionary with the following keys:
727
+ - "input" (torch.Tensor): The x value.
728
+ - "target" (torch.Tensor): The corresponding y value.
729
+ """
730
+ return {"input": self.inputs[idx], "target": self.targets[idx]}
731
+
732
+
733
+ class SyntheticClusters(Dataset):
734
+ """PyTorch dataset for generating synthetic clustered 2D data with labels.
735
+
736
+ Each cluster is defined by its center, size, spread (standard deviation), and label.
737
+ The dataset samples points from a Gaussian distribution centered at the cluster mean.
738
+
739
+ Args:
740
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
741
+ e.g. [(x1, y1), (x2, y2), ...].
742
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
743
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
744
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
745
+
746
+ Raises:
747
+ AssertionError: If the input lists do not all have the same length.
748
+
749
+ Attributes:
750
+ data (torch.Tensor): A concatenated tensor of all generated points with shape (N, 2).
751
+ labels (torch.Tensor): A concatenated tensor of class labels with shape (N,),
752
+ where N is the total number of generated points.
753
+ """
754
+
755
+ def __init__(self, cluster_centers, cluster_sizes, cluster_std, cluster_labels):
756
+ """Initialize the ClusterDataset.
757
+
758
+ Args:
759
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
760
+ e.g. [(x1, y1), (x2, y2), ...].
761
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
762
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
763
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
764
+
765
+ Raises:
766
+ AssertionError: If the input lists do not all have the same length.
767
+ """
768
+ assert (
769
+ len(cluster_centers) == len(cluster_sizes) == len(cluster_std) == len(cluster_labels)
770
+ ), "All input lists must have the same length"
771
+
772
+ self.data = []
773
+ self.labels = []
774
+
775
+ # Generate points for each cluster
776
+ for center, size, std, label in zip(
777
+ cluster_centers, cluster_sizes, cluster_std, cluster_labels, strict=False
778
+ ):
779
+ x = torch.normal(mean=center[0], std=std, size=(size, 1))
780
+ y = torch.normal(mean=center[1], std=std, size=(size, 1))
781
+ points = torch.cat([x, y], dim=1)
782
+
783
+ self.data.append(points)
784
+ self.labels.append(torch.full((size,), label, dtype=torch.long))
785
+
786
+ # Concatenate all clusters
787
+ self.data = torch.cat(self.data, dim=0)
788
+ self.labels = torch.cat(self.labels, dim=0)
789
+
790
+ def __len__(self):
791
+ """Return the total number of samples in the dataset.
792
+
793
+ Returns:
794
+ int: The number of generated data points.
795
+ """
796
+ return len(self.data)
797
+
798
+ def __getitem__(self, idx) -> dict:
799
+ """Retrieve a single sample from the dataset.
800
+
801
+ Args:
802
+ idx (int): Index of the sample to retrieve.
803
+
804
+ Returns:
805
+ dict: A dictionary with the following keys:
806
+ - "input" (torch.Tensor): The 2D point at index `idx`.
807
+ - "target" (torch.Tensor): The corresponding class label.
808
+ """
809
+ return {"input": self.data[idx], "target": self.labels[idx]}