congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +813 -476
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +118 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.6.dist-info → congrads-1.1.0.dist-info}/METADATA +48 -41
- congrads-1.1.0.dist-info/RECORD +14 -0
- congrads-1.1.0.dist-info/WHEEL +4 -0
- congrads-1.0.6.dist-info/LICENSE +0 -26
- congrads-1.0.6.dist-info/RECORD +0 -15
- congrads-1.0.6.dist-info/WHEEL +0 -5
- congrads-1.0.6.dist-info/top_level.txt +0 -1
congrads/datasets.py
CHANGED
|
@@ -1,51 +1,30 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
downloading, loading, and transforming specific datasets.
|
|
1
|
+
"""This module defines several PyTorch dataset classes for loading and working with various datasets.
|
|
2
|
+
|
|
3
|
+
Each dataset class extends the `torch.utils.data.Dataset` class and provides functionality for
|
|
4
|
+
downloading, loading, and transforming specific datasets where applicable.
|
|
6
5
|
|
|
7
6
|
Classes:
|
|
8
7
|
|
|
9
|
-
-
|
|
10
|
-
|
|
11
|
-
- FamilyIncome: A dataset class for the Family Income and
|
|
12
|
-
Expenditure dataset.
|
|
13
|
-
- NoisySines: A dataset class that generates noisy sine wave
|
|
14
|
-
samples with added Gaussian noise.
|
|
8
|
+
- SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
|
|
9
|
+
- BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
|
|
10
|
+
- FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
|
|
15
11
|
|
|
16
|
-
Each dataset class provides methods for downloading the data
|
|
17
|
-
(if not already available), checking the integrity of the dataset, loading
|
|
18
|
-
the data from CSV files or generating synthetic data, and applying
|
|
12
|
+
Each dataset class provides methods for downloading the data
|
|
13
|
+
(if not already available or synthetic), checking the integrity of the dataset, loading
|
|
14
|
+
the data from CSV files or generating synthetic data, and applying
|
|
19
15
|
transformations to the data.
|
|
20
|
-
|
|
21
|
-
Key Methods:
|
|
22
|
-
|
|
23
|
-
- `__init__`: Initializes the dataset by specifying the root directory,
|
|
24
|
-
transformation function, and optional download flag.
|
|
25
|
-
- `__getitem__`: Retrieves a specific data point given its index,
|
|
26
|
-
returning input-output pairs.
|
|
27
|
-
- `__len__`: Returns the total number of examples in the dataset.
|
|
28
|
-
- `download`: Downloads and extracts the dataset from
|
|
29
|
-
the specified mirrors.
|
|
30
|
-
- `_load_data`: Loads the dataset from CSV files and
|
|
31
|
-
applies transformations.
|
|
32
|
-
- `_check_exists`: Checks if the dataset is already
|
|
33
|
-
downloaded and verified.
|
|
34
|
-
|
|
35
|
-
Each dataset class allows the user to apply custom transformations to the
|
|
36
|
-
dataset through the `transform` argument to allow pre-processing and offers
|
|
37
|
-
the ability to download the dataset if it's not already present on
|
|
38
|
-
the local disk.
|
|
39
16
|
"""
|
|
40
17
|
|
|
41
18
|
import os
|
|
19
|
+
import random
|
|
20
|
+
from collections.abc import Callable
|
|
42
21
|
from pathlib import Path
|
|
43
|
-
from typing import Callable, Union
|
|
44
22
|
from urllib.error import URLError
|
|
45
23
|
|
|
46
24
|
import numpy as np
|
|
47
25
|
import pandas as pd
|
|
48
26
|
import torch
|
|
27
|
+
from torch.distributions import Dirichlet
|
|
49
28
|
from torch.utils.data import Dataset
|
|
50
29
|
from torchvision.datasets.utils import (
|
|
51
30
|
check_integrity,
|
|
@@ -54,8 +33,7 @@ from torchvision.datasets.utils import (
|
|
|
54
33
|
|
|
55
34
|
|
|
56
35
|
class BiasCorrection(Dataset):
|
|
57
|
-
"""
|
|
58
|
-
A dataset class for accessing the Bias Correction dataset.
|
|
36
|
+
"""A dataset class for accessing the Bias Correction dataset.
|
|
59
37
|
|
|
60
38
|
This class extends the `Dataset` class and provides functionality for
|
|
61
39
|
downloading, loading, and transforming the Bias Correction dataset.
|
|
@@ -77,28 +55,16 @@ class BiasCorrection(Dataset):
|
|
|
77
55
|
is not set to True or if all mirrors fail to provide the dataset.
|
|
78
56
|
"""
|
|
79
57
|
|
|
80
|
-
mirrors = [
|
|
81
|
-
"https://archive.ics.uci.edu/static/public/514/",
|
|
82
|
-
]
|
|
83
|
-
|
|
58
|
+
mirrors = ["https://archive.ics.uci.edu/static/public/514/"]
|
|
84
59
|
resources = [
|
|
85
60
|
(
|
|
86
|
-
# pylint: disable-next=line-too-long
|
|
87
61
|
"bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
|
|
88
62
|
"3deee56d461a2686887c4ae38fe3ccf3",
|
|
89
|
-
)
|
|
63
|
+
)
|
|
90
64
|
]
|
|
91
65
|
|
|
92
|
-
def __init__(
|
|
93
|
-
|
|
94
|
-
root: Union[str, Path],
|
|
95
|
-
transform: Callable,
|
|
96
|
-
download: bool = False,
|
|
97
|
-
) -> None:
|
|
98
|
-
"""
|
|
99
|
-
Constructor method to initialize the dataset.
|
|
100
|
-
"""
|
|
101
|
-
|
|
66
|
+
def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
|
|
67
|
+
"""Constructor method to initialize the dataset."""
|
|
102
68
|
super().__init__()
|
|
103
69
|
self.root = root
|
|
104
70
|
self.transform = transform
|
|
@@ -107,15 +73,12 @@ class BiasCorrection(Dataset):
|
|
|
107
73
|
self.download()
|
|
108
74
|
|
|
109
75
|
if not self._check_exists():
|
|
110
|
-
raise RuntimeError(
|
|
111
|
-
"Dataset not found. You can use download=True to download it"
|
|
112
|
-
)
|
|
76
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
|
113
77
|
|
|
114
78
|
self.data_input, self.data_output = self._load_data()
|
|
115
79
|
|
|
116
80
|
def _load_data(self):
|
|
117
|
-
"""
|
|
118
|
-
Loads the dataset from the CSV file and applies the transformation.
|
|
81
|
+
"""Loads the dataset from the CSV file and applies the transformation.
|
|
119
82
|
|
|
120
83
|
The data is read from the `Bias_correction_ucl.csv` file, and the
|
|
121
84
|
transformation function is applied to it.
|
|
@@ -125,7 +88,6 @@ class BiasCorrection(Dataset):
|
|
|
125
88
|
Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
|
|
126
89
|
and output data as numpy arrays.
|
|
127
90
|
"""
|
|
128
|
-
|
|
129
91
|
data: pd.DataFrame = pd.read_csv(
|
|
130
92
|
os.path.join(self.data_folder, "Bias_correction_ucl.csv")
|
|
131
93
|
).pipe(self.transform)
|
|
@@ -136,48 +98,42 @@ class BiasCorrection(Dataset):
|
|
|
136
98
|
return data_input, data_output
|
|
137
99
|
|
|
138
100
|
def __len__(self):
|
|
139
|
-
"""
|
|
140
|
-
Returns the number of examples in the dataset.
|
|
101
|
+
"""Returns the number of examples in the dataset.
|
|
141
102
|
|
|
142
103
|
Returns:
|
|
143
104
|
int: The number of examples in the dataset
|
|
144
105
|
(i.e., the number of rows in the input data).
|
|
145
106
|
"""
|
|
146
|
-
|
|
147
107
|
return self.data_input.shape[0]
|
|
148
108
|
|
|
149
109
|
def __getitem__(self, idx):
|
|
150
|
-
"""
|
|
151
|
-
Returns the input-output pair for a given index.
|
|
110
|
+
"""Returns the input-output pair for a given index.
|
|
152
111
|
|
|
153
112
|
Args:
|
|
154
113
|
idx (int): The index of the example to retrieve.
|
|
155
114
|
|
|
156
115
|
Returns:
|
|
157
|
-
|
|
158
|
-
|
|
116
|
+
dict: A dictionary with the following keys:
|
|
117
|
+
- "input" (torch.Tensor): The input features for the example.
|
|
118
|
+
- "target" (torch.Tensor): The target output for the example.
|
|
159
119
|
"""
|
|
160
|
-
|
|
161
120
|
example = self.data_input[idx, :]
|
|
162
121
|
target = self.data_output[idx, :]
|
|
163
122
|
example = torch.tensor(example)
|
|
164
123
|
target = torch.tensor(target)
|
|
165
|
-
return example, target
|
|
124
|
+
return {"input": example, "target": target}
|
|
166
125
|
|
|
167
126
|
@property
|
|
168
127
|
def data_folder(self) -> str:
|
|
169
|
-
"""
|
|
170
|
-
Returns the path to the folder where the dataset is stored.
|
|
128
|
+
"""Returns the path to the folder where the dataset is stored.
|
|
171
129
|
|
|
172
130
|
Returns:
|
|
173
131
|
str: The path to the dataset folder.
|
|
174
132
|
"""
|
|
175
|
-
|
|
176
133
|
return os.path.join(self.root, self.__class__.__name__)
|
|
177
134
|
|
|
178
135
|
def _check_exists(self) -> bool:
|
|
179
|
-
"""
|
|
180
|
-
Checks if the dataset is already downloaded and verified.
|
|
136
|
+
"""Checks if the dataset is already downloaded and verified.
|
|
181
137
|
|
|
182
138
|
This method checks that all required files exist and
|
|
183
139
|
their integrity is validated via MD5 checksums.
|
|
@@ -186,15 +142,13 @@ class BiasCorrection(Dataset):
|
|
|
186
142
|
bool: True if all resources exist and their
|
|
187
143
|
integrity is valid, False otherwise.
|
|
188
144
|
"""
|
|
189
|
-
|
|
190
145
|
return all(
|
|
191
146
|
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
192
147
|
for file_path, checksum in self.resources
|
|
193
148
|
)
|
|
194
149
|
|
|
195
150
|
def download(self) -> None:
|
|
196
|
-
"""
|
|
197
|
-
Downloads and extracts the dataset.
|
|
151
|
+
"""Downloads and extracts the dataset.
|
|
198
152
|
|
|
199
153
|
This method attempts to download the dataset from the mirrors and
|
|
200
154
|
extract it into the appropriate folder. If any error occurs during
|
|
@@ -203,7 +157,6 @@ class BiasCorrection(Dataset):
|
|
|
203
157
|
Raises:
|
|
204
158
|
RuntimeError: If all mirrors fail to provide the dataset.
|
|
205
159
|
"""
|
|
206
|
-
|
|
207
160
|
if self._check_exists():
|
|
208
161
|
return
|
|
209
162
|
|
|
@@ -216,10 +169,7 @@ class BiasCorrection(Dataset):
|
|
|
216
169
|
url = f"{mirror}{filename}"
|
|
217
170
|
try:
|
|
218
171
|
download_and_extract_archive(
|
|
219
|
-
url,
|
|
220
|
-
download_root=self.data_folder,
|
|
221
|
-
filename=filename,
|
|
222
|
-
md5=md5,
|
|
172
|
+
url, download_root=self.data_folder, filename=filename, md5=md5
|
|
223
173
|
)
|
|
224
174
|
except URLError as e:
|
|
225
175
|
errors.append(e)
|
|
@@ -227,14 +177,13 @@ class BiasCorrection(Dataset):
|
|
|
227
177
|
break
|
|
228
178
|
else:
|
|
229
179
|
s = f"Error downloading {filename}:\n"
|
|
230
|
-
for mirror, err in zip(self.mirrors, errors):
|
|
180
|
+
for mirror, err in zip(self.mirrors, errors, strict=False):
|
|
231
181
|
s += f"Tried {mirror}, got:\n{str(err)}\n"
|
|
232
182
|
raise RuntimeError(s)
|
|
233
183
|
|
|
234
184
|
|
|
235
185
|
class FamilyIncome(Dataset):
|
|
236
|
-
"""
|
|
237
|
-
A dataset class for accessing the Family Income and Expenditure dataset.
|
|
186
|
+
"""A dataset class for accessing the Family Income and Expenditure dataset.
|
|
238
187
|
|
|
239
188
|
This class extends the `Dataset` class and provides functionality for
|
|
240
189
|
downloading, loading, and transforming the Family Income and
|
|
@@ -259,27 +208,17 @@ class FamilyIncome(Dataset):
|
|
|
259
208
|
"""
|
|
260
209
|
|
|
261
210
|
mirrors = [
|
|
262
|
-
|
|
263
|
-
"https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure",
|
|
211
|
+
"https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure"
|
|
264
212
|
]
|
|
265
|
-
|
|
266
213
|
resources = [
|
|
267
214
|
(
|
|
268
215
|
"archive.zip",
|
|
269
216
|
"7d74bc7facc3d7c07c4df1c1c6ac563e",
|
|
270
|
-
)
|
|
217
|
+
)
|
|
271
218
|
]
|
|
272
219
|
|
|
273
|
-
def __init__(
|
|
274
|
-
|
|
275
|
-
root: Union[str, Path],
|
|
276
|
-
transform: Callable,
|
|
277
|
-
download: bool = False,
|
|
278
|
-
) -> None:
|
|
279
|
-
"""
|
|
280
|
-
Constructor method to initialize the dataset.
|
|
281
|
-
"""
|
|
282
|
-
|
|
220
|
+
def __init__(self, root: str | Path, transform: Callable, download: bool = False) -> None:
|
|
221
|
+
"""Constructor method to initialize the dataset."""
|
|
283
222
|
super().__init__()
|
|
284
223
|
self.root = root
|
|
285
224
|
self.transform = transform
|
|
@@ -288,26 +227,22 @@ class FamilyIncome(Dataset):
|
|
|
288
227
|
self.download()
|
|
289
228
|
|
|
290
229
|
if not self._check_exists():
|
|
291
|
-
raise RuntimeError(
|
|
292
|
-
"Dataset not found. You can use download=True to download it."
|
|
293
|
-
)
|
|
230
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it.")
|
|
294
231
|
|
|
295
232
|
self.data_input, self.data_output = self._load_data()
|
|
296
233
|
|
|
297
234
|
def _load_data(self):
|
|
298
|
-
"""
|
|
299
|
-
Loads the Family Income and Expenditure dataset from the CSV file
|
|
300
|
-
and applies the transformation.
|
|
235
|
+
"""Load and transform the Family Income and Expenditure dataset.
|
|
301
236
|
|
|
302
|
-
|
|
303
|
-
and the transformation function
|
|
304
|
-
output
|
|
237
|
+
Reads the data from the `Family Income and Expenditure.csv` file located
|
|
238
|
+
in `self.data_folder` and applies the transformation function. The input
|
|
239
|
+
and output columns are extracted and returned as NumPy arrays.
|
|
305
240
|
|
|
306
241
|
Returns:
|
|
307
|
-
Tuple[
|
|
308
|
-
|
|
242
|
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
|
243
|
+
- input data as a NumPy array of type float32
|
|
244
|
+
- output data as a NumPy array of type float32
|
|
309
245
|
"""
|
|
310
|
-
|
|
311
246
|
data: pd.DataFrame = pd.read_csv(
|
|
312
247
|
os.path.join(self.data_folder, "Family Income and Expenditure.csv")
|
|
313
248
|
).pipe(self.transform)
|
|
@@ -318,48 +253,42 @@ class FamilyIncome(Dataset):
|
|
|
318
253
|
return data_input, data_output
|
|
319
254
|
|
|
320
255
|
def __len__(self):
|
|
321
|
-
"""
|
|
322
|
-
Returns the number of examples in the dataset.
|
|
256
|
+
"""Returns the number of examples in the dataset.
|
|
323
257
|
|
|
324
258
|
Returns:
|
|
325
259
|
int: The number of examples in the dataset
|
|
326
260
|
(i.e., the number of rows in the input data).
|
|
327
261
|
"""
|
|
328
|
-
|
|
329
262
|
return self.data_input.shape[0]
|
|
330
263
|
|
|
331
264
|
def __getitem__(self, idx):
|
|
332
|
-
"""
|
|
333
|
-
Returns the input-output pair for a given index.
|
|
265
|
+
"""Returns the input-output pair for a given index.
|
|
334
266
|
|
|
335
267
|
Args:
|
|
336
268
|
idx (int): The index of the example to retrieve.
|
|
337
269
|
|
|
338
270
|
Returns:
|
|
339
|
-
|
|
340
|
-
|
|
271
|
+
dict: A dictionary with the following keys:
|
|
272
|
+
- "input" (torch.Tensor): The input features for the example.
|
|
273
|
+
- "target" (torch.Tensor): The target output for the example.
|
|
341
274
|
"""
|
|
342
|
-
|
|
343
275
|
example = self.data_input[idx, :]
|
|
344
276
|
target = self.data_output[idx, :]
|
|
345
277
|
example = torch.tensor(example)
|
|
346
278
|
target = torch.tensor(target)
|
|
347
|
-
return example, target
|
|
279
|
+
return {"input": example, "target": target}
|
|
348
280
|
|
|
349
281
|
@property
|
|
350
282
|
def data_folder(self) -> str:
|
|
351
|
-
"""
|
|
352
|
-
Returns the path to the folder where the dataset is stored.
|
|
283
|
+
"""Returns the path to the folder where the dataset is stored.
|
|
353
284
|
|
|
354
285
|
Returns:
|
|
355
286
|
str: The path to the dataset folder.
|
|
356
287
|
"""
|
|
357
|
-
|
|
358
288
|
return os.path.join(self.root, self.__class__.__name__)
|
|
359
289
|
|
|
360
290
|
def _check_exists(self) -> bool:
|
|
361
|
-
"""
|
|
362
|
-
Checks if the dataset is already downloaded and verified.
|
|
291
|
+
"""Checks if the dataset is already downloaded and verified.
|
|
363
292
|
|
|
364
293
|
This method checks that all required files exist and
|
|
365
294
|
their integrity is validated via MD5 checksums.
|
|
@@ -368,15 +297,13 @@ class FamilyIncome(Dataset):
|
|
|
368
297
|
bool: True if all resources exist and their
|
|
369
298
|
integrity is valid, False otherwise.
|
|
370
299
|
"""
|
|
371
|
-
|
|
372
300
|
return all(
|
|
373
301
|
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
374
302
|
for file_path, checksum in self.resources
|
|
375
303
|
)
|
|
376
304
|
|
|
377
305
|
def download(self) -> None:
|
|
378
|
-
"""
|
|
379
|
-
Downloads and extracts the dataset.
|
|
306
|
+
"""Downloads and extracts the dataset.
|
|
380
307
|
|
|
381
308
|
This method attempts to download the dataset from the mirrors
|
|
382
309
|
and extract it into the appropriate folder. If any error occurs
|
|
@@ -385,7 +312,6 @@ class FamilyIncome(Dataset):
|
|
|
385
312
|
Raises:
|
|
386
313
|
RuntimeError: If all mirrors fail to provide the dataset.
|
|
387
314
|
"""
|
|
388
|
-
|
|
389
315
|
if self._check_exists():
|
|
390
316
|
return
|
|
391
317
|
|
|
@@ -398,10 +324,7 @@ class FamilyIncome(Dataset):
|
|
|
398
324
|
url = f"{mirror}"
|
|
399
325
|
try:
|
|
400
326
|
download_and_extract_archive(
|
|
401
|
-
url,
|
|
402
|
-
download_root=self.data_folder,
|
|
403
|
-
filename=filename,
|
|
404
|
-
md5=md5,
|
|
327
|
+
url, download_root=self.data_folder, filename=filename, md5=md5
|
|
405
328
|
)
|
|
406
329
|
except URLError as e:
|
|
407
330
|
errors.append(e)
|
|
@@ -409,91 +332,468 @@ class FamilyIncome(Dataset):
|
|
|
409
332
|
break
|
|
410
333
|
else:
|
|
411
334
|
s = f"Error downloading {filename}:\n"
|
|
412
|
-
for mirror, err in zip(self.mirrors, errors):
|
|
335
|
+
for mirror, err in zip(self.mirrors, errors, strict=False):
|
|
413
336
|
s += f"Tried {mirror}, got:\n{str(err)}\n"
|
|
414
337
|
raise RuntimeError(s)
|
|
415
338
|
|
|
416
339
|
|
|
417
|
-
class
|
|
340
|
+
class SectionedGaussians(Dataset):
|
|
341
|
+
"""Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
|
|
342
|
+
|
|
343
|
+
Each section defines a subrange of x-values with its own Gaussian distribution
|
|
344
|
+
(mean and standard deviation). Instead of abrupt transitions, the parameters
|
|
345
|
+
are blended smoothly between sections using a sigmoid function.
|
|
346
|
+
|
|
347
|
+
The resulting signal can represent a continuous process where statistical
|
|
348
|
+
properties gradually evolve over time or position.
|
|
349
|
+
|
|
350
|
+
Features:
|
|
351
|
+
- Input: Gaussian signal samples (y-values)
|
|
352
|
+
- Context: Concatenation of time (x) and normalized energy feature
|
|
353
|
+
- Target: Exponential decay ground truth from 1 at x_min to 0 at x_max
|
|
354
|
+
|
|
355
|
+
Attributes:
|
|
356
|
+
sections (list[dict]): List of section definitions.
|
|
357
|
+
n_samples (int): Total number of samples across all sections.
|
|
358
|
+
n_runs (int): Number of random waveforms generated from base configuration.
|
|
359
|
+
time (torch.Tensor): Sampled x-values, shape [n_samples, 1].
|
|
360
|
+
signal (torch.Tensor): Generated Gaussian signal values, shape [n_samples, 1].
|
|
361
|
+
energies (torch.Tensor): Normalized energy feature, shape [n_samples, 1].
|
|
362
|
+
context (torch.Tensor): Concatenation of time and energy, shape [n_samples, 2].
|
|
363
|
+
x_min (float): Minimum x-value across all sections.
|
|
364
|
+
x_max (float): Maximum x-value across all sections.
|
|
365
|
+
ground_truth_steepness (float): Exponential decay steepness for target output.
|
|
366
|
+
blend_k (float): Sharpness parameter controlling how rapidly means and
|
|
367
|
+
standard deviations transition between sections.
|
|
418
368
|
"""
|
|
419
|
-
|
|
420
|
-
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
sections: list[dict],
|
|
373
|
+
n_samples: int = 1000,
|
|
374
|
+
n_runs: int = 1,
|
|
375
|
+
seed: int | None = None,
|
|
376
|
+
device="cpu",
|
|
377
|
+
ground_truth_steepness: float = 0.0,
|
|
378
|
+
blend_k: float = 10.0,
|
|
379
|
+
):
|
|
380
|
+
"""Initializes the dataset and generates smoothly blended Gaussian signals.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
sections (list[dict]): List of sections. Each section must define:
|
|
384
|
+
- range (tuple[float, float]): Start and end of the x-interval.
|
|
385
|
+
- mean (float): Mean of the Gaussian for this section.
|
|
386
|
+
- std (float): Standard deviation of the Gaussian for this section.
|
|
387
|
+
- max_splits (int): Max number of extra splits per section.
|
|
388
|
+
- split_prob (float): Probability of splitting a section.
|
|
389
|
+
- mean_var (float): How much to vary mean (fraction of original).
|
|
390
|
+
- std_var (float): How much to vary std (fraction of original).
|
|
391
|
+
- range_var (float): How much to vary start/end positions (fraction of section length).
|
|
392
|
+
n_samples (int, optional): Total number of samples to generate. Defaults to 1000.
|
|
393
|
+
n_runs (int, optional): Number of random waveforms to generate from the
|
|
394
|
+
base configuration. Defaults to 1.
|
|
395
|
+
seed (int or None, optional): Random seed for reproducibility. Defaults to None.
|
|
396
|
+
device (str or torch.device, optional): Device on which tensors are allocated.
|
|
397
|
+
Defaults to "cpu".
|
|
398
|
+
ground_truth_steepness (float, optional): Controls how sharply the ground-truth
|
|
399
|
+
exponential decay decreases from 1 to 0. Defaults to 0.0 (linear decay).
|
|
400
|
+
blend_k (float, optional): Controls the sharpness of the sigmoid blending
|
|
401
|
+
between sections. Higher values make the transition steeper; lower
|
|
402
|
+
values make it smoother. Defaults to 10.0.
|
|
403
|
+
"""
|
|
404
|
+
self.sections = sections
|
|
405
|
+
self.n_samples = n_samples
|
|
406
|
+
self.n_runs = n_runs
|
|
407
|
+
self.device = device
|
|
408
|
+
self.blend_k = blend_k
|
|
409
|
+
self.ground_truth_steepness = torch.tensor(ground_truth_steepness, device=device)
|
|
410
|
+
|
|
411
|
+
if seed is not None:
|
|
412
|
+
torch.manual_seed(seed)
|
|
413
|
+
|
|
414
|
+
time: list[torch.Tensor] = []
|
|
415
|
+
signal: list[torch.Tensor] = []
|
|
416
|
+
energy: list[torch.Tensor] = []
|
|
417
|
+
identifier: list[torch.Tensor] = []
|
|
418
|
+
|
|
419
|
+
# Compute global min/max for time
|
|
420
|
+
self.t_min = min(s["range"][0] for s in sections)
|
|
421
|
+
self.t_max = max(s["range"][1] for s in sections)
|
|
422
|
+
|
|
423
|
+
# Generate waveforms from base configuration
|
|
424
|
+
for run in range(self.n_runs):
|
|
425
|
+
waveform = self._generate_waveform()
|
|
426
|
+
time.append(waveform[0])
|
|
427
|
+
signal.append(waveform[1])
|
|
428
|
+
energy.append(waveform[2])
|
|
429
|
+
identifier.append(torch.full_like(waveform[0], float(run)))
|
|
430
|
+
|
|
431
|
+
# Concatenate runs into single tensors
|
|
432
|
+
self.time = torch.cat(time, dim=0)
|
|
433
|
+
self.signal = torch.cat(signal, dim=0)
|
|
434
|
+
self.energy = torch.cat(energy, dim=0)
|
|
435
|
+
self.identifier = torch.cat(identifier, dim=0)
|
|
436
|
+
self.context = torch.hstack([self.time, self.energy, self.identifier])
|
|
437
|
+
|
|
438
|
+
# Adjust n_samples in case of rounding mismatch
|
|
439
|
+
self.n_samples = len(self.time)
|
|
440
|
+
|
|
441
|
+
def _generate_waveform(
|
|
442
|
+
self,
|
|
443
|
+
):
|
|
444
|
+
sections = []
|
|
445
|
+
|
|
446
|
+
prev_mean = 0
|
|
447
|
+
|
|
448
|
+
for sec in self.sections:
|
|
449
|
+
start, end = sec["range"]
|
|
450
|
+
mean, std = sec["add_mean"], sec["std"]
|
|
451
|
+
max_splits = sec["max_splits"]
|
|
452
|
+
split_prob = sec["split_prob"]
|
|
453
|
+
mean_scale = sec["mean_var"]
|
|
454
|
+
std_scale = sec["std_var"]
|
|
455
|
+
range_scale = sec["range_var"]
|
|
456
|
+
section_len = end - start
|
|
457
|
+
|
|
458
|
+
# Decide whether to split this section into multiple contiguous parts
|
|
459
|
+
n_splits = 1
|
|
460
|
+
if random.random() < split_prob:
|
|
461
|
+
n_splits += random.randint(1, max_splits)
|
|
462
|
+
|
|
463
|
+
# Randomly divide the section into contiguous subsections
|
|
464
|
+
random_fracs = Dirichlet(torch.ones(n_splits) * 3).sample()
|
|
465
|
+
sub_lengths = (random_fracs * section_len).tolist()
|
|
466
|
+
|
|
467
|
+
sub_start = start
|
|
468
|
+
for base_len in sub_lengths:
|
|
469
|
+
# Compute unperturbed subsection boundaries
|
|
470
|
+
base_start = sub_start
|
|
471
|
+
base_end = base_start + base_len
|
|
472
|
+
|
|
473
|
+
# Define smooth range variation parameters
|
|
474
|
+
max_shift = range_scale * base_len / 2
|
|
475
|
+
|
|
476
|
+
# Decide how much to increase the mean overall
|
|
477
|
+
prev_mean += mean * random.uniform(0.0, 2 * mean_scale)
|
|
478
|
+
|
|
479
|
+
# Apply small random shifts to start/end, relative to subsection size
|
|
480
|
+
if range_scale > 0:
|
|
481
|
+
# Keep total ordering consistent (avoid overlaps or inversions)
|
|
482
|
+
local_shift = random.uniform(-max_shift / n_splits, max_shift / n_splits)
|
|
483
|
+
new_start = base_start + local_shift
|
|
484
|
+
new_end = base_end + local_shift
|
|
485
|
+
else:
|
|
486
|
+
new_start, new_end = base_start, base_end
|
|
487
|
+
|
|
488
|
+
# Clamp to valid time bounds
|
|
489
|
+
new_start = max(self.t_min, min(new_start, end))
|
|
490
|
+
new_end = max(new_start, min(new_end, end))
|
|
491
|
+
|
|
492
|
+
# Randomize mean and std within allowed ranges
|
|
493
|
+
mean_var = prev_mean * (1 + random.uniform(-mean_scale, mean_scale))
|
|
494
|
+
std_var = std * (1 + random.uniform(-std_scale, std_scale))
|
|
495
|
+
|
|
496
|
+
sections.append(
|
|
497
|
+
{
|
|
498
|
+
"range": (new_start, new_end),
|
|
499
|
+
"mean": mean_var,
|
|
500
|
+
"std": std_var,
|
|
501
|
+
}
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Prepare for next subsection
|
|
505
|
+
sub_start = base_end
|
|
506
|
+
prev_mean = mean_var
|
|
507
|
+
|
|
508
|
+
# Ensure last section ends exactly at x_max
|
|
509
|
+
sections[-1]["range"] = (sections[-1]["range"][0], self.t_max)
|
|
510
|
+
|
|
511
|
+
# Precompute exact float counts
|
|
512
|
+
section_lengths = [s["range"][1] - s["range"][0] for s in sections]
|
|
513
|
+
total_length = sum(section_lengths)
|
|
514
|
+
exact_counts = [length / total_length * self.n_samples for length in section_lengths]
|
|
515
|
+
|
|
516
|
+
# Allocate integer samples with rounding, track residual
|
|
517
|
+
n_section_samples_list = []
|
|
518
|
+
residual = 0.0
|
|
519
|
+
for exact in exact_counts[:-1]:
|
|
520
|
+
n = int(round(exact + residual))
|
|
521
|
+
n_section_samples_list.append(n)
|
|
522
|
+
residual += exact - n # carry over rounding error
|
|
523
|
+
|
|
524
|
+
# Assign remaining samples to last section
|
|
525
|
+
n_section_samples_list.append(self.n_samples - sum(n_section_samples_list))
|
|
526
|
+
|
|
527
|
+
# Generate data for each section
|
|
528
|
+
signal_segments: list[torch.Tensor] = []
|
|
529
|
+
for i, section in enumerate(sections):
|
|
530
|
+
start, end = section["range"]
|
|
531
|
+
mean, std = section["mean"], section["std"]
|
|
532
|
+
|
|
533
|
+
# Samples proportional to section length
|
|
534
|
+
n_section_samples = n_section_samples_list[i]
|
|
535
|
+
|
|
536
|
+
# Determine next section’s parameters for blending
|
|
537
|
+
if i < len(sections) - 1:
|
|
538
|
+
next_mean = sections[i + 1]["mean"]
|
|
539
|
+
next_std = sections[i + 1]["std"]
|
|
540
|
+
else:
|
|
541
|
+
next_mean = mean
|
|
542
|
+
next_std = std
|
|
543
|
+
|
|
544
|
+
# Sigmoid-based blend curve from 0 → 1
|
|
545
|
+
x = torch.linspace(
|
|
546
|
+
-self.blend_k, self.blend_k, n_section_samples, device=self.device
|
|
547
|
+
).unsqueeze(1)
|
|
548
|
+
fade = torch.sigmoid(x)
|
|
549
|
+
fade = (fade - fade.min()) / (fade.max() - fade.min())
|
|
550
|
+
|
|
551
|
+
# Interpolate mean and std within the section
|
|
552
|
+
mean_curve = mean + (next_mean - mean) * fade
|
|
553
|
+
std_curve = std + (next_std - std) * fade
|
|
554
|
+
|
|
555
|
+
# Sample from a gradually changing Gaussian
|
|
556
|
+
y = torch.normal(mean=mean_curve, std=std_curve)
|
|
557
|
+
signal_segments.append(y)
|
|
558
|
+
|
|
559
|
+
# Concatenate tensors
|
|
560
|
+
time: torch.Tensor = torch.linspace(
|
|
561
|
+
self.t_min, self.t_max, self.n_samples, device=self.device
|
|
562
|
+
).unsqueeze(1)
|
|
563
|
+
signal: torch.Tensor = torch.cat(signal_segments, dim=0)
|
|
564
|
+
|
|
565
|
+
# Compute and normalize energy feature
|
|
566
|
+
energy = torch.linalg.vector_norm(
|
|
567
|
+
signal - signal[0].reshape([1, -1]), ord=2, dim=1
|
|
568
|
+
).unsqueeze(1)
|
|
569
|
+
min_e, max_e = energy.min(), energy.max()
|
|
570
|
+
energy = 1 - (energy - min_e) / (max_e - min_e) * 3
|
|
571
|
+
|
|
572
|
+
# Combine into context tensor
|
|
573
|
+
return time, signal, energy
|
|
574
|
+
|
|
575
|
+
def _compute_ground_truth(self, time: torch.Tensor) -> torch.Tensor:
|
|
576
|
+
"""Computes the ground-truth exponential decay at coordinate t.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
time (torch.Tensor): Input coordinate (shape: [1]).
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
torch.Tensor: Corresponding target value between 0 and 1.
|
|
583
|
+
"""
|
|
584
|
+
time_norm = (time - self.t_min) / (self.t_max - self.t_min)
|
|
585
|
+
k = self.ground_truth_steepness
|
|
586
|
+
if k == 0:
|
|
587
|
+
return 1 - time_norm
|
|
588
|
+
return (torch.exp(-k * time_norm) - torch.exp(-k)) / (1 - torch.exp(-k))
|
|
589
|
+
|
|
590
|
+
def __len__(self):
|
|
591
|
+
"""Returns the total number of generated samples."""
|
|
592
|
+
return self.n_samples
|
|
593
|
+
|
|
594
|
+
def __getitem__(self, idx):
|
|
595
|
+
"""Retrieves a single dataset sample.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
idx (int): Sample index.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
dict:
|
|
602
|
+
- "input": Gaussian signal value (torch.Tensor),
|
|
603
|
+
- "context: Concatenated features (time, energy, identifier),
|
|
604
|
+
- "target": Ground-truth exponential decay value
|
|
605
|
+
|
|
606
|
+
"""
|
|
607
|
+
sigal = self.signal[idx]
|
|
608
|
+
context = self.context[idx]
|
|
609
|
+
time = self.time[idx]
|
|
610
|
+
target = self._compute_ground_truth(time)
|
|
611
|
+
|
|
612
|
+
return {
|
|
613
|
+
"input": sigal,
|
|
614
|
+
"context": context,
|
|
615
|
+
"target": target,
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class SyntheticMonotonicity(Dataset):
|
|
620
|
+
"""Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
|
|
621
|
+
|
|
622
|
+
True function:
|
|
623
|
+
y_true(x) = log(1 + x)
|
|
624
|
+
|
|
625
|
+
Observed:
|
|
626
|
+
y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
|
|
421
627
|
|
|
422
628
|
Args:
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
629
|
+
n_samples (int): number of data points (default 200)
|
|
630
|
+
x_range (tuple): range of x values (default [0, 5])
|
|
631
|
+
noise_base (float): baseline noise level (default 0.05)
|
|
632
|
+
noise_scale (float): scale of heteroscedastic noise (default 0.15)
|
|
633
|
+
noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
|
|
634
|
+
noise_center (float): center point of heteroscedastic increase (default 2.5)
|
|
635
|
+
osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
|
|
636
|
+
osc_frequency (float): frequency of oscillation (default 6.0)
|
|
637
|
+
osc_prob (float): probability each sample receives oscillation (default 0.5)
|
|
638
|
+
seed (int or None): random seed
|
|
431
639
|
"""
|
|
432
640
|
|
|
433
641
|
def __init__(
|
|
434
642
|
self,
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
643
|
+
n_samples=200,
|
|
644
|
+
x_range=(0.0, 5.0),
|
|
645
|
+
noise_base=0.05,
|
|
646
|
+
noise_scale=0.15,
|
|
647
|
+
noise_sharpness=4.0,
|
|
648
|
+
noise_center=2.5,
|
|
649
|
+
osc_amplitude=0.08,
|
|
650
|
+
osc_frequency=6.0,
|
|
651
|
+
osc_prob=0.5,
|
|
652
|
+
seed=None,
|
|
441
653
|
):
|
|
654
|
+
"""Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
|
|
655
|
+
|
|
656
|
+
True function:
|
|
657
|
+
y_true(x) = log(1 + x)
|
|
658
|
+
|
|
659
|
+
Observed:
|
|
660
|
+
y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
n_samples (int): number of data points (default 200)
|
|
664
|
+
x_range (tuple): range of x values (default [0, 5])
|
|
665
|
+
noise_base (float): baseline noise level (default 0.05)
|
|
666
|
+
noise_scale (float): scale of heteroscedastic noise (default 0.15)
|
|
667
|
+
noise_sharpness (float): steepness of heteroscedastic transition (default 4.0)
|
|
668
|
+
noise_center (float): center point of heteroscedastic increase (default 2.5)
|
|
669
|
+
osc_amplitude (float): amplitude of oscillatory perturbation (default 0.08)
|
|
670
|
+
osc_frequency (float): frequency of oscillation (default 6.0)
|
|
671
|
+
osc_prob (float): probability each sample receives oscillation (default 0.5)
|
|
672
|
+
seed (int or None): random seed
|
|
442
673
|
"""
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
self.amplitude = amplitude
|
|
447
|
-
self.frequency = frequency
|
|
448
|
-
self.noise_std = noise_std
|
|
449
|
-
self.bias = bias
|
|
450
|
-
self.random_seed = random_seed
|
|
674
|
+
super().__init__()
|
|
675
|
+
if seed is not None:
|
|
676
|
+
np.random.seed(seed)
|
|
451
677
|
|
|
452
|
-
|
|
453
|
-
self.
|
|
454
|
-
self.noise = np.random.normal(0, self.noise_std, length)
|
|
678
|
+
self.n_samples = n_samples
|
|
679
|
+
self.x_range = x_range
|
|
455
680
|
|
|
456
|
-
|
|
681
|
+
# Sample inputs
|
|
682
|
+
self.x = np.random.rand(n_samples) * (x_range[1] - x_range[0]) + x_range[0]
|
|
683
|
+
self.x = np.sort(self.x)
|
|
684
|
+
|
|
685
|
+
# Heteroscedastic noise (logistic growth with x)
|
|
686
|
+
noise_sigma = noise_base + noise_scale / (
|
|
687
|
+
1 + np.exp(-noise_sharpness * (self.x - noise_center))
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Oscillatory perturbation (applied with probability osc_prob)
|
|
691
|
+
mask = (np.random.rand(n_samples) < osc_prob).astype(float)
|
|
692
|
+
osc = mask * (osc_amplitude * np.sin(osc_frequency * self.x))
|
|
693
|
+
|
|
694
|
+
# Final observed target
|
|
695
|
+
self.y = np.log1p(self.x) + osc + noise_sigma * np.random.randn(n_samples)
|
|
696
|
+
|
|
697
|
+
# Convert to tensors
|
|
698
|
+
self.inputs = torch.tensor(self.x, dtype=torch.float32).unsqueeze(1)
|
|
699
|
+
self.targets = torch.tensor(self.y, dtype=torch.float32).unsqueeze(1)
|
|
700
|
+
|
|
701
|
+
def __len__(self):
|
|
702
|
+
"""Return the total number of samples in the dataset.
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
int: The number of generated data points.
|
|
457
706
|
"""
|
|
458
|
-
|
|
707
|
+
return self.n_samples
|
|
708
|
+
|
|
709
|
+
def __getitem__(self, idx) -> dict:
|
|
710
|
+
"""Retrieve a single sample from the dataset.
|
|
459
711
|
|
|
460
712
|
Args:
|
|
461
|
-
idx (int): Index of the
|
|
713
|
+
idx (int): Index of the sample to retrieve.
|
|
462
714
|
|
|
463
715
|
Returns:
|
|
464
|
-
|
|
465
|
-
|
|
716
|
+
dict: A dictionary with the following keys:
|
|
717
|
+
- "input" (torch.Tensor): The x value.
|
|
718
|
+
- "target" (torch.Tensor): The corresponding y value.
|
|
466
719
|
"""
|
|
720
|
+
return {"input": self.inputs[idx], "target": self.targets[idx]}
|
|
467
721
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
722
|
+
|
|
723
|
+
class SyntheticClusters(Dataset):
|
|
724
|
+
"""PyTorch dataset for generating synthetic clustered 2D data with labels.
|
|
725
|
+
|
|
726
|
+
Each cluster is defined by its center, size, spread (standard deviation), and label.
|
|
727
|
+
The dataset samples points from a Gaussian distribution centered at the cluster mean.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
|
|
731
|
+
e.g. [(x1, y1), (x2, y2), ...].
|
|
732
|
+
cluster_sizes (list[int]): Number of points to generate in each cluster.
|
|
733
|
+
cluster_std (list[float]): Standard deviation (spread) of each cluster.
|
|
734
|
+
cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
|
|
735
|
+
|
|
736
|
+
Raises:
|
|
737
|
+
AssertionError: If the input lists do not all have the same length.
|
|
738
|
+
|
|
739
|
+
Attributes:
|
|
740
|
+
data (torch.Tensor): A concatenated tensor of all generated points with shape (N, 2).
|
|
741
|
+
labels (torch.Tensor): A concatenated tensor of class labels with shape (N,),
|
|
742
|
+
where N is the total number of generated points.
|
|
743
|
+
"""
|
|
744
|
+
|
|
745
|
+
def __init__(self, cluster_centers, cluster_sizes, cluster_std, cluster_labels):
|
|
746
|
+
"""Initialize the ClusterDataset.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
cluster_centers (list[tuple[float, float]]): Coordinates of each cluster center,
|
|
750
|
+
e.g. [(x1, y1), (x2, y2), ...].
|
|
751
|
+
cluster_sizes (list[int]): Number of points to generate in each cluster.
|
|
752
|
+
cluster_std (list[float]): Standard deviation (spread) of each cluster.
|
|
753
|
+
cluster_labels (list[int]): Class label for each cluster (e.g., 0 or 1).
|
|
754
|
+
|
|
755
|
+
Raises:
|
|
756
|
+
AssertionError: If the input lists do not all have the same length.
|
|
757
|
+
"""
|
|
758
|
+
assert (
|
|
759
|
+
len(cluster_centers) == len(cluster_sizes) == len(cluster_std) == len(cluster_labels)
|
|
760
|
+
), "All input lists must have the same length"
|
|
761
|
+
|
|
762
|
+
self.data = []
|
|
763
|
+
self.labels = []
|
|
764
|
+
|
|
765
|
+
# Generate points for each cluster
|
|
766
|
+
for center, size, std, label in zip(
|
|
767
|
+
cluster_centers, cluster_sizes, cluster_std, cluster_labels, strict=False
|
|
768
|
+
):
|
|
769
|
+
x = torch.normal(mean=center[0], std=std, size=(size, 1))
|
|
770
|
+
y = torch.normal(mean=center[1], std=std, size=(size, 1))
|
|
771
|
+
points = torch.cat([x, y], dim=1)
|
|
772
|
+
|
|
773
|
+
self.data.append(points)
|
|
774
|
+
self.labels.append(torch.full((size,), label, dtype=torch.long))
|
|
775
|
+
|
|
776
|
+
# Concatenate all clusters
|
|
777
|
+
self.data = torch.cat(self.data, dim=0)
|
|
778
|
+
self.labels = torch.cat(self.labels, dim=0)
|
|
491
779
|
|
|
492
780
|
def __len__(self):
|
|
781
|
+
"""Return the total number of samples in the dataset.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
int: The number of generated data points.
|
|
493
785
|
"""
|
|
494
|
-
|
|
786
|
+
return len(self.data)
|
|
787
|
+
|
|
788
|
+
def __getitem__(self, idx) -> dict:
|
|
789
|
+
"""Retrieve a single sample from the dataset.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
idx (int): Index of the sample to retrieve.
|
|
495
793
|
|
|
496
794
|
Returns:
|
|
497
|
-
|
|
795
|
+
dict: A dictionary with the following keys:
|
|
796
|
+
- "input" (torch.Tensor): The 2D point at index `idx`.
|
|
797
|
+
- "target" (torch.Tensor): The corresponding class label.
|
|
498
798
|
"""
|
|
499
|
-
return self.
|
|
799
|
+
return {"input": self.data[idx], "target": self.labels[idx]}
|