congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,799 @@
1
+ """This module defines several PyTorch dataset classes for loading and working with various datasets.
2
+
3
+ Each dataset class extends the `torch.utils.data.Dataset` class and provides functionality for
4
+ downloading, loading, and transforming specific datasets where applicable.
5
+
6
+ Classes:
7
+
8
+ - SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
9
+ - BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
10
+ - FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
11
+
12
+ Each dataset class provides methods for downloading the data
13
+ (if not already available or synthetic), checking the integrity of the dataset, loading
14
+ the data from CSV files or generating synthetic data, and applying
15
+ transformations to the data.
16
+ """
17
+
18
+ import os
19
+ import random
20
+ from collections.abc import Callable
21
+ from pathlib import Path
22
+ from urllib.error import URLError
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ import torch
27
+ from torch.distributions import Dirichlet
28
+ from torch.utils.data import Dataset
29
+ from torchvision.datasets.utils import (
30
+ check_integrity,
31
+ download_and_extract_archive,
32
+ )
33
+
34
+
35
+ class BiasCorrection(Dataset):
36
+ """A dataset class for accessing the Bias Correction dataset.
37
+
38
+ This class extends the `Dataset` class and provides functionality for
39
+ downloading, loading, and transforming the Bias Correction dataset.
40
+ The dataset is focused on temperature forecast data and is made available
41
+ for use with PyTorch. If `download` is set to True, the dataset will be
42
+ downloaded if it is not already available. The data is then loaded,
43
+ and a transformation function is applied to it.
44
+
45
+ Args:
46
+ root (Union[str, Path]): The root directory where the dataset
47
+ will be stored or loaded from.
48
+ transform (Callable): A function to transform the dataset
49
+ (e.g., preprocessing).
50
+ download (bool, optional): Whether to download the dataset if it's
51
+ not already present. Defaults to False.
52
+
53
+ Raises:
54
+ RuntimeError: If the dataset is not found and `download`
55
+ is not set to True or if all mirrors fail to provide the dataset.
56
+ """
57
+
58
+ mirrors = ["https://archive.ics.uci.edu/static/public/514/"]
59
+ resources = [
60
+ (
61
+ "bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
62
+ "3deee56d461a2686887c4ae38fe3ccf3",
63
+ )
64
+ ]
65
+
66
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
67
+ """Constructor method to initialize the dataset."""
68
+ super().__init__()
69
+ self.root = root
70
+ self.transform = transform
71
+
72
+ if download:
73
+ self.download()
74
+
75
+ if not self._check_exists():
76
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
77
+
78
+ self.data_input, self.data_output = self._load_data()
79
+
80
+ def _load_data(self):
81
+ """Loads the dataset from the CSV file and applies the transformation.
82
+
83
+ The data is read from the `Bias_correction_ucl.csv` file, and the
84
+ transformation function is applied to it.
85
+ The input and output data are separated and returned as numpy arrays.
86
+
87
+ Returns:
88
+ Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
89
+ and output data as numpy arrays.
90
+ """
91
+ data: pd.DataFrame = pd.read_csv(
92
+ os.path.join(self.data_folder, "Bias_correction_ucl.csv")
93
+ ).pipe(self.transform)
94
+
95
+ data_input = data["Input"].to_numpy(dtype=np.float32)
96
+ data_output = data["Output"].to_numpy(dtype=np.float32)
97
+
98
+ return data_input, data_output
99
+
100
+ def __len__(self):
101
+ """Returns the number of examples in the dataset.
102
+
103
+ Returns:
104
+ int: The number of examples in the dataset
105
+ (i.e., the number of rows in the input data).
106
+ """
107
+ return self.data_input.shape[0]
108
+
109
+ def __getitem__(self, idx):
110
+ """Returns the input-output pair for a given index.
111
+
112
+ Args:
113
+ idx (int): The index of the example to retrieve.
114
+
115
+ Returns:
116
+ dict: A dictionary with the following keys:
117
+ - "input" (torch.Tensor): The input features for the example.
118
+ - "target" (torch.Tensor): The target output for the example.
119
+ """
120
+ example = self.data_input[idx, :]
121
+ target = self.data_output[idx, :]
122
+ example = torch.tensor(example)
123
+ target = torch.tensor(target)
124
+ return {"input": example, "target": target}
125
+
126
+ @property
127
+ def data_folder(self) -> str:
128
+ """Returns the path to the folder where the dataset is stored.
129
+
130
+ Returns:
131
+ str: The path to the dataset folder.
132
+ """
133
+ return os.path.join(self.root, self.__class__.__name__)
134
+
135
+ def _check_exists(self) -> bool:
136
+ """Checks if the dataset is already downloaded and verified.
137
+
138
+ This method checks that all required files exist and
139
+ their integrity is validated via MD5 checksums.
140
+
141
+ Returns:
142
+ bool: True if all resources exist and their
143
+ integrity is valid, False otherwise.
144
+ """
145
+ return all(
146
+ check_integrity(os.path.join(self.data_folder, file_path), checksum)
147
+ for file_path, checksum in self.resources
148
+ )
149
+
150
+ def download(self) -> None:
151
+ """Downloads and extracts the dataset.
152
+
153
+ This method attempts to download the dataset from the mirrors and
154
+ extract it into the appropriate folder. If any error occurs during
155
+ downloading, it will try each mirror in sequence.
156
+
157
+ Raises:
158
+ RuntimeError: If all mirrors fail to provide the dataset.
159
+ """
160
+ if self._check_exists():
161
+ return
162
+
163
+ os.makedirs(self.data_folder, exist_ok=True)
164
+
165
+ # download files
166
+ for filename, md5 in self.resources:
167
+ errors = []
168
+ for mirror in self.mirrors:
169
+ url = f"{mirror}{filename}"
170
+ try:
171
+ download_and_extract_archive(
172
+ url, download_root=self.data_folder, filename=filename, md5=md5
173
+ )
174
+ except URLError as e:
175
+ errors.append(e)
176
+ continue
177
+ break
178
+ else:
179
+ s = f"Error downloading {filename}:\n"
180
+ for mirror, err in zip(self.mirrors, errors, strict=False):
181
+ s += f"Tried {mirror}, got:\n{str(err)}\n"
182
+ raise RuntimeError(s)
183
+
184
+
185
+ class FamilyIncome(Dataset):
186
+ """A dataset class for accessing the Family Income and Expenditure dataset.
187
+
188
+ This class extends the `Dataset` class and provides functionality for
189
+ downloading, loading, and transforming the Family Income and
190
+ Expenditure dataset. The dataset is intended for use with
191
+ PyTorch-based projects, offering a convenient interface for data handling.
192
+ This class provides access to the Family Income and Expenditure dataset
193
+ for use with PyTorch. If `download` is set to True, the dataset will be
194
+ downloaded if it is not already available. The data is then loaded,
195
+ and a user-defined transformation function is applied to it.
196
+
197
+ Args:
198
+ root (Union[str, Path]): The root directory where the dataset will
199
+ be stored or loaded from.
200
+ transform (Callable): A function to transform the dataset
201
+ (e.g., preprocessing).
202
+ download (bool, optional): Whether to download the dataset if it's
203
+ not already present. Defaults to False.
204
+
205
+ Raises:
206
+ RuntimeError: If the dataset is not found and `download`
207
+ is not set to True or if all mirrors fail to provide the dataset.
208
+ """
209
+
210
+ mirrors = [
211
+ "https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure"
212
+ ]
213
+ resources = [
214
+ (
215
+ "archive.zip",
216
+ "7d74bc7facc3d7c07c4df1c1c6ac563e",
217
+ )
218
+ ]
219
+
220
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
221
+ """Constructor method to initialize the dataset."""
222
+ super().__init__()
223
+ self.root = root
224
+ self.transform = transform
225
+
226
+ if download:
227
+ self.download()
228
+
229
+ if not self._check_exists():
230
+ raise RuntimeError("Dataset not found. You can use download=True to download it.")
231
+
232
+ self.data_input, self.data_output = self._load_data()
233
+
234
+ def _load_data(self):
235
+ """Load and transform the Family Income and Expenditure dataset.
236
+
237
+ Reads the data from the `Family Income and Expenditure.csv` file located
238
+ in `self.data_folder` and applies the transformation function. The input
239
+ and output columns are extracted and returned as NumPy arrays.
240
+
241
+ Returns:
242
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
243
+ - input data as a NumPy array of type float32
244
+ - output data as a NumPy array of type float32
245
+ """
246
+ data: pd.DataFrame = pd.read_csv(
247
+ os.path.join(self.data_folder, "Family Income and Expenditure.csv")
248
+ ).pipe(self.transform)
249
+
250
+ data_input = data["Input"].to_numpy(dtype=np.float32)
251
+ data_output = data["Output"].to_numpy(dtype=np.float32)
252
+
253
+ return data_input, data_output
254
+
255
+ def __len__(self):
256
+ """Returns the number of examples in the dataset.
257
+
258
+ Returns:
259
+ int: The number of examples in the dataset
260
+ (i.e., the number of rows in the input data).
261
+ """
262
+ return self.data_input.shape[0]
263
+
264
+ def __getitem__(self, idx):
265
+ """Returns the input-output pair for a given index.
266
+
267
+ Args:
268
+ idx (int): The index of the example to retrieve.
269
+
270
+ Returns:
271
+ dict: A dictionary with the following keys:
272
+ - "input" (torch.Tensor): The input features for the example.
273
+ - "target" (torch.Tensor): The target output for the example.
274
+ """
275
+ example = self.data_input[idx, :]
276
+ target = self.data_output[idx, :]
277
+ example = torch.tensor(example)
278
+ target = torch.tensor(target)
279
+ return {"input": example, "target": target}
280
+
281
+ @property
282
+ def data_folder(self) -> str:
283
+ """Returns the path to the folder where the dataset is stored.
284
+
285
+ Returns:
286
+ str: The path to the dataset folder.
287
+ """
288
+ return os.path.join(self.root, self.__class__.__name__)
289
+
290
+ def _check_exists(self) -> bool:
291
+ """Checks if the dataset is already downloaded and verified.
292
+
293
+ This method checks that all required files exist and
294
+ their integrity is validated via MD5 checksums.
295
+
296
+ Returns:
297
+ bool: True if all resources exist and their
298
+ integrity is valid, False otherwise.
299
+ """
300
+ return all(
301
+ check_integrity(os.path.join(self.data_folder, file_path), checksum)
302
+ for file_path, checksum in self.resources
303
+ )
304
+
305
+ def download(self) -> None:
306
+ """Downloads and extracts the dataset.
307
+
308
+ This method attempts to download the dataset from the mirrors
309
+ and extract it into the appropriate folder. If any error occurs
310
+ during downloading, it will try each mirror in sequence.
311
+
312
+ Raises:
313
+ RuntimeError: If all mirrors fail to provide the dataset.
314
+ """
315
+ if self._check_exists():
316
+ return
317
+
318
+ os.makedirs(self.data_folder, exist_ok=True)
319
+
320
+ # download files
321
+ for filename, md5 in self.resources:
322
+ errors = []
323
+ for mirror in self.mirrors:
324
+ url = f"{mirror}"
325
+ try:
326
+ download_and_extract_archive(
327
+ url, download_root=self.data_folder, filename=filename, md5=md5
328
+ )
329
+ except URLError as e:
330
+ errors.append(e)
331
+ continue
332
+ break
333
+ else:
334
+ s = f"Error downloading {filename}:\n"
335
+ for mirror, err in zip(self.mirrors, errors, strict=False):
336
+ s += f"Tried {mirror}, got:\n{str(err)}\n"
337
+ raise RuntimeError(s)
338
+
339
+
340
+ class SectionedGaussians(Dataset):
341
+ """Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
342
+
343
+ Each section defines a subrange of x-values with its own Gaussian distribution
344
+ (mean and standard deviation). Instead of abrupt transitions, the parameters
345
+ are blended smoothly between sections using a sigmoid function.
346
+
347
+ The resulting signal can represent a continuous process where statistical
348
+ properties gradually evolve over time or position.
349
+
350
+ Features:
351
+ - Input: Gaussian signal samples (y-values)
352
+ - Context: Concatenation of time (x) and normalized energy feature
353
+ - Target: Exponential decay ground truth from 1 at x_min to 0 at x_max
354
+
355
+ Attributes:
356
+ sections (list[dict]): List of section definitions.
357
+ n_samples (int): Total number of samples across all sections.
358
+ n_runs (int): Number of random waveforms generated from base configuration.
359
+ time (torch.Tensor): Sampled x-values, shape [n_samples, 1].
360
+ signal (torch.Tensor): Generated Gaussian signal values, shape [n_samples, 1].
361
+ energies (torch.Tensor): Normalized energy feature, shape [n_samples, 1].
362
+ context (torch.Tensor): Concatenation of time and energy, shape [n_samples, 2].
363
+ x_min (float): Minimum x-value across all sections.
364
+ x_max (float): Maximum x-value across all sections.
365
+ ground_truth_steepness (float): Exponential decay steepness for target output.
366
+ blend_k (float): Sharpness parameter controlling how rapidly means and
367
+ standard deviations transition between sections.
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ sections: list[dict],
373
+ n_samples: int = 1000,
374
+ n_runs: int = 1,
375
+ seed: int | None = None,
376
+ device="cpu",
377
+ ground_truth_steepness: float = 0.0,
378
+ blend_k: float = 10.0,
379
+ ):
380
+ """Initializes the dataset and generates smoothly blended Gaussian signals.
381
+
382
+ Args:
383
+ sections (list[dict]): List of sections. Each section must define:
384
+ - range (tuple[float, float]): Start and end of the x-interval.
385
+ - mean (float): Mean of the Gaussian for this section.
386
+ - std (float): Standard deviation of the Gaussian for this section.
387
+ - max_splits (int): Max number of extra splits per section.
388
+ - split_prob (float): Probability of splitting a section.
389
+ - mean_var (float): How much to vary mean (fraction of original).
390
+ - std_var (float): How much to vary std (fraction of original).
391
+ - range_var (float): How much to vary start/end positions (fraction of section length).
392
+ n_samples (int, optional): Total number of samples to generate. Defaults to 1000.
393
+ n_runs (int, optional): Number of random waveforms to generate from the
394
+ base configuration. Defaults to 1.
395
+ seed (int or None, optional): Random seed for reproducibility. Defaults to None.
396
+ device (str or torch.device, optional): Device on which tensors are allocated.
397
+ Defaults to "cpu".
398
+ ground_truth_steepness (float, optional): Controls how sharply the ground-truth
399
+ exponential decay decreases from 1 to 0. Defaults to 0.0 (linear decay).
400
+ blend_k (float, optional): Controls the sharpness of the sigmoid blending
401
+ between sections. Higher values make the transition steeper; lower
402
+ values make it smoother. Defaults to 10.0.
403
+ """
404
+ self.sections = sections
405
+ self.n_samples = n_samples
406
+ self.n_runs = n_runs
407
+ self.device = device
408
+ self.blend_k = blend_k
409
+ self.ground_truth_steepness = torch.tensor(ground_truth_steepness, device=device)
410
+
411
+ if seed is not None:
412
+ torch.manual_seed(seed)
413
+
414
+ time: list[torch.Tensor] = []
415
+ signal: list[torch.Tensor] = []
416
+ energy: list[torch.Tensor] = []
417
+ identifier: list[torch.Tensor] = []
418
+
419
+ # Compute global min/max for time
420
+ self.t_min = min(s["range"][0] for s in sections)
421
+ self.t_max = max(s["range"][1] for s in sections)
422
+
423
+ # Generate waveforms from base configuration
424
+ for run in range(self.n_runs):
425
+ waveform = self._generate_waveform()
426
+ time.append(waveform[0])
427
+ signal.append(waveform[1])
428
+ energy.append(waveform[2])
429
+ identifier.append(torch.full_like(waveform[0], float(run)))
430
+
431
+ # Concatenate runs into single tensors
432
+ self.time = torch.cat(time, dim=0)
433
+ self.signal = torch.cat(signal, dim=0)
434
+ self.energy = torch.cat(energy, dim=0)
435
+ self.identifier = torch.cat(identifier, dim=0)
436
+ self.context = torch.hstack([self.time, self.energy, self.identifier])
437
+
438
+ # Adjust n_samples in case of rounding mismatch
439
+ self.n_samples = len(self.time)
440
+
441
+ def _generate_waveform(
442
+ self,
443
+ ):
444
+ sections = []
445
+
446
+ prev_mean = 0
447
+
448
+ for sec in self.sections:
449
+ start, end = sec["range"]
450
+ mean, std = sec["add_mean"], sec["std"]
451
+ max_splits = sec["max_splits"]
452
+ split_prob = sec["split_prob"]
453
+ mean_scale = sec["mean_var"]
454
+ std_scale = sec["std_var"]
455
+ range_scale = sec["range_var"]
456
+ section_len = end - start
457
+
458
+ # Decide whether to split this section into multiple contiguous parts
459
+ n_splits = 1
460
+ if random.random() < split_prob:
461
+ n_splits += random.randint(1, max_splits)
462
+
463
+ # Randomly divide the section into contiguous subsections
464
+ random_fracs = Dirichlet(torch.ones(n_splits) * 3).sample()
465
+ sub_lengths = (random_fracs * section_len).tolist()
466
+
467
+ sub_start = start
468
+ for base_len in sub_lengths:
469
+ # Compute unperturbed subsection boundaries
470
+ base_start = sub_start
471
+ base_end = base_start + base_len
472
+
473
+ # Define smooth range variation parameters
474
+ max_shift = range_scale * base_len / 2
475
+
476
+ # Decide how much to increase the mean overall
477
+ prev_mean += mean * random.uniform(0.0, 2 * mean_scale)
478
+
479
+ # Apply small random shifts to start/end, relative to subsection size
480
+ if range_scale > 0:
481
+ # Keep total ordering consistent (avoid overlaps or inversions)
482
+ local_shift = random.uniform(-max_shift / n_splits, max_shift / n_splits)
483
+ new_start = base_start + local_shift
484
+ new_end = base_end + local_shift
485
+ else:
486
+ new_start, new_end = base_start, base_end
487
+
488
+ # Clamp to valid time bounds
489
+ new_start = max(self.t_min, min(new_start, end))
490
+ new_end = max(new_start, min(new_end, end))
491
+
492
+ # Randomize mean and std within allowed ranges
493
+ mean_var = prev_mean * (1 + random.uniform(-mean_scale, mean_scale))
494
+ std_var = std * (1 + random.uniform(-std_scale, std_scale))
495
+
496
+ sections.append(
497
+ {
498
+ "range": (new_start, new_end),
499
+ "mean": mean_var,
500
+ "std": std_var,
501
+ }
502
+ )
503
+
504
+ # Prepare for next subsection
505
+ sub_start = base_end
506
+ prev_mean = mean_var
507
+
508
+ # Ensure last section ends exactly at x_max
509
+ sections[-1]["range"] = (sections[-1]["range"][0], self.t_max)
510
+
511
+ # Precompute exact float counts
512
+ section_lengths = [s["range"][1] - s["range"][0] for s in sections]
513
+ total_length = sum(section_lengths)
514
+ exact_counts = [length / total_length * self.n_samples for length in section_lengths]
515
+
516
+ # Allocate integer samples with rounding, track residual
517
+ n_section_samples_list = []
518
+ residual = 0.0
519
+ for exact in exact_counts[:-1]:
520
+ n = int(round(exact + residual))
521
+ n_section_samples_list.append(n)
522
+ residual += exact - n # carry over rounding error
523
+
524
+ # Assign remaining samples to last section
525
+ n_section_samples_list.append(self.n_samples - sum(n_section_samples_list))
526
+
527
+ # Generate data for each section
528
+ signal_segments: list[torch.Tensor] = []
529
+ for i, section in enumerate(sections):
530
+ start, end = section["range"]
531
+ mean, std = section["mean"], section["std"]
532
+
533
+ # Samples proportional to section length
534
+ n_section_samples = n_section_samples_list[i]
535
+
536
+ # Determine next section’s parameters for blending
537
+ if i < len(sections) - 1:
538
+ next_mean = sections[i + 1]["mean"]
539
+ next_std = sections[i + 1]["std"]
540
+ else:
541
+ next_mean = mean
542
+ next_std = std
543
+
544
+ # Sigmoid-based blend curve from 0 → 1
545
+ x = torch.linspace(
546
+ -self.blend_k, self.blend_k, n_section_samples, device=self.device
547
+ ).unsqueeze(1)
548
+ fade = torch.sigmoid(x)
549
+ fade = (fade - fade.min()) / (fade.max() - fade.min())
550
+
551
+ # Interpolate mean and std within the section
552
+ mean_curve = mean + (next_mean - mean) * fade
553
+ std_curve = std + (next_std - std) * fade
554
+
555
+ # Sample from a gradually changing Gaussian
556
+ y = torch.normal(mean=mean_curve, std=std_curve)
557
+ signal_segments.append(y)
558
+
559
+ # Concatenate tensors
560
+ time: torch.Tensor = torch.linspace(
561
+ self.t_min, self.t_max, self.n_samples, device=self.device
562
+ ).unsqueeze(1)
563
+ signal: torch.Tensor = torch.cat(signal_segments, dim=0)
564
+
565
+ # Compute and normalize energy feature
566
+ energy = torch.linalg.vector_norm(
567
+ signal - signal[0].reshape([1, -1]), ord=2, dim=1
568
+ ).unsqueeze(1)
569
+ min_e, max_e = energy.min(), energy.max()
570
+ energy = 1 - (energy - min_e) / (max_e - min_e) * 3
571
+
572
+ # Combine into context tensor
573
+ return time, signal, energy
574
+
575
+ def _compute_ground_truth(self, time: torch.Tensor) -> torch.Tensor:
576
+ """Computes the ground-truth exponential decay at coordinate t.
577
+
578
+ Args:
579
+ time (torch.Tensor): Input coordinate (shape: [1]).
580
+
581
+ Returns:
582
+ torch.Tensor: Corresponding target value between 0 and 1.
583
+ """
584
+ time_norm = (time - self.t_min) / (self.t_max - self.t_min)
585
+ k = self.ground_truth_steepness
586
+ if k == 0:
587
+ return 1 - time_norm
588
+ return (torch.exp(-k * time_norm) - torch.exp(-k)) / (1 - torch.exp(-k))
589
+
590
+ def __len__(self):
591
+ """Returns the total number of generated samples."""
592
+ return self.n_samples
593
+
594
+ def __getitem__(self, idx):
595
+ """Retrieves a single dataset sample.
596
+
597
+ Args:
598
+ idx (int): Sample index.
599
+
600
+ Returns:
601
+ dict:
602
+ - "input": Gaussian signal value (torch.Tensor),
603
+ - "context: Concatenated features (time, energy, identifier),
604
+ - "target": Ground-truth exponential decay value
605
+
606
+ """
607
+ sigal = self.signal[idx]
608
+ context = self.context[idx]
609
+ time = self.time[idx]
610
+ target = self._compute_ground_truth(time)
611
+
612
+ return {
613
+ "input": sigal,
614
+ "context": context,
615
+ "target": target,
616
+ }
617
+
618
+
619
+ class SyntheticMonotonicity(Dataset):
620
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
621
+
622
+ True function:
623
+ y_true(x) = log(1 + x)
624
+
625
+ Observed:
626
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
627
+
628
+ Args:
629
+ n_samples (int): number of data points (default 200)
630
+ x_range (tuple): range of x values (default [0, 5])
631
+ noise_base (float): baseline noise level (default 0.05)
632
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
633
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
634
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
635
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
636
+ osc_frequency (float): frequency of oscillation (default 6.0)
637
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
638
+ seed (int or None): random seed
639
+ """
640
+
641
+ def __init__(
642
+ self,
643
+ n_samples=200,
644
+ x_range=(0.0, 5.0),
645
+ noise_base=0.05,
646
+ noise_scale=0.15,
647
+ noise_sharpness=4.0,
648
+ noise_center=2.5,
649
+ osc_amplitude=0.08,
650
+ osc_frequency=6.0,
651
+ osc_prob=0.5,
652
+ seed=None,
653
+ ):
654
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
655
+
656
+ True function:
657
+ y_true(x) = log(1 + x)
658
+
659
+ Observed:
660
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
661
+
662
+ Args:
663
+ n_samples (int): number of data points (default 200)
664
+ x_range (tuple): range of x values (default [0, 5])
665
+ noise_base (float): baseline noise level (default 0.05)
666
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
667
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
668
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
669
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
670
+ osc_frequency (float): frequency of oscillation (default 6.0)
671
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
672
+ seed (int or None): random seed
673
+ """
674
+ super().__init__()
675
+ if seed is not None:
676
+ np.random.seed(seed)
677
+
678
+ self.n_samples = n_samples
679
+ self.x_range = x_range
680
+
681
+ # Sample inputs
682
+ self.x = np.random.rand(n_samples) * (x_range[1] - x_range[0]) + x_range[0]
683
+ self.x = np.sort(self.x)
684
+
685
+ # Heteroscedastic noise (logistic growth with x)
686
+ noise_sigma = noise_base + noise_scale / (
687
+ 1 + np.exp(-noise_sharpness * (self.x - noise_center))
688
+ )
689
+
690
+ # Oscillatory perturbation (applied with probability osc_prob)
691
+ mask = (np.random.rand(n_samples) < osc_prob).astype(float)
692
+ osc = mask * (osc_amplitude * np.sin(osc_frequency * self.x))
693
+
694
+ # Final observed target
695
+ self.y = np.log1p(self.x) + osc + noise_sigma * np.random.randn(n_samples)
696
+
697
+ # Convert to tensors
698
+ self.inputs = torch.tensor(self.x, dtype=torch.float32).unsqueeze(1)
699
+ self.targets = torch.tensor(self.y, dtype=torch.float32).unsqueeze(1)
700
+
701
+ def __len__(self):
702
+ """Return the total number of samples in the dataset.
703
+
704
+ Returns:
705
+ int: The number of generated data points.
706
+ """
707
+ return self.n_samples
708
+
709
+ def __getitem__(self, idx) -> dict:
710
+ """Retrieve a single sample from the dataset.
711
+
712
+ Args:
713
+ idx (int): Index of the sample to retrieve.
714
+
715
+ Returns:
716
+ dict: A dictionary with the following keys:
717
+ - "input" (torch.Tensor): The x value.
718
+ - "target" (torch.Tensor): The corresponding y value.
719
+ """
720
+ return {"input": self.inputs[idx], "target": self.targets[idx]}
721
+
722
+
723
+ class SyntheticClusters(Dataset):
724
+ """PyTorch dataset for generating synthetic clustered 2D data with labels.
725
+
726
+ Each cluster is defined by its center, size, spread (standard deviation), and label.
727
+ The dataset samples points from a Gaussian distribution centered at the cluster mean.
728
+
729
+ Args:
730
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
731
+ e.g. [(x1, y1), (x2, y2), ...].
732
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
733
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
734
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
735
+
736
+ Raises:
737
+ AssertionError: If the input lists do not all have the same length.
738
+
739
+ Attributes:
740
+ data (torch.Tensor): A concatenated tensor of all generated points with shape (N, 2).
741
+ labels (torch.Tensor): A concatenated tensor of class labels with shape (N,),
742
+ where N is the total number of generated points.
743
+ """
744
+
745
+ def __init__(self, cluster_centers, cluster_sizes, cluster_std, cluster_labels):
746
+ """Initialize the ClusterDataset.
747
+
748
+ Args:
749
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
750
+ e.g. [(x1, y1), (x2, y2), ...].
751
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
752
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
753
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
754
+
755
+ Raises:
756
+ AssertionError: If the input lists do not all have the same length.
757
+ """
758
+ assert (
759
+ len(cluster_centers) == len(cluster_sizes) == len(cluster_std) == len(cluster_labels)
760
+ ), "All input lists must have the same length"
761
+
762
+ self.data = []
763
+ self.labels = []
764
+
765
+ # Generate points for each cluster
766
+ for center, size, std, label in zip(
767
+ cluster_centers, cluster_sizes, cluster_std, cluster_labels, strict=False
768
+ ):
769
+ x = torch.normal(mean=center[0], std=std, size=(size, 1))
770
+ y = torch.normal(mean=center[1], std=std, size=(size, 1))
771
+ points = torch.cat([x, y], dim=1)
772
+
773
+ self.data.append(points)
774
+ self.labels.append(torch.full((size,), label, dtype=torch.long))
775
+
776
+ # Concatenate all clusters
777
+ self.data = torch.cat(self.data, dim=0)
778
+ self.labels = torch.cat(self.labels, dim=0)
779
+
780
+ def __len__(self):
781
+ """Return the total number of samples in the dataset.
782
+
783
+ Returns:
784
+ int: The number of generated data points.
785
+ """
786
+ return len(self.data)
787
+
788
+ def __getitem__(self, idx) -> dict:
789
+ """Retrieve a single sample from the dataset.
790
+
791
+ Args:
792
+ idx (int): Index of the sample to retrieve.
793
+
794
+ Returns:
795
+ dict: A dictionary with the following keys:
796
+ - "input" (torch.Tensor): The 2D point at index `idx`.
797
+ - "target" (torch.Tensor): The corresponding class label.
798
+ """
799
+ return {"input": self.data[idx], "target": self.labels[idx]}