congrads 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +17 -10
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +664 -134
- congrads/core.py +482 -110
- congrads/datasets.py +315 -11
- congrads/descriptor.py +100 -20
- congrads/metrics.py +178 -16
- congrads/networks.py +47 -23
- congrads/requirements.txt +6 -0
- congrads/transformations.py +139 -0
- congrads/utils.py +439 -39
- congrads-1.0.1.dist-info/METADATA +208 -0
- congrads-1.0.1.dist-info/RECORD +16 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/WHEEL +1 -1
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/LICENSE +0 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/top_level.txt +0 -0
congrads/datasets.py
CHANGED
|
@@ -1,16 +1,81 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines several PyTorch dataset classes for loading and
|
|
3
|
+
working with various datasets. Each dataset class extends the
|
|
4
|
+
`torch.utils.data.Dataset` class and provides functionality for
|
|
5
|
+
downloading, loading, and transforming specific datasets.
|
|
6
|
+
|
|
7
|
+
Classes:
|
|
8
|
+
|
|
9
|
+
- BiasCorrection: A dataset class for the Bias Correction dataset
|
|
10
|
+
focused on temperature forecast data.
|
|
11
|
+
- FamilyIncome: A dataset class for the Family Income and
|
|
12
|
+
Expenditure dataset.
|
|
13
|
+
- NoisySines: A dataset class that generates noisy sine wave
|
|
14
|
+
samples with added Gaussian noise.
|
|
15
|
+
|
|
16
|
+
Each dataset class provides methods for downloading the data
|
|
17
|
+
(if not already available), checking the integrity of the dataset, loading
|
|
18
|
+
the data from CSV files or generating synthetic data, and applying
|
|
19
|
+
transformations to the data.
|
|
20
|
+
|
|
21
|
+
Key Methods:
|
|
22
|
+
|
|
23
|
+
- `__init__`: Initializes the dataset by specifying the root directory,
|
|
24
|
+
transformation function, and optional download flag.
|
|
25
|
+
- `__getitem__`: Retrieves a specific data point given its index,
|
|
26
|
+
returning input-output pairs.
|
|
27
|
+
- `__len__`: Returns the total number of examples in the dataset.
|
|
28
|
+
- `download`: Downloads and extracts the dataset from
|
|
29
|
+
the specified mirrors.
|
|
30
|
+
- `_load_data`: Loads the dataset from CSV files and
|
|
31
|
+
applies transformations.
|
|
32
|
+
- `_check_exists`: Checks if the dataset is already
|
|
33
|
+
downloaded and verified.
|
|
34
|
+
|
|
35
|
+
Each dataset class allows the user to apply custom transformations to the
|
|
36
|
+
dataset through the `transform` argument to allow pre-processing and offers
|
|
37
|
+
the ability to download the dataset if it's not already present on
|
|
38
|
+
the local disk.
|
|
39
|
+
"""
|
|
40
|
+
|
|
1
41
|
import os
|
|
2
|
-
from urllib.error import URLError
|
|
3
|
-
import numpy as np
|
|
4
42
|
from pathlib import Path
|
|
5
43
|
from typing import Callable, Union
|
|
44
|
+
from urllib.error import URLError
|
|
45
|
+
|
|
46
|
+
import numpy as np
|
|
6
47
|
import pandas as pd
|
|
7
|
-
from torch.utils.data import Dataset
|
|
8
48
|
import torch
|
|
9
|
-
|
|
10
|
-
from torchvision.datasets.utils import
|
|
49
|
+
from torch.utils.data import Dataset
|
|
50
|
+
from torchvision.datasets.utils import (
|
|
51
|
+
check_integrity,
|
|
52
|
+
download_and_extract_archive,
|
|
53
|
+
)
|
|
11
54
|
|
|
12
55
|
|
|
13
56
|
class BiasCorrection(Dataset):
|
|
57
|
+
"""
|
|
58
|
+
A dataset class for accessing the Bias Correction dataset.
|
|
59
|
+
|
|
60
|
+
This class extends the `Dataset` class and provides functionality for
|
|
61
|
+
downloading, loading, and transforming the Bias Correction dataset.
|
|
62
|
+
The dataset is focused on temperature forecast data and is made available
|
|
63
|
+
for use with PyTorch. If `download` is set to True, the dataset will be
|
|
64
|
+
downloaded if it is not already available. The data is then loaded,
|
|
65
|
+
and a transformation function is applied to it.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
root (Union[str, Path]): The root directory where the dataset
|
|
69
|
+
will be stored or loaded from.
|
|
70
|
+
transform (Callable): A function to transform the dataset
|
|
71
|
+
(e.g., preprocessing).
|
|
72
|
+
download (bool, optional): Whether to download the dataset if it's
|
|
73
|
+
not already present. Defaults to False.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
RuntimeError: If the dataset is not found and `download`
|
|
77
|
+
is not set to True or if all mirrors fail to provide the dataset.
|
|
78
|
+
"""
|
|
14
79
|
|
|
15
80
|
mirrors = [
|
|
16
81
|
"https://archive.ics.uci.edu/static/public/514/",
|
|
@@ -18,6 +83,7 @@ class BiasCorrection(Dataset):
|
|
|
18
83
|
|
|
19
84
|
resources = [
|
|
20
85
|
(
|
|
86
|
+
# pylint: disable-next=line-too-long
|
|
21
87
|
"bias+correction+of+numerical+prediction+model+temperature+forecast.zip",
|
|
22
88
|
"3deee56d461a2686887c4ae38fe3ccf3",
|
|
23
89
|
),
|
|
@@ -29,6 +95,9 @@ class BiasCorrection(Dataset):
|
|
|
29
95
|
transform: Callable,
|
|
30
96
|
download: bool = False,
|
|
31
97
|
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Constructor method to initialize the dataset.
|
|
100
|
+
"""
|
|
32
101
|
|
|
33
102
|
super().__init__()
|
|
34
103
|
self.root = root
|
|
@@ -45,6 +114,17 @@ class BiasCorrection(Dataset):
|
|
|
45
114
|
self.data_input, self.data_output = self._load_data()
|
|
46
115
|
|
|
47
116
|
def _load_data(self):
|
|
117
|
+
"""
|
|
118
|
+
Loads the dataset from the CSV file and applies the transformation.
|
|
119
|
+
|
|
120
|
+
The data is read from the `Bias_correction_ucl.csv` file, and the
|
|
121
|
+
transformation function is applied to it.
|
|
122
|
+
The input and output data are separated and returned as numpy arrays.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
|
|
126
|
+
and output data as numpy arrays.
|
|
127
|
+
"""
|
|
48
128
|
|
|
49
129
|
data: pd.DataFrame = pd.read_csv(
|
|
50
130
|
os.path.join(self.data_folder, "Bias_correction_ucl.csv")
|
|
@@ -56,10 +136,27 @@ class BiasCorrection(Dataset):
|
|
|
56
136
|
return data_input, data_output
|
|
57
137
|
|
|
58
138
|
def __len__(self):
|
|
139
|
+
"""
|
|
140
|
+
Returns the number of examples in the dataset.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
int: The number of examples in the dataset
|
|
144
|
+
(i.e., the number of rows in the input data).
|
|
145
|
+
"""
|
|
59
146
|
|
|
60
147
|
return self.data_input.shape[0]
|
|
61
148
|
|
|
62
149
|
def __getitem__(self, idx):
|
|
150
|
+
"""
|
|
151
|
+
Returns the input-output pair for a given index.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
idx (int): The index of the example to retrieve.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Tuple[torch.Tensor, torch.Tensor]: The input-output pair
|
|
158
|
+
as PyTorch tensors.
|
|
159
|
+
"""
|
|
63
160
|
|
|
64
161
|
example = self.data_input[idx, :]
|
|
65
162
|
target = self.data_output[idx, :]
|
|
@@ -69,16 +166,44 @@ class BiasCorrection(Dataset):
|
|
|
69
166
|
|
|
70
167
|
@property
|
|
71
168
|
def data_folder(self) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Returns the path to the folder where the dataset is stored.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
str: The path to the dataset folder.
|
|
174
|
+
"""
|
|
72
175
|
|
|
73
176
|
return os.path.join(self.root, self.__class__.__name__)
|
|
74
177
|
|
|
75
178
|
def _check_exists(self) -> bool:
|
|
179
|
+
"""
|
|
180
|
+
Checks if the dataset is already downloaded and verified.
|
|
181
|
+
|
|
182
|
+
This method checks that all required files exist and
|
|
183
|
+
their integrity is validated via MD5 checksums.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
bool: True if all resources exist and their
|
|
187
|
+
integrity is valid, False otherwise.
|
|
188
|
+
"""
|
|
189
|
+
|
|
76
190
|
return all(
|
|
77
191
|
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
78
192
|
for file_path, checksum in self.resources
|
|
79
193
|
)
|
|
80
194
|
|
|
81
195
|
def download(self) -> None:
|
|
196
|
+
"""
|
|
197
|
+
Downloads and extracts the dataset.
|
|
198
|
+
|
|
199
|
+
This method attempts to download the dataset from the mirrors and
|
|
200
|
+
extract it into the appropriate folder. If any error occurs during
|
|
201
|
+
downloading, it will try each mirror in sequence.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
RuntimeError: If all mirrors fail to provide the dataset.
|
|
205
|
+
"""
|
|
206
|
+
|
|
82
207
|
if self._check_exists():
|
|
83
208
|
return
|
|
84
209
|
|
|
@@ -91,7 +216,10 @@ class BiasCorrection(Dataset):
|
|
|
91
216
|
url = f"{mirror}{filename}"
|
|
92
217
|
try:
|
|
93
218
|
download_and_extract_archive(
|
|
94
|
-
url,
|
|
219
|
+
url,
|
|
220
|
+
download_root=self.data_folder,
|
|
221
|
+
filename=filename,
|
|
222
|
+
md5=md5,
|
|
95
223
|
)
|
|
96
224
|
except URLError as e:
|
|
97
225
|
errors.append(e)
|
|
@@ -104,15 +232,40 @@ class BiasCorrection(Dataset):
|
|
|
104
232
|
raise RuntimeError(s)
|
|
105
233
|
|
|
106
234
|
|
|
107
|
-
class
|
|
235
|
+
class FamilyIncome(Dataset):
|
|
236
|
+
"""
|
|
237
|
+
A dataset class for accessing the Family Income and Expenditure dataset.
|
|
238
|
+
|
|
239
|
+
This class extends the `Dataset` class and provides functionality for
|
|
240
|
+
downloading, loading, and transforming the Family Income and
|
|
241
|
+
Expenditure dataset. The dataset is intended for use with
|
|
242
|
+
PyTorch-based projects, offering a convenient interface for data handling.
|
|
243
|
+
This class provides access to the Family Income and Expenditure dataset
|
|
244
|
+
for use with PyTorch. If `download` is set to True, the dataset will be
|
|
245
|
+
downloaded if it is not already available. The data is then loaded,
|
|
246
|
+
and a user-defined transformation function is applied to it.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
root (Union[str, Path]): The root directory where the dataset will
|
|
250
|
+
be stored or loaded from.
|
|
251
|
+
transform (Callable): A function to transform the dataset
|
|
252
|
+
(e.g., preprocessing).
|
|
253
|
+
download (bool, optional): Whether to download the dataset if it's
|
|
254
|
+
not already present. Defaults to False.
|
|
255
|
+
|
|
256
|
+
Raises:
|
|
257
|
+
RuntimeError: If the dataset is not found and `download`
|
|
258
|
+
is not set to True or if all mirrors fail to provide the dataset.
|
|
259
|
+
"""
|
|
108
260
|
|
|
109
261
|
mirrors = [
|
|
110
|
-
|
|
262
|
+
# pylint: disable-next=line-too-long
|
|
263
|
+
"https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure",
|
|
111
264
|
]
|
|
112
265
|
|
|
113
266
|
resources = [
|
|
114
267
|
(
|
|
115
|
-
"
|
|
268
|
+
"archive.zip",
|
|
116
269
|
"7d74bc7facc3d7c07c4df1c1c6ac563e",
|
|
117
270
|
),
|
|
118
271
|
]
|
|
@@ -123,6 +276,10 @@ class FiniteIncome(Dataset):
|
|
|
123
276
|
transform: Callable,
|
|
124
277
|
download: bool = False,
|
|
125
278
|
) -> None:
|
|
279
|
+
"""
|
|
280
|
+
Constructor method to initialize the dataset.
|
|
281
|
+
"""
|
|
282
|
+
|
|
126
283
|
super().__init__()
|
|
127
284
|
self.root = root
|
|
128
285
|
self.transform = transform
|
|
@@ -138,6 +295,18 @@ class FiniteIncome(Dataset):
|
|
|
138
295
|
self.data_input, self.data_output = self._load_data()
|
|
139
296
|
|
|
140
297
|
def _load_data(self):
|
|
298
|
+
"""
|
|
299
|
+
Loads the Family Income and Expenditure dataset from the CSV file
|
|
300
|
+
and applies the transformation.
|
|
301
|
+
|
|
302
|
+
The data is read from the `Family Income and Expenditure.csv` file,
|
|
303
|
+
and the transformation function is applied to it. The input and
|
|
304
|
+
output data are separated and returned as numpy arrays.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Tuple[numpy.ndarray, numpy.ndarray]: A tuple containing the input
|
|
308
|
+
and output data as numpy arrays.
|
|
309
|
+
"""
|
|
141
310
|
|
|
142
311
|
data: pd.DataFrame = pd.read_csv(
|
|
143
312
|
os.path.join(self.data_folder, "Family Income and Expenditure.csv")
|
|
@@ -149,9 +318,28 @@ class FiniteIncome(Dataset):
|
|
|
149
318
|
return data_input, data_output
|
|
150
319
|
|
|
151
320
|
def __len__(self):
|
|
321
|
+
"""
|
|
322
|
+
Returns the number of examples in the dataset.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
int: The number of examples in the dataset
|
|
326
|
+
(i.e., the number of rows in the input data).
|
|
327
|
+
"""
|
|
328
|
+
|
|
152
329
|
return self.data_input.shape[0]
|
|
153
330
|
|
|
154
331
|
def __getitem__(self, idx):
|
|
332
|
+
"""
|
|
333
|
+
Returns the input-output pair for a given index.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
idx (int): The index of the example to retrieve.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Tuple[torch.Tensor, torch.Tensor]: The input-output pair
|
|
340
|
+
as PyTorch tensors.
|
|
341
|
+
"""
|
|
342
|
+
|
|
155
343
|
example = self.data_input[idx, :]
|
|
156
344
|
target = self.data_output[idx, :]
|
|
157
345
|
example = torch.tensor(example)
|
|
@@ -160,15 +348,43 @@ class FiniteIncome(Dataset):
|
|
|
160
348
|
|
|
161
349
|
@property
|
|
162
350
|
def data_folder(self) -> str:
|
|
351
|
+
"""
|
|
352
|
+
Returns the path to the folder where the dataset is stored.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
str: The path to the dataset folder.
|
|
356
|
+
"""
|
|
357
|
+
|
|
163
358
|
return os.path.join(self.root, self.__class__.__name__)
|
|
164
359
|
|
|
165
360
|
def _check_exists(self) -> bool:
|
|
361
|
+
"""
|
|
362
|
+
Checks if the dataset is already downloaded and verified.
|
|
363
|
+
|
|
364
|
+
This method checks that all required files exist and
|
|
365
|
+
their integrity is validated via MD5 checksums.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
bool: True if all resources exist and their
|
|
369
|
+
integrity is valid, False otherwise.
|
|
370
|
+
"""
|
|
371
|
+
|
|
166
372
|
return all(
|
|
167
373
|
check_integrity(os.path.join(self.data_folder, file_path), checksum)
|
|
168
374
|
for file_path, checksum in self.resources
|
|
169
375
|
)
|
|
170
376
|
|
|
171
377
|
def download(self) -> None:
|
|
378
|
+
"""
|
|
379
|
+
Downloads and extracts the dataset.
|
|
380
|
+
|
|
381
|
+
This method attempts to download the dataset from the mirrors
|
|
382
|
+
and extract it into the appropriate folder. If any error occurs
|
|
383
|
+
during downloading, it will try each mirror in sequence.
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
RuntimeError: If all mirrors fail to provide the dataset.
|
|
387
|
+
"""
|
|
172
388
|
|
|
173
389
|
if self._check_exists():
|
|
174
390
|
return
|
|
@@ -179,10 +395,13 @@ class FiniteIncome(Dataset):
|
|
|
179
395
|
for filename, md5 in self.resources:
|
|
180
396
|
errors = []
|
|
181
397
|
for mirror in self.mirrors:
|
|
182
|
-
url = f"{mirror}
|
|
398
|
+
url = f"{mirror}"
|
|
183
399
|
try:
|
|
184
400
|
download_and_extract_archive(
|
|
185
|
-
url,
|
|
401
|
+
url,
|
|
402
|
+
download_root=self.data_folder,
|
|
403
|
+
filename=filename,
|
|
404
|
+
md5=md5,
|
|
186
405
|
)
|
|
187
406
|
except URLError as e:
|
|
188
407
|
errors.append(e)
|
|
@@ -193,3 +412,88 @@ class FiniteIncome(Dataset):
|
|
|
193
412
|
for mirror, err in zip(self.mirrors, errors):
|
|
194
413
|
s += f"Tried {mirror}, got:\n{str(err)}\n"
|
|
195
414
|
raise RuntimeError(s)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class NoisySines(Dataset):
|
|
418
|
+
"""
|
|
419
|
+
A PyTorch dataset generating samples from a causal
|
|
420
|
+
sine wave with added noise.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
length (int): Number of data points in the dataset.
|
|
424
|
+
amplitude (float): Amplitude of the sine wave.
|
|
425
|
+
frequency (float): Frequency of the sine wave in Hz.
|
|
426
|
+
noise_std (float): Standard deviation of the Gaussian noise.
|
|
427
|
+
bias (float): Offset from zero.
|
|
428
|
+
|
|
429
|
+
The sine wave is zero for times before t=0 and follows a
|
|
430
|
+
standard sine wave after t=0, with Gaussian noise added to all points.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(
|
|
434
|
+
self,
|
|
435
|
+
length,
|
|
436
|
+
amplitude=1,
|
|
437
|
+
frequency=10.0,
|
|
438
|
+
noise_std=0.05,
|
|
439
|
+
bias=0,
|
|
440
|
+
random_seed=42,
|
|
441
|
+
):
|
|
442
|
+
"""
|
|
443
|
+
Initializes the NoisyCausalSine dataset.
|
|
444
|
+
"""
|
|
445
|
+
self.length = length
|
|
446
|
+
self.amplitude = amplitude
|
|
447
|
+
self.frequency = frequency
|
|
448
|
+
self.noise_std = noise_std
|
|
449
|
+
self.bias = bias
|
|
450
|
+
self.random_seed = random_seed
|
|
451
|
+
|
|
452
|
+
np.random.seed(self.random_seed)
|
|
453
|
+
self.time = np.linspace(0, 1, length)
|
|
454
|
+
self.noise = np.random.normal(0, self.noise_std, length)
|
|
455
|
+
|
|
456
|
+
def __getitem__(self, idx):
|
|
457
|
+
"""
|
|
458
|
+
Returns the time and noisy sine wave value for a given index.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
idx (int): Index of the data point to retrieve.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the
|
|
465
|
+
time value and the noisy sine wave value.
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
t = self.time[idx]
|
|
469
|
+
if idx < self.length // 2:
|
|
470
|
+
sine_value = self.bias
|
|
471
|
+
cosine_value = self.bias
|
|
472
|
+
else:
|
|
473
|
+
sine_value = (
|
|
474
|
+
self.amplitude * np.sin(2 * np.pi * self.frequency * t)
|
|
475
|
+
+ self.bias
|
|
476
|
+
)
|
|
477
|
+
cosine_value = (
|
|
478
|
+
self.amplitude * np.cos(2 * np.pi * self.frequency * t)
|
|
479
|
+
+ self.bias
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Add noise to the signals
|
|
483
|
+
noisy_sine = sine_value + self.noise[idx]
|
|
484
|
+
noisy_cosine = cosine_value + self.noise[idx]
|
|
485
|
+
|
|
486
|
+
# Convert to tensor
|
|
487
|
+
example, target = torch.tensor([t], dtype=torch.float32), torch.tensor(
|
|
488
|
+
[noisy_sine, noisy_cosine], dtype=torch.float32
|
|
489
|
+
)
|
|
490
|
+
return example, target
|
|
491
|
+
|
|
492
|
+
def __len__(self):
|
|
493
|
+
"""
|
|
494
|
+
Returns the total number of data points in the dataset.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
int: The length of the dataset.
|
|
498
|
+
"""
|
|
499
|
+
return self.length
|
congrads/descriptor.py
CHANGED
|
@@ -1,17 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the `Descriptor` class, which is designed to manage
|
|
3
|
+
the mapping between neuron names, their corresponding layers, and additional
|
|
4
|
+
properties such as constant or variable status. It provides a way to easily
|
|
5
|
+
place constraints on parts of your network, by referencing the neuron names
|
|
6
|
+
instead of indices.
|
|
7
|
+
|
|
8
|
+
The `Descriptor` class allows for easy constraint definitions on parts of
|
|
9
|
+
your neural network. It supports registering neurons with associated layers,
|
|
10
|
+
indices, and optional attributes, such as whether the layer is constant
|
|
11
|
+
or variable.
|
|
12
|
+
|
|
13
|
+
Key Methods:
|
|
14
|
+
|
|
15
|
+
- `__init__`: Initializes the `Descriptor` object with empty mappings
|
|
16
|
+
and sets for managing neurons and layers.
|
|
17
|
+
- `add`: Registers a neuron with its associated layer, index, and
|
|
18
|
+
optional constant status.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from .utils import validate_type
|
|
23
|
+
|
|
24
|
+
|
|
1
25
|
class Descriptor:
|
|
2
|
-
|
|
26
|
+
"""
|
|
27
|
+
A class to manage the mapping between neuron names, their corresponding
|
|
28
|
+
layers, and additional properties (such as min/max values, output,
|
|
29
|
+
and constant variables).
|
|
30
|
+
|
|
31
|
+
This class is designed to track the relationship between neurons and
|
|
32
|
+
layers in a neural network. It allows for the assignment of properties
|
|
33
|
+
(like minimum and maximum values, and whether a layer is an output,
|
|
34
|
+
constant, or variable) to each neuron. The data is stored in
|
|
35
|
+
dictionaries and sets for efficient lookups.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
neuron_to_layer (dict): A dictionary mapping neuron names to
|
|
39
|
+
their corresponding layer names.
|
|
40
|
+
neuron_to_index (dict): A dictionary mapping neuron names to
|
|
41
|
+
their corresponding indices in the layers.
|
|
42
|
+
constant_layers (set): A set of layer names that represent
|
|
43
|
+
constant layers.
|
|
44
|
+
variable_layers (set): A set of layer names that represent
|
|
45
|
+
variable layers.
|
|
46
|
+
"""
|
|
3
47
|
|
|
4
48
|
def __init__(
|
|
5
49
|
self,
|
|
6
50
|
):
|
|
51
|
+
"""
|
|
52
|
+
Initializes the Descriptor object.
|
|
53
|
+
"""
|
|
7
54
|
|
|
8
|
-
# Define dictionaries that will translate neuron
|
|
55
|
+
# Define dictionaries that will translate neuron
|
|
56
|
+
# names to layer and index
|
|
9
57
|
self.neuron_to_layer: dict[str, str] = {}
|
|
10
58
|
self.neuron_to_index: dict[str, int] = {}
|
|
11
|
-
self.neuron_to_minmax: dict[str, tuple[float, float]] = {}
|
|
12
59
|
|
|
13
60
|
# Define sets that will hold the layers based on which type
|
|
14
|
-
self.output_layers: set[str] = set()
|
|
15
61
|
self.constant_layers: set[str] = set()
|
|
16
62
|
self.variable_layers: set[str] = set()
|
|
17
63
|
|
|
@@ -20,31 +66,65 @@ class Descriptor:
|
|
|
20
66
|
layer_name: str,
|
|
21
67
|
index: int,
|
|
22
68
|
neuron_name: str,
|
|
23
|
-
min: float = 0,
|
|
24
|
-
max: float = 1,
|
|
25
|
-
output: bool = False,
|
|
26
69
|
constant: bool = False,
|
|
27
70
|
):
|
|
71
|
+
"""
|
|
72
|
+
Adds a neuron to the descriptor with its associated layer,
|
|
73
|
+
index, and properties.
|
|
28
74
|
|
|
29
|
-
|
|
30
|
-
|
|
75
|
+
This method registers a neuron name and associates it with a
|
|
76
|
+
layer, its index, and optional properties such as whether
|
|
77
|
+
the layer is an output or constant layer.
|
|
31
78
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
79
|
+
Args:
|
|
80
|
+
layer_name (str): The name of the layer where the neuron is located.
|
|
81
|
+
index (int): The index of the neuron within the layer.
|
|
82
|
+
neuron_name (str): The name of the neuron.
|
|
83
|
+
constant (bool, optional): Whether the layer is a constant layer.
|
|
84
|
+
Defaults to False.
|
|
36
85
|
|
|
37
|
-
|
|
38
|
-
|
|
86
|
+
Raises:
|
|
87
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
88
|
+
ValueError: If a layer or index is already assigned for a neuron
|
|
89
|
+
or a duplicate index is used within a layer.
|
|
90
|
+
|
|
91
|
+
"""
|
|
39
92
|
|
|
40
|
-
|
|
93
|
+
# Type checking
|
|
94
|
+
validate_type("layer_name", layer_name, str)
|
|
95
|
+
validate_type("index", index, int)
|
|
96
|
+
validate_type("neuron_name", neuron_name, str)
|
|
97
|
+
validate_type("constant", constant, bool)
|
|
98
|
+
|
|
99
|
+
# Other validations
|
|
100
|
+
if neuron_name in self.neuron_to_layer:
|
|
41
101
|
raise ValueError(
|
|
42
|
-
|
|
102
|
+
"There already is a layer registered for the neuron with name "
|
|
103
|
+
f"'{neuron_name}'. Please use a unique name for each neuron."
|
|
43
104
|
)
|
|
44
105
|
|
|
45
|
-
if
|
|
106
|
+
if neuron_name in self.neuron_to_index:
|
|
46
107
|
raise ValueError(
|
|
47
|
-
|
|
108
|
+
"There already is an index registered for the neuron with name "
|
|
109
|
+
f"'{neuron_name}'. Please use a unique name for each neuron."
|
|
48
110
|
)
|
|
49
111
|
|
|
50
|
-
self.
|
|
112
|
+
for existing_neuron, assigned_index in self.neuron_to_index.items():
|
|
113
|
+
if (
|
|
114
|
+
assigned_index == index
|
|
115
|
+
and self.neuron_to_layer[existing_neuron] == layer_name
|
|
116
|
+
):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"The index {index} in layer {layer_name} is already "
|
|
119
|
+
"assigned. Every neuron must be assigned a different "
|
|
120
|
+
"index that matches the network's output."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Add to dictionaries and sets
|
|
124
|
+
if constant:
|
|
125
|
+
self.constant_layers.add(layer_name)
|
|
126
|
+
else:
|
|
127
|
+
self.variable_layers.add(layer_name)
|
|
128
|
+
|
|
129
|
+
self.neuron_to_layer[neuron_name] = layer_name
|
|
130
|
+
self.neuron_to_index[neuron_name] = index
|