congrads 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/datasets.py CHANGED
@@ -1,51 +1,30 @@
1
- """
2
- This module defines several PyTorch dataset classes for loading and
3
- working with various datasets. Each dataset class extends the
4
- `torch.utils.data.Dataset` class and provides functionality for
5
- downloading, loading, and transforming specific datasets.
1
+ """This module defines several PyTorch dataset classes for loading and working with various datasets.
2
+
3
+ Each dataset class extends the `torch.utils.data.Dataset` class and provides functionality for
4
+ downloading, loading, and transforming specific datasets where applicable.
6
5
 
7
6
  Classes:
8
7
 
9
- - BiasCorrection: A dataset class for the Bias Correction dataset
10
- focused on temperature forecast data.
11
- - FamilyIncome: A dataset class for the Family Income and
12
- Expenditure dataset.
13
- - NoisySines: A dataset class that generates noisy sine wave
14
- samples with added Gaussian noise.
8
+ - SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
9
+ - BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
10
+ - FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
15
11
 
16
- Each dataset class provides methods for downloading the data
17
- (if not already available), checking the integrity of the dataset, loading
18
- the data from CSV files or generating synthetic data, and applying
12
+ Each dataset class provides methods for downloading the data
13
+ (if not already available or synthetic), checking the integrity of the dataset, loading
14
+ the data from CSV files or generating synthetic data, and applying
19
15
  transformations to the data.
20
-
21
- Key Methods:
22
-
23
- - `__init__`: Initializes the dataset by specifying the root directory,
24
- transformation function, and optional download flag.
25
- - `__getitem__`: Retrieves a specific data point given its index,
26
- returning input-output pairs.
27
- - `__len__`: Returns the total number of examples in the dataset.
28
- - `download`: Downloads and extracts the dataset from
29
- the specified mirrors.
30
- - `_load_data`: Loads the dataset from CSV files and
31
- applies transformations.
32
- - `_check_exists`: Checks if the dataset is already
33
- downloaded and verified.
34
-
35
- Each dataset class allows the user to apply custom transformations to the
36
- dataset through the `transform` argument to allow pre-processing and offers
37
- the ability to download the dataset if it's not already present on
38
- the local disk.
39
16
  """
40
17
 
41
18
  import os
19
+ import random
20
+ from collections.abc import Callable
42
21
  from pathlib import Path
43
- from typing import Callable, Union
44
22
  from urllib.error import URLError
45
23
 
46
24
  import numpy as np
47
25
  import pandas as pd
48
26
  import torch
27
+ from torch.distributions import Dirichlet
49
28
  from torch.utils.data import Dataset
50
29
  from torchvision.datasets.utils import (
51
30
  check_integrity,
@@ -54,8 +33,7 @@ from torchvision.datasets.utils import (
54
33
 
55
34
 
56
35
  class BiasCorrection(Dataset):
57
- """
58
- A dataset class for accessing the Bias Correction dataset.
36
+ """A dataset class for accessing the Bias Correction dataset.
59
37
 
60
38
  This class extends the `Dataset` class and provides functionality for
61
39
  downloading, loading, and transforming the Bias Correction dataset.
@@ -77,28 +55,16 @@ class BiasCorrection(Dataset):
77
55
  is not set to True or if all mirrors fail to provide the dataset.
78
56
  """
79
57
 
80
- mirrors = [
81
- "https://archive.ics.uci.edu/static/public/514/",
82
- ]
83
-
58
+ mirrors = ["https://archive.ics.uci.edu/static/public/514/"]
84
59
  resources = [
85
60
  (
86
- # pylint: disable-next=line-too-long
87
61
  "bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
88
62
  "3deee56d461a2686887c4ae38fe3ccf3",
89
- ),
63
+ )
90
64
  ]
91
65
 
92
- def __init__(
93
- self,
94
- root: Union[str, Path],
95
- transform: Callable,
96
- download: bool = False,
97
- ) -> None:
98
- """
99
- Constructor method to initialize the dataset.
100
- """
101
-
66
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
67
+ """Constructor method to initialize the dataset."""
102
68
  super().__init__()
103
69
  self.root = root
104
70
  self.transform = transform
@@ -107,15 +73,12 @@ class BiasCorrection(Dataset):
107
73
  self.download()
108
74
 
109
75
  if not self._check_exists():
110
- raise RuntimeError(
111
- "Dataset not found. You can use download=True to download it"
112
- )
76
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
113
77
 
114
78
  self.data_input, self.data_output = self._load_data()
115
79
 
116
80
  def _load_data(self):
117
- """
118
- Loads the dataset from the CSV file and applies the transformation.
81
+ """Loads the dataset from the CSV file and applies the transformation.
119
82
 
120
83
  The data is read from the `Bias_correction_ucl.csv` file, and the
121
84
  transformation function is applied to it.
@@ -125,7 +88,6 @@ class BiasCorrection(Dataset):
125
88
  Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
126
89
  and output data as numpy arrays.
127
90
  """
128
-
129
91
  data: pd.DataFrame = pd.read_csv(
130
92
  os.path.join(self.data_folder, "Bias_correction_ucl.csv")
131
93
  ).pipe(self.transform)
@@ -136,48 +98,42 @@ class BiasCorrection(Dataset):
136
98
  return data_input, data_output
137
99
 
138
100
  def __len__(self):
139
- """
140
- Returns the number of examples in the dataset.
101
+ """Returns the number of examples in the dataset.
141
102
 
142
103
  Returns:
143
104
  int: The number of examples in the dataset
144
105
  (i.e., the number of rows in the input data).
145
106
  """
146
-
147
107
  return self.data_input.shape[0]
148
108
 
149
109
  def __getitem__(self, idx):
150
- """
151
- Returns the input-output pair for a given index.
110
+ """Returns the input-output pair for a given index.
152
111
 
153
112
  Args:
154
113
  idx (int): The index of the example to retrieve.
155
114
 
156
115
  Returns:
157
- Tuple[torch.Tensor, torch.Tensor]: The input-output pair
158
- as PyTorch tensors.
116
+ dict: A dictionary with the following keys:
117
+ - "input" (torch.Tensor): The input features for the example.
118
+ - "target" (torch.Tensor): The target output for the example.
159
119
  """
160
-
161
120
  example = self.data_input[idx, :]
162
121
  target = self.data_output[idx, :]
163
122
  example = torch.tensor(example)
164
123
  target = torch.tensor(target)
165
- return example, target
124
+ return {"input": example, "target": target}
166
125
 
167
126
  @property
168
127
  def data_folder(self) -> str:
169
- """
170
- Returns the path to the folder where the dataset is stored.
128
+ """Returns the path to the folder where the dataset is stored.
171
129
 
172
130
  Returns:
173
131
  str: The path to the dataset folder.
174
132
  """
175
-
176
133
  return os.path.join(self.root, self.__class__.__name__)
177
134
 
178
135
  def _check_exists(self) -> bool:
179
- """
180
- Checks if the dataset is already downloaded and verified.
136
+ """Checks if the dataset is already downloaded and verified.
181
137
 
182
138
  This method checks that all required files exist and
183
139
  their integrity is validated via MD5 checksums.
@@ -186,15 +142,13 @@ class BiasCorrection(Dataset):
186
142
  bool: True if all resources exist and their
187
143
  integrity is valid, False otherwise.
188
144
  """
189
-
190
145
  return all(
191
146
  check_integrity(os.path.join(self.data_folder, file_path), checksum)
192
147
  for file_path, checksum in self.resources
193
148
  )
194
149
 
195
150
  def download(self) -> None:
196
- """
197
- Downloads and extracts the dataset.
151
+ """Downloads and extracts the dataset.
198
152
 
199
153
  This method attempts to download the dataset from the mirrors and
200
154
  extract it into the appropriate folder. If any error occurs during
@@ -203,7 +157,6 @@ class BiasCorrection(Dataset):
203
157
  Raises:
204
158
  RuntimeError: If all mirrors fail to provide the dataset.
205
159
  """
206
-
207
160
  if self._check_exists():
208
161
  return
209
162
 
@@ -216,10 +169,7 @@ class BiasCorrection(Dataset):
216
169
  url = f"{mirror}{filename}"
217
170
  try:
218
171
  download_and_extract_archive(
219
- url,
220
- download_root=self.data_folder,
221
- filename=filename,
222
- md5=md5,
172
+ url, download_root=self.data_folder, filename=filename, md5=md5
223
173
  )
224
174
  except URLError as e:
225
175
  errors.append(e)
@@ -227,14 +177,13 @@ class BiasCorrection(Dataset):
227
177
  break
228
178
  else:
229
179
  s = f"Error downloading {filename}:\n"
230
- for mirror, err in zip(self.mirrors, errors):
180
+ for mirror, err in zip(self.mirrors, errors, strict=False):
231
181
  s += f"Tried {mirror}, got:\n{str(err)}\n"
232
182
  raise RuntimeError(s)
233
183
 
234
184
 
235
185
  class FamilyIncome(Dataset):
236
- """
237
- A dataset class for accessing the Family Income and Expenditure dataset.
186
+ """A dataset class for accessing the Family Income and Expenditure dataset.
238
187
 
239
188
  This class extends the `Dataset` class and provides functionality for
240
189
  downloading, loading, and transforming the Family Income and
@@ -259,27 +208,17 @@ class FamilyIncome(Dataset):
259
208
  """
260
209
 
261
210
  mirrors = [
262
- # pylint: disable-next=line-too-long
263
- "https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure",
211
+ "https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure"
264
212
  ]
265
-
266
213
  resources = [
267
214
  (
268
215
  "archive.zip",
269
216
  "7d74bc7facc3d7c07c4df1c1c6ac563e",
270
- ),
217
+ )
271
218
  ]
272
219
 
273
- def __init__(
274
- self,
275
- root: Union[str, Path],
276
- transform: Callable,
277
- download: bool = False,
278
- ) -> None:
279
- """
280
- Constructor method to initialize the dataset.
281
- """
282
-
220
+ def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
221
+ """Constructor method to initialize the dataset."""
283
222
  super().__init__()
284
223
  self.root = root
285
224
  self.transform = transform
@@ -288,26 +227,22 @@ class FamilyIncome(Dataset):
288
227
  self.download()
289
228
 
290
229
  if not self._check_exists():
291
- raise RuntimeError(
292
- "Dataset not found. You can use download=True to download it."
293
- )
230
+ raise RuntimeError("Dataset not found. You can use download=True to download it.")
294
231
 
295
232
  self.data_input, self.data_output = self._load_data()
296
233
 
297
234
  def _load_data(self):
298
- """
299
- Loads the Family Income and Expenditure dataset from the CSV file
300
- and applies the transformation.
235
+ """Load and transform the Family Income and Expenditure dataset.
301
236
 
302
- The data is read from the `Family Income and Expenditure.csv` file,
303
- and the transformation function is applied to it. The input and
304
- output data are separated and returned as numpy arrays.
237
+ Reads the data from the `Family Income and Expenditure.csv` file located
238
+ in `self.data_folder` and applies the transformation function. The input
239
+ and output columns are extracted and returned as NumPy arrays.
305
240
 
306
241
  Returns:
307
- Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
308
- and output data as numpy arrays.
242
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
243
+ - input data as a NumPy array of type float32
244
+ - output data as a NumPy array of type float32
309
245
  """
310
-
311
246
  data: pd.DataFrame = pd.read_csv(
312
247
  os.path.join(self.data_folder, "Family Income and Expenditure.csv")
313
248
  ).pipe(self.transform)
@@ -318,48 +253,42 @@ class FamilyIncome(Dataset):
318
253
  return data_input, data_output
319
254
 
320
255
  def __len__(self):
321
- """
322
- Returns the number of examples in the dataset.
256
+ """Returns the number of examples in the dataset.
323
257
 
324
258
  Returns:
325
259
  int: The number of examples in the dataset
326
260
  (i.e., the number of rows in the input data).
327
261
  """
328
-
329
262
  return self.data_input.shape[0]
330
263
 
331
264
  def __getitem__(self, idx):
332
- """
333
- Returns the input-output pair for a given index.
265
+ """Returns the input-output pair for a given index.
334
266
 
335
267
  Args:
336
268
  idx (int): The index of the example to retrieve.
337
269
 
338
270
  Returns:
339
- Tuple[torch.Tensor, torch.Tensor]: The input-output pair
340
- as PyTorch tensors.
271
+ dict: A dictionary with the following keys:
272
+ - "input" (torch.Tensor): The input features for the example.
273
+ - "target" (torch.Tensor): The target output for the example.
341
274
  """
342
-
343
275
  example = self.data_input[idx, :]
344
276
  target = self.data_output[idx, :]
345
277
  example = torch.tensor(example)
346
278
  target = torch.tensor(target)
347
- return example, target
279
+ return {"input": example, "target": target}
348
280
 
349
281
  @property
350
282
  def data_folder(self) -> str:
351
- """
352
- Returns the path to the folder where the dataset is stored.
283
+ """Returns the path to the folder where the dataset is stored.
353
284
 
354
285
  Returns:
355
286
  str: The path to the dataset folder.
356
287
  """
357
-
358
288
  return os.path.join(self.root, self.__class__.__name__)
359
289
 
360
290
  def _check_exists(self) -> bool:
361
- """
362
- Checks if the dataset is already downloaded and verified.
291
+ """Checks if the dataset is already downloaded and verified.
363
292
 
364
293
  This method checks that all required files exist and
365
294
  their integrity is validated via MD5 checksums.
@@ -368,15 +297,13 @@ class FamilyIncome(Dataset):
368
297
  bool: True if all resources exist and their
369
298
  integrity is valid, False otherwise.
370
299
  """
371
-
372
300
  return all(
373
301
  check_integrity(os.path.join(self.data_folder, file_path), checksum)
374
302
  for file_path, checksum in self.resources
375
303
  )
376
304
 
377
305
  def download(self) -> None:
378
- """
379
- Downloads and extracts the dataset.
306
+ """Downloads and extracts the dataset.
380
307
 
381
308
  This method attempts to download the dataset from the mirrors
382
309
  and extract it into the appropriate folder. If any error occurs
@@ -385,7 +312,6 @@ class FamilyIncome(Dataset):
385
312
  Raises:
386
313
  RuntimeError: If all mirrors fail to provide the dataset.
387
314
  """
388
-
389
315
  if self._check_exists():
390
316
  return
391
317
 
@@ -398,10 +324,7 @@ class FamilyIncome(Dataset):
398
324
  url = f"{mirror}"
399
325
  try:
400
326
  download_and_extract_archive(
401
- url,
402
- download_root=self.data_folder,
403
- filename=filename,
404
- md5=md5,
327
+ url, download_root=self.data_folder, filename=filename, md5=md5
405
328
  )
406
329
  except URLError as e:
407
330
  errors.append(e)
@@ -409,91 +332,468 @@ class FamilyIncome(Dataset):
409
332
  break
410
333
  else:
411
334
  s = f"Error downloading {filename}:\n"
412
- for mirror, err in zip(self.mirrors, errors):
335
+ for mirror, err in zip(self.mirrors, errors, strict=False):
413
336
  s += f"Tried {mirror}, got:\n{str(err)}\n"
414
337
  raise RuntimeError(s)
415
338
 
416
339
 
417
- class NoisySines(Dataset):
340
+ class SectionedGaussians(Dataset):
341
+ """Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
342
+
343
+ Each section defines a subrange of x-values with its own Gaussian distribution
344
+ (mean and standard deviation). Instead of abrupt transitions, the parameters
345
+ are blended smoothly between sections using a sigmoid function.
346
+
347
+ The resulting signal can represent a continuous process where statistical
348
+ properties gradually evolve over time or position.
349
+
350
+ Features:
351
+ - Input: Gaussian signal samples (y-values)
352
+ - Context: Concatenation of time (x) and normalized energy feature
353
+ - Target: Exponential decay ground truth from 1 at x_min to 0 at x_max
354
+
355
+ Attributes:
356
+ sections (list[dict]): List of section definitions.
357
+ n_samples (int): Total number of samples across all sections.
358
+ n_runs (int): Number of random waveforms generated from base configuration.
359
+ time (torch.Tensor): Sampled x-values, shape [n_samples, 1].
360
+ signal (torch.Tensor): Generated Gaussian signal values, shape [n_samples, 1].
361
+ energies (torch.Tensor): Normalized energy feature, shape [n_samples, 1].
362
+ context (torch.Tensor): Concatenation of time and energy, shape [n_samples, 2].
363
+ x_min (float): Minimum x-value across all sections.
364
+ x_max (float): Maximum x-value across all sections.
365
+ ground_truth_steepness (float): Exponential decay steepness for target output.
366
+ blend_k (float): Sharpness parameter controlling how rapidly means and
367
+ standard deviations transition between sections.
418
368
  """
419
- A PyTorch dataset generating samples from a causal
420
- sine wave with added noise.
369
+
370
+ def __init__(
371
+ self,
372
+ sections: list[dict],
373
+ n_samples: int = 1000,
374
+ n_runs: int = 1,
375
+ seed: int | None = None,
376
+ device="cpu",
377
+ ground_truth_steepness: float = 0.0,
378
+ blend_k: float = 10.0,
379
+ ):
380
+ """Initializes the dataset and generates smoothly blended Gaussian signals.
381
+
382
+ Args:
383
+ sections (list[dict]): List of sections. Each section must define:
384
+ - range (tuple[float, float]): Start and end of the x-interval.
385
+ - mean (float): Mean of the Gaussian for this section.
386
+ - std (float): Standard deviation of the Gaussian for this section.
387
+ - max_splits (int): Max number of extra splits per section.
388
+ - split_prob (float): Probability of splitting a section.
389
+ - mean_var (float): How much to vary mean (fraction of original).
390
+ - std_var (float): How much to vary std (fraction of original).
391
+ - range_var (float): How much to vary start/end positions (fraction of section length).
392
+ n_samples (int, optional): Total number of samples to generate. Defaults to 1000.
393
+ n_runs (int, optional): Number of random waveforms to generate from the
394
+ base configuration. Defaults to 1.
395
+ seed (int or None, optional): Random seed for reproducibility. Defaults to None.
396
+ device (str or torch.device, optional): Device on which tensors are allocated.
397
+ Defaults to "cpu".
398
+ ground_truth_steepness (float, optional): Controls how sharply the ground-truth
399
+ exponential decay decreases from 1 to 0. Defaults to 0.0 (linear decay).
400
+ blend_k (float, optional): Controls the sharpness of the sigmoid blending
401
+ between sections. Higher values make the transition steeper; lower
402
+ values make it smoother. Defaults to 10.0.
403
+ """
404
+ self.sections = sections
405
+ self.n_samples = n_samples
406
+ self.n_runs = n_runs
407
+ self.device = device
408
+ self.blend_k = blend_k
409
+ self.ground_truth_steepness = torch.tensor(ground_truth_steepness, device=device)
410
+
411
+ if seed is not None:
412
+ torch.manual_seed(seed)
413
+
414
+ time: list[torch.Tensor] = []
415
+ signal: list[torch.Tensor] = []
416
+ energy: list[torch.Tensor] = []
417
+ identifier: list[torch.Tensor] = []
418
+
419
+ # Compute global min/max for time
420
+ self.t_min = min(s["range"][0] for s in sections)
421
+ self.t_max = max(s["range"][1] for s in sections)
422
+
423
+ # Generate waveforms from base configuration
424
+ for run in range(self.n_runs):
425
+ waveform = self._generate_waveform()
426
+ time.append(waveform[0])
427
+ signal.append(waveform[1])
428
+ energy.append(waveform[2])
429
+ identifier.append(torch.full_like(waveform[0], float(run)))
430
+
431
+ # Concatenate runs into single tensors
432
+ self.time = torch.cat(time, dim=0)
433
+ self.signal = torch.cat(signal, dim=0)
434
+ self.energy = torch.cat(energy, dim=0)
435
+ self.identifier = torch.cat(identifier, dim=0)
436
+ self.context = torch.hstack([self.time, self.energy, self.identifier])
437
+
438
+ # Adjust n_samples in case of rounding mismatch
439
+ self.n_samples = len(self.time)
440
+
441
+ def _generate_waveform(
442
+ self,
443
+ ):
444
+ sections = []
445
+
446
+ prev_mean = 0
447
+
448
+ for sec in self.sections:
449
+ start, end = sec["range"]
450
+ mean, std = sec["add_mean"], sec["std"]
451
+ max_splits = sec["max_splits"]
452
+ split_prob = sec["split_prob"]
453
+ mean_scale = sec["mean_var"]
454
+ std_scale = sec["std_var"]
455
+ range_scale = sec["range_var"]
456
+ section_len = end - start
457
+
458
+ # Decide whether to split this section into multiple contiguous parts
459
+ n_splits = 1
460
+ if random.random() < split_prob:
461
+ n_splits += random.randint(1, max_splits)
462
+
463
+ # Randomly divide the section into contiguous subsections
464
+ random_fracs = Dirichlet(torch.ones(n_splits) * 3).sample()
465
+ sub_lengths = (random_fracs * section_len).tolist()
466
+
467
+ sub_start = start
468
+ for base_len in sub_lengths:
469
+ # Compute unperturbed subsection boundaries
470
+ base_start = sub_start
471
+ base_end = base_start + base_len
472
+
473
+ # Define smooth range variation parameters
474
+ max_shift = range_scale * base_len / 2
475
+
476
+ # Decide how much to increase the mean overall
477
+ prev_mean += mean * random.uniform(0.0, 2 * mean_scale)
478
+
479
+ # Apply small random shifts to start/end, relative to subsection size
480
+ if range_scale > 0:
481
+ # Keep total ordering consistent (avoid overlaps or inversions)
482
+ local_shift = random.uniform(-max_shift / n_splits, max_shift / n_splits)
483
+ new_start = base_start + local_shift
484
+ new_end = base_end + local_shift
485
+ else:
486
+ new_start, new_end = base_start, base_end
487
+
488
+ # Clamp to valid time bounds
489
+ new_start = max(self.t_min, min(new_start, end))
490
+ new_end = max(new_start, min(new_end, end))
491
+
492
+ # Randomize mean and std within allowed ranges
493
+ mean_var = prev_mean * (1 + random.uniform(-mean_scale, mean_scale))
494
+ std_var = std * (1 + random.uniform(-std_scale, std_scale))
495
+
496
+ sections.append(
497
+ {
498
+ "range": (new_start, new_end),
499
+ "mean": mean_var,
500
+ "std": std_var,
501
+ }
502
+ )
503
+
504
+ # Prepare for next subsection
505
+ sub_start = base_end
506
+ prev_mean = mean_var
507
+
508
+ # Ensure last section ends exactly at x_max
509
+ sections[-1]["range"] = (sections[-1]["range"][0], self.t_max)
510
+
511
+ # Precompute exact float counts
512
+ section_lengths = [s["range"][1] - s["range"][0] for s in sections]
513
+ total_length = sum(section_lengths)
514
+ exact_counts = [length / total_length * self.n_samples for length in section_lengths]
515
+
516
+ # Allocate integer samples with rounding, track residual
517
+ n_section_samples_list = []
518
+ residual = 0.0
519
+ for exact in exact_counts[:-1]:
520
+ n = int(round(exact + residual))
521
+ n_section_samples_list.append(n)
522
+ residual += exact - n # carry over rounding error
523
+
524
+ # Assign remaining samples to last section
525
+ n_section_samples_list.append(self.n_samples - sum(n_section_samples_list))
526
+
527
+ # Generate data for each section
528
+ signal_segments: list[torch.Tensor] = []
529
+ for i, section in enumerate(sections):
530
+ start, end = section["range"]
531
+ mean, std = section["mean"], section["std"]
532
+
533
+ # Samples proportional to section length
534
+ n_section_samples = n_section_samples_list[i]
535
+
536
+ # Determine next section’s parameters for blending
537
+ if i < len(sections) - 1:
538
+ next_mean = sections[i + 1]["mean"]
539
+ next_std = sections[i + 1]["std"]
540
+ else:
541
+ next_mean = mean
542
+ next_std = std
543
+
544
+ # Sigmoid-based blend curve from 0 → 1
545
+ x = torch.linspace(
546
+ -self.blend_k, self.blend_k, n_section_samples, device=self.device
547
+ ).unsqueeze(1)
548
+ fade = torch.sigmoid(x)
549
+ fade = (fade - fade.min()) / (fade.max() - fade.min())
550
+
551
+ # Interpolate mean and std within the section
552
+ mean_curve = mean + (next_mean - mean) * fade
553
+ std_curve = std + (next_std - std) * fade
554
+
555
+ # Sample from a gradually changing Gaussian
556
+ y = torch.normal(mean=mean_curve, std=std_curve)
557
+ signal_segments.append(y)
558
+
559
+ # Concatenate tensors
560
+ time: torch.Tensor = torch.linspace(
561
+ self.t_min, self.t_max, self.n_samples, device=self.device
562
+ ).unsqueeze(1)
563
+ signal: torch.Tensor = torch.cat(signal_segments, dim=0)
564
+
565
+ # Compute and normalize energy feature
566
+ energy = torch.linalg.vector_norm(
567
+ signal - signal[0].reshape([1, -1]), ord=2, dim=1
568
+ ).unsqueeze(1)
569
+ min_e, max_e = energy.min(), energy.max()
570
+ energy = 1 - (energy - min_e) / (max_e - min_e) * 3
571
+
572
+ # Combine into context tensor
573
+ return time, signal, energy
574
+
575
+ def _compute_ground_truth(self, time: torch.Tensor) -> torch.Tensor:
576
+ """Computes the ground-truth exponential decay at coordinate t.
577
+
578
+ Args:
579
+ time (torch.Tensor): Input coordinate (shape: [1]).
580
+
581
+ Returns:
582
+ torch.Tensor: Corresponding target value between 0 and 1.
583
+ """
584
+ time_norm = (time - self.t_min) / (self.t_max - self.t_min)
585
+ k = self.ground_truth_steepness
586
+ if k == 0:
587
+ return 1 - time_norm
588
+ return (torch.exp(-k * time_norm) - torch.exp(-k)) / (1 - torch.exp(-k))
589
+
590
+ def __len__(self):
591
+ """Returns the total number of generated samples."""
592
+ return self.n_samples
593
+
594
+ def __getitem__(self, idx):
595
+ """Retrieves a single dataset sample.
596
+
597
+ Args:
598
+ idx (int): Sample index.
599
+
600
+ Returns:
601
+ dict:
602
+ - "input": Gaussian signal value (torch.Tensor),
603
+ - "context: Concatenated features (time, energy, identifier),
604
+ - "target": Ground-truth exponential decay value
605
+
606
+ """
607
+ sigal = self.signal[idx]
608
+ context = self.context[idx]
609
+ time = self.time[idx]
610
+ target = self._compute_ground_truth(time)
611
+
612
+ return {
613
+ "input": sigal,
614
+ "context": context,
615
+ "target": target,
616
+ }
617
+
618
+
619
+ class SyntheticMonotonicity(Dataset):
620
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
621
+
622
+ True function:
623
+ y_true(x) = log(1 + x)
624
+
625
+ Observed:
626
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
421
627
 
422
628
  Args:
423
- length (int): Number of data points in the dataset.
424
- amplitude (float): Amplitude of the sine wave.
425
- frequency (float): Frequency of the sine wave in Hz.
426
- noise_std (float): Standard deviation of the Gaussian noise.
427
- bias (float): Offset from zero.
428
-
429
- The sine wave is zero for times before t=0 and follows a
430
- standard sine wave after t=0, with Gaussian noise added to all points.
629
+ n_samples (int): number of data points (default 200)
630
+ x_range (tuple): range of x values (default [0, 5])
631
+ noise_base (float): baseline noise level (default 0.05)
632
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
633
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
634
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
635
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
636
+ osc_frequency (float): frequency of oscillation (default 6.0)
637
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
638
+ seed (int or None): random seed
431
639
  """
432
640
 
433
641
  def __init__(
434
642
  self,
435
- length,
436
- amplitude=1,
437
- frequency=10.0,
438
- noise_std=0.05,
439
- bias=0,
440
- random_seed=42,
643
+ n_samples=200,
644
+ x_range=(0.0, 5.0),
645
+ noise_base=0.05,
646
+ noise_scale=0.15,
647
+ noise_sharpness=4.0,
648
+ noise_center=2.5,
649
+ osc_amplitude=0.08,
650
+ osc_frequency=6.0,
651
+ osc_prob=0.5,
652
+ seed=None,
441
653
  ):
654
+ """Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
655
+
656
+ True function:
657
+ y_true(x) = log(1 + x)
658
+
659
+ Observed:
660
+ y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
661
+
662
+ Args:
663
+ n_samples (int): number of data points (default 200)
664
+ x_range (tuple): range of x values (default [0, 5])
665
+ noise_base (float): baseline noise level (default 0.05)
666
+ noise_scale (float): scale of heteroscedastic noise (default 0.15)
667
+ noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
668
+ noise_center (float): center point of heteroscedastic increase (default 2.5)
669
+ osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
670
+ osc_frequency (float): frequency of oscillation (default 6.0)
671
+ osc_prob (float): probability each sample receives oscillation (default 0.5)
672
+ seed (int or None): random seed
442
673
  """
443
- Initializes the NoisyCausalSine dataset.
444
- """
445
- self.length = length
446
- self.amplitude = amplitude
447
- self.frequency = frequency
448
- self.noise_std = noise_std
449
- self.bias = bias
450
- self.random_seed = random_seed
674
+ super().__init__()
675
+ if seed is not None:
676
+ np.random.seed(seed)
451
677
 
452
- np.random.seed(self.random_seed)
453
- self.time = np.linspace(0, 1, length)
454
- self.noise = np.random.normal(0, self.noise_std, length)
678
+ self.n_samples = n_samples
679
+ self.x_range = x_range
455
680
 
456
- def __getitem__(self, idx):
681
+ # Sample inputs
682
+ self.x = np.random.rand(n_samples) * (x_range[1] - x_range[0]) + x_range[0]
683
+ self.x = np.sort(self.x)
684
+
685
+ # Heteroscedastic noise (logistic growth with x)
686
+ noise_sigma = noise_base + noise_scale / (
687
+ 1 + np.exp(-noise_sharpness * (self.x - noise_center))
688
+ )
689
+
690
+ # Oscillatory perturbation (applied with probability osc_prob)
691
+ mask = (np.random.rand(n_samples) < osc_prob).astype(float)
692
+ osc = mask * (osc_amplitude * np.sin(osc_frequency * self.x))
693
+
694
+ # Final observed target
695
+ self.y = np.log1p(self.x) + osc + noise_sigma * np.random.randn(n_samples)
696
+
697
+ # Convert to tensors
698
+ self.inputs = torch.tensor(self.x, dtype=torch.float32).unsqueeze(1)
699
+ self.targets = torch.tensor(self.y, dtype=torch.float32).unsqueeze(1)
700
+
701
+ def __len__(self):
702
+ """Return the total number of samples in the dataset.
703
+
704
+ Returns:
705
+ int: The number of generated data points.
457
706
  """
458
- Returns the time and noisy sine wave value for a given index.
707
+ return self.n_samples
708
+
709
+ def __getitem__(self, idx) -> dict:
710
+ """Retrieve a single sample from the dataset.
459
711
 
460
712
  Args:
461
- idx (int): Index of the data point to retrieve.
713
+ idx (int): Index of the sample to retrieve.
462
714
 
463
715
  Returns:
464
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing the
465
- time value and the noisy sine wave value.
716
+ dict: A dictionary with the following keys:
717
+ - "input" (torch.Tensor): The x value.
718
+ - "target" (torch.Tensor): The corresponding y value.
466
719
  """
720
+ return {"input": self.inputs[idx], "target": self.targets[idx]}
467
721
 
468
- t = self.time[idx]
469
- if idx < self.length // 2:
470
- sine_value = self.bias
471
- cosine_value = self.bias
472
- else:
473
- sine_value = (
474
- self.amplitude * np.sin(2 * np.pi * self.frequency * t)
475
- + self.bias
476
- )
477
- cosine_value = (
478
- self.amplitude * np.cos(2 * np.pi * self.frequency * t)
479
- + self.bias
480
- )
481
-
482
- # Add noise to the signals
483
- noisy_sine = sine_value + self.noise[idx]
484
- noisy_cosine = cosine_value + self.noise[idx]
485
-
486
- # Convert to tensor
487
- example, target = torch.tensor([t], dtype=torch.float32), torch.tensor(
488
- [noisy_sine, noisy_cosine], dtype=torch.float32
489
- )
490
- return example, target
722
+
723
+ class SyntheticClusters(Dataset):
724
+ """PyTorch dataset for generating synthetic clustered 2D data with labels.
725
+
726
+ Each cluster is defined by its center, size, spread (standard deviation), and label.
727
+ The dataset samples points from a Gaussian distribution centered at the cluster mean.
728
+
729
+ Args:
730
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
731
+ e.g. [(x1, y1), (x2, y2), ...].
732
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
733
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
734
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
735
+
736
+ Raises:
737
+ AssertionError: If the input lists do not all have the same length.
738
+
739
+ Attributes:
740
+ data (torch.Tensor): A concatenated tensor of all generated points with shape (N, 2).
741
+ labels (torch.Tensor): A concatenated tensor of class labels with shape (N,),
742
+ where N is the total number of generated points.
743
+ """
744
+
745
+ def __init__(self, cluster_centers, cluster_sizes, cluster_std, cluster_labels):
746
+ """Initialize the ClusterDataset.
747
+
748
+ Args:
749
+ cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
750
+ e.g. [(x1, y1), (x2, y2), ...].
751
+ cluster_sizes (list[int]): Number of points to generate in each cluster.
752
+ cluster_std (list[float]): Standard deviation (spread) of each cluster.
753
+ cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
754
+
755
+ Raises:
756
+ AssertionError: If the input lists do not all have the same length.
757
+ """
758
+ assert (
759
+ len(cluster_centers) == len(cluster_sizes) == len(cluster_std) == len(cluster_labels)
760
+ ), "All input lists must have the same length"
761
+
762
+ self.data = []
763
+ self.labels = []
764
+
765
+ # Generate points for each cluster
766
+ for center, size, std, label in zip(
767
+ cluster_centers, cluster_sizes, cluster_std, cluster_labels, strict=False
768
+ ):
769
+ x = torch.normal(mean=center[0], std=std, size=(size, 1))
770
+ y = torch.normal(mean=center[1], std=std, size=(size, 1))
771
+ points = torch.cat([x, y], dim=1)
772
+
773
+ self.data.append(points)
774
+ self.labels.append(torch.full((size,), label, dtype=torch.long))
775
+
776
+ # Concatenate all clusters
777
+ self.data = torch.cat(self.data, dim=0)
778
+ self.labels = torch.cat(self.labels, dim=0)
491
779
 
492
780
  def __len__(self):
781
+ """Return the total number of samples in the dataset.
782
+
783
+ Returns:
784
+ int: The number of generated data points.
493
785
  """
494
- Returns the total number of data points in the dataset.
786
+ return len(self.data)
787
+
788
+ def __getitem__(self, idx) -> dict:
789
+ """Retrieve a single sample from the dataset.
790
+
791
+ Args:
792
+ idx (int): Index of the sample to retrieve.
495
793
 
496
794
  Returns:
497
- int: The length of the dataset.
795
+ dict: A dictionary with the following keys:
796
+ - "input" (torch.Tensor): The 2D point at index `idx`.
797
+ - "target" (torch.Tensor): The corresponding class label.
498
798
  """
499
- return self.length
799
+ return {"input": self.data[idx], "target": self.labels[idx]}