congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +813 -476
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +118 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.6.dist-info → congrads-1.1.0.dist-info}/METADATA +48 -41
- congrads-1.1.0.dist-info/RECORD +14 -0
- congrads-1.1.0.dist-info/WHEEL +4 -0
- congrads-1.0.6.dist-info/LICENSE +0 -26
- congrads-1.0.6.dist-info/RECORD +0 -15
- congrads-1.0.6.dist-info/WHEEL +0 -5
- congrads-1.0.6.dist-info/top_level.txt +0 -1
congrads/utils.py
CHANGED
|
@@ -1,45 +1,45 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module provides utility functions and classes for managing data,
|
|
3
|
-
logging, and preprocessing in machine learning workflows. The functionalities
|
|
4
|
-
include logging key-value pairs to CSV files, splitting datasets into
|
|
5
|
-
training, validation, and test sets, preprocessing functions for various
|
|
6
|
-
data sets and validator functions for type checking.
|
|
7
|
-
"""
|
|
1
|
+
"""This module holds utility functions and classes for the congrads package."""
|
|
8
2
|
|
|
9
3
|
import os
|
|
4
|
+
import random
|
|
10
5
|
|
|
11
6
|
import numpy as np
|
|
12
7
|
import pandas as pd
|
|
13
8
|
import torch
|
|
14
|
-
from torch import Generator
|
|
9
|
+
from torch import Generator, Tensor, argsort, cat, int32, unique
|
|
10
|
+
from torch.nn.modules.loss import _Loss
|
|
15
11
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
class CSVLogger:
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
|
|
15
|
+
"""A utility class for logging key-value pairs to a CSV file, organized by epochs.
|
|
16
|
+
|
|
17
|
+
Supports merging with existing logs or overwriting them.
|
|
22
18
|
|
|
23
19
|
Args:
|
|
24
20
|
file_path (str): The path to the CSV file for logging.
|
|
25
|
-
overwrite (bool): If True, overwrites any existing file at
|
|
26
|
-
|
|
27
|
-
merge (bool): If True, merges new values with existing data
|
|
28
|
-
in the file.
|
|
21
|
+
overwrite (bool): If True, overwrites any existing file at the file_path.
|
|
22
|
+
merge (bool): If True, merges new values with existing data in the file.
|
|
29
23
|
|
|
30
24
|
Raises:
|
|
31
25
|
ValueError: If both overwrite and merge are True.
|
|
32
|
-
FileExistsError: If the file already exists and neither
|
|
33
|
-
overwrite nor merge is True.
|
|
26
|
+
FileExistsError: If the file already exists and neither overwrite nor merge is True.
|
|
34
27
|
"""
|
|
35
28
|
|
|
36
|
-
def __init__(
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
29
|
+
def __init__(self, file_path: str, overwrite: bool = False, merge: bool = True):
|
|
30
|
+
"""Initializes the CSVLogger.
|
|
31
|
+
|
|
32
|
+
Supports merging with existing logs or overwriting them.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
file_path (str): The path to the CSV file for logging.
|
|
36
|
+
overwrite (optional, bool): If True, overwrites any existing file at the file_path. Defaults to False.
|
|
37
|
+
merge (optional, bool): If True, merges new values with existing data in the file. Defaults to True.
|
|
42
38
|
|
|
39
|
+
Raises:
|
|
40
|
+
ValueError: If both overwrite and merge are True.
|
|
41
|
+
FileExistsError: If the file already exists and neither overwrite nor merge is True.
|
|
42
|
+
"""
|
|
43
43
|
self.file_path = file_path
|
|
44
44
|
self.values: dict[tuple[int, str], float] = {}
|
|
45
45
|
|
|
@@ -62,53 +62,43 @@ class CSVLogger:
|
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
def add_value(self, name: str, value: float, epoch: int):
|
|
65
|
-
"""
|
|
66
|
-
Adds a value to the logger for a specific epoch and name.
|
|
65
|
+
"""Adds a value to the logger for a specific epoch and name.
|
|
67
66
|
|
|
68
67
|
Args:
|
|
69
68
|
name (str): The name of the metric or value to log.
|
|
70
69
|
value (float): The value to log.
|
|
71
70
|
epoch (int): The epoch associated with the value.
|
|
72
71
|
"""
|
|
73
|
-
|
|
74
72
|
self.values[epoch, name] = value
|
|
75
73
|
|
|
76
74
|
def save(self):
|
|
77
|
-
"""
|
|
78
|
-
Saves the logged values to the specified CSV file.
|
|
75
|
+
"""Saves the logged values to the specified CSV file.
|
|
79
76
|
|
|
80
77
|
If the file exists and merge is enabled, merges the current data
|
|
81
78
|
with the existing file.
|
|
82
79
|
"""
|
|
83
|
-
|
|
84
80
|
data = self.to_dataframe(self.values)
|
|
85
81
|
data.to_csv(self.file_path, index=False)
|
|
86
82
|
|
|
87
83
|
def load(self):
|
|
88
|
-
"""
|
|
89
|
-
Loads data from the CSV file into the logger.
|
|
84
|
+
"""Loads data from the CSV file into the logger.
|
|
90
85
|
|
|
91
86
|
Converts the CSV data into the internal dictionary format for
|
|
92
87
|
further updates or operations.
|
|
93
88
|
"""
|
|
94
|
-
|
|
95
89
|
df = pd.read_csv(self.file_path)
|
|
96
90
|
self.values = self.to_dict(df)
|
|
97
91
|
|
|
98
92
|
@staticmethod
|
|
99
93
|
def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
|
|
100
|
-
"""
|
|
101
|
-
Converts a dictionary of values into a DataFrame.
|
|
94
|
+
"""Converts a dictionary of values into a DataFrame.
|
|
102
95
|
|
|
103
96
|
Args:
|
|
104
|
-
values (dict[tuple[int, str], float]): A dictionary of values
|
|
105
|
-
keyed by (epoch, name).
|
|
97
|
+
values (dict[tuple[int, str], float]): A dictionary of values keyed by (epoch, name).
|
|
106
98
|
|
|
107
99
|
Returns:
|
|
108
|
-
pd.DataFrame: A DataFrame where epochs are rows, names are
|
|
109
|
-
columns, and values are the cell data.
|
|
100
|
+
pd.DataFrame: A DataFrame where epochs are rows, names are columns, and values are the cell data.
|
|
110
101
|
"""
|
|
111
|
-
|
|
112
102
|
# Convert to a DataFrame
|
|
113
103
|
df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
|
|
114
104
|
|
|
@@ -126,10 +116,7 @@ class CSVLogger:
|
|
|
126
116
|
|
|
127
117
|
@staticmethod
|
|
128
118
|
def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
|
|
129
|
-
"""
|
|
130
|
-
Converts a DataFrame with epochs as rows and names as columns
|
|
131
|
-
back into a dictionary of the format {(epoch, name): value}.
|
|
132
|
-
"""
|
|
119
|
+
"""Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}."""
|
|
133
120
|
# Set the epoch column as the index (if not already)
|
|
134
121
|
df = df.set_index("epoch")
|
|
135
122
|
|
|
@@ -153,9 +140,7 @@ def split_data_loaders(
|
|
|
153
140
|
test_size: float = 0.1,
|
|
154
141
|
split_generator: Generator = None,
|
|
155
142
|
) -> tuple[DataLoader, DataLoader, DataLoader]:
|
|
156
|
-
"""
|
|
157
|
-
Splits a dataset into training, validation, and test sets,
|
|
158
|
-
and returns corresponding DataLoader objects.
|
|
143
|
+
"""Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.
|
|
159
144
|
|
|
160
145
|
Args:
|
|
161
146
|
data (Dataset): The dataset to be split.
|
|
@@ -185,12 +170,9 @@ def split_data_loaders(
|
|
|
185
170
|
ValueError: If the train_size, valid_size, and test_size are not
|
|
186
171
|
between 0 and 1, or if their sum does not equal 1.
|
|
187
172
|
"""
|
|
188
|
-
|
|
189
173
|
# Validate split sizes
|
|
190
174
|
if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
|
|
191
|
-
raise ValueError(
|
|
192
|
-
"train_size, valid_size, and test_size must be between 0 and 1."
|
|
193
|
-
)
|
|
175
|
+
raise ValueError("train_size, valid_size, and test_size must be between 0 and 1.")
|
|
194
176
|
if not abs(train_size + valid_size + test_size - 1.0) < 1e-6:
|
|
195
177
|
raise ValueError("train_size, valid_size, and test_size must sum to 1.")
|
|
196
178
|
|
|
@@ -220,11 +202,8 @@ def split_data_loaders(
|
|
|
220
202
|
return train_generator, valid_generator, test_generator
|
|
221
203
|
|
|
222
204
|
|
|
223
|
-
#
|
|
224
|
-
|
|
225
|
-
"""
|
|
226
|
-
Preprocesses the given dataframe for bias correction by
|
|
227
|
-
performing a series of transformations.
|
|
205
|
+
def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
|
|
206
|
+
"""Preprocesses the given dataframe for bias correction by performing a series of transformations.
|
|
228
207
|
|
|
229
208
|
The function sequentially:
|
|
230
209
|
|
|
@@ -245,8 +224,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
245
224
|
"""
|
|
246
225
|
|
|
247
226
|
def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
|
|
248
|
-
"""Transform the string that denotes the date to
|
|
249
|
-
the datetime format in pandas."""
|
|
227
|
+
"""Transform the string that denotes the date to the datetime format in pandas."""
|
|
250
228
|
# make copy of dataframe
|
|
251
229
|
df_temp = df.copy()
|
|
252
230
|
# add new column at the front where the date string is
|
|
@@ -255,8 +233,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
255
233
|
return df_temp
|
|
256
234
|
|
|
257
235
|
def add_year(df: pd.DataFrame) -> pd.DataFrame:
|
|
258
|
-
"""Extract the year from the datetime cell and add it
|
|
259
|
-
as a new column to the dataframe at the front."""
|
|
236
|
+
"""Extract the year from the datetime cell and add it as a new column to the dataframe at the front."""
|
|
260
237
|
# make copy of dataframe
|
|
261
238
|
df_temp = df.copy()
|
|
262
239
|
# extract year and add new column at the front containing these numbers
|
|
@@ -264,8 +241,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
264
241
|
return df_temp
|
|
265
242
|
|
|
266
243
|
def add_month(df: pd.DataFrame) -> pd.DataFrame:
|
|
267
|
-
"""Extract the month from the datetime cell and add it
|
|
268
|
-
as a new column to the dataframe at the front."""
|
|
244
|
+
"""Extract the month from the datetime cell and add it as a new column to the dataframe at the front."""
|
|
269
245
|
# make copy of dataframe
|
|
270
246
|
df_temp = df.copy()
|
|
271
247
|
# extract month and add new column at index 1 containing these numbers
|
|
@@ -273,8 +249,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
273
249
|
return df_temp
|
|
274
250
|
|
|
275
251
|
def add_day(df: pd.DataFrame) -> pd.DataFrame:
|
|
276
|
-
"""Extract the day from the datetime cell and add it
|
|
277
|
-
as a new column to the dataframe at the front."""
|
|
252
|
+
"""Extract the day from the datetime cell and add it as a new column to the dataframe at the front."""
|
|
278
253
|
# make copy of dataframe
|
|
279
254
|
df_temp = df.copy()
|
|
280
255
|
# extract day and add new column at index 2 containing these numbers
|
|
@@ -282,8 +257,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
282
257
|
return df_temp
|
|
283
258
|
|
|
284
259
|
def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
|
|
285
|
-
"""Add a multiindex denoting if the column
|
|
286
|
-
is an input or output variable."""
|
|
260
|
+
"""Add a multiindex denoting if the column is an input or output variable."""
|
|
287
261
|
# copy the dataframe
|
|
288
262
|
temp_df = df.copy()
|
|
289
263
|
# extract all the column names
|
|
@@ -297,9 +271,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
297
271
|
input_output_list = input_list + output_list
|
|
298
272
|
# define multi index for attaching this 'Input' and 'Output' list with
|
|
299
273
|
# the column names already existing
|
|
300
|
-
multiindex_bias = pd.MultiIndex.from_arrays(
|
|
301
|
-
[input_output_list, column_names]
|
|
302
|
-
)
|
|
274
|
+
multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
|
|
303
275
|
# transpose such that index can be adjusted to multi index
|
|
304
276
|
new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
|
|
305
277
|
# transpose back such that columns are the same as before
|
|
@@ -308,9 +280,11 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
308
280
|
|
|
309
281
|
def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
|
|
310
282
|
"""Normalize the columns for the bias correction dataset.
|
|
283
|
+
|
|
311
284
|
This is different from normalizing all the columns separately
|
|
312
285
|
because the upper and lower bounds for the output variables
|
|
313
|
-
are assumed to be the same.
|
|
286
|
+
are assumed to be the same.
|
|
287
|
+
"""
|
|
314
288
|
# copy the dataframe
|
|
315
289
|
temp_df = df.copy()
|
|
316
290
|
# normalize each column
|
|
@@ -325,17 +299,13 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
325
299
|
else:
|
|
326
300
|
max_value = df[feature_name].max()
|
|
327
301
|
min_value = df[feature_name].min()
|
|
328
|
-
temp_df[feature_name] = (df[feature_name] - min_value) / (
|
|
329
|
-
max_value - min_value
|
|
330
|
-
)
|
|
302
|
+
temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
|
|
331
303
|
return temp_df
|
|
332
304
|
|
|
333
305
|
def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
|
|
334
306
|
"""Sample 2500 examples from the dataframe without replacement."""
|
|
335
307
|
temp_df = df.copy()
|
|
336
|
-
sample_df = temp_df.sample(
|
|
337
|
-
n=2500, replace=False, random_state=3, axis=0
|
|
338
|
-
)
|
|
308
|
+
sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
|
|
339
309
|
return sample_df
|
|
340
310
|
|
|
341
311
|
return (
|
|
@@ -363,11 +333,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
363
333
|
)
|
|
364
334
|
|
|
365
335
|
|
|
366
|
-
#
|
|
367
|
-
|
|
368
|
-
"""
|
|
369
|
-
Preprocesses the given Family Income dataframe by applying a
|
|
370
|
-
series of transformations and constraints.
|
|
336
|
+
def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
|
|
337
|
+
"""Preprocesses the given Family Income dataframe.
|
|
371
338
|
|
|
372
339
|
The function sequentially:
|
|
373
340
|
|
|
@@ -393,27 +360,43 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
393
360
|
"""
|
|
394
361
|
|
|
395
362
|
def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
396
|
-
"""Normalize
|
|
397
|
-
|
|
398
|
-
function
|
|
363
|
+
"""Normalize each column of the dataframe independently.
|
|
364
|
+
|
|
365
|
+
This function scales each column to have values between 0 and 1
|
|
366
|
+
(or another standard normalization, depending on implementation),
|
|
367
|
+
making it suitable for numerical processing. While designed for
|
|
368
|
+
the Family Income dataset, it can be applied to any dataframe
|
|
369
|
+
with numeric columns.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
df (pd.DataFrame): Input dataframe to normalize.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
pd.DataFrame: Dataframe with each column normalized independently.
|
|
376
|
+
"""
|
|
399
377
|
# copy the dataframe
|
|
400
378
|
temp_df = df.copy()
|
|
401
379
|
# normalize each column
|
|
402
380
|
for feature_name in df.columns:
|
|
403
381
|
max_value = df[feature_name].max()
|
|
404
382
|
min_value = df[feature_name].min()
|
|
405
|
-
temp_df[feature_name] = (df[feature_name] - min_value) / (
|
|
406
|
-
max_value - min_value
|
|
407
|
-
)
|
|
383
|
+
temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
|
|
408
384
|
return temp_df
|
|
409
385
|
|
|
410
386
|
def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
411
|
-
"""
|
|
412
|
-
|
|
413
|
-
This function
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
387
|
+
"""Filter rows that violate income-related constraints.
|
|
388
|
+
|
|
389
|
+
This function is specific to the Family Income dataset. It removes rows
|
|
390
|
+
that do not satisfy the following constraints:
|
|
391
|
+
1. Household income must be greater than all expenses.
|
|
392
|
+
2. Food expense must be greater than the sum of detailed food expenses.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
df (pd.DataFrame): Input dataframe containing income and expense data.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
pd.DataFrame: Filtered dataframe containing only rows that satisfy
|
|
399
|
+
all constraints.
|
|
417
400
|
"""
|
|
418
401
|
temp_df = df.copy()
|
|
419
402
|
# check that household income is larger than expenses in the output
|
|
@@ -421,9 +404,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
421
404
|
income_array = np.add(
|
|
422
405
|
np.multiply(
|
|
423
406
|
input_array[:, [0, 1]],
|
|
424
|
-
np.subtract(
|
|
425
|
-
np.asarray([11815988, 9234485]), np.asarray([11285, 0])
|
|
426
|
-
),
|
|
407
|
+
np.subtract(np.asarray([11815988, 9234485]), np.asarray([11285, 0])),
|
|
427
408
|
),
|
|
428
409
|
np.asarray([11285, 0]),
|
|
429
410
|
)
|
|
@@ -483,9 +464,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
483
464
|
food_mul_expense_array = expense_reduced_array[:, [1, 2, 3]]
|
|
484
465
|
food_mul_expense_array_sum = np.sum(food_mul_expense_array, axis=1)
|
|
485
466
|
food_expense_array = expense_reduced_array[:, 0]
|
|
486
|
-
sanity_check_array = np.greater_equal(
|
|
487
|
-
food_expense_array, food_mul_expense_array_sum
|
|
488
|
-
)
|
|
467
|
+
sanity_check_array = np.greater_equal(food_expense_array, food_mul_expense_array_sum)
|
|
489
468
|
drop_reduction["Unimportant"] = sanity_check_array.tolist()
|
|
490
469
|
new_reduction = drop_reduction[drop_reduction.Unimportant]
|
|
491
470
|
satisfied_constraints_df = new_reduction.drop("Unimportant", axis=1)
|
|
@@ -493,9 +472,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
493
472
|
return satisfied_constraints_df
|
|
494
473
|
|
|
495
474
|
def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
496
|
-
"""Add a multiindex denoting if the column is
|
|
497
|
-
an input or output variable."""
|
|
498
|
-
|
|
475
|
+
"""Add a multiindex denoting if the column is an input or output variable."""
|
|
499
476
|
# copy the dataframe
|
|
500
477
|
temp_df = df.copy()
|
|
501
478
|
# extract all the column names
|
|
@@ -510,9 +487,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
510
487
|
input_output_list = input_list_start + output_list + input_list_end
|
|
511
488
|
# define multi index for attaching this 'Input' and
|
|
512
489
|
# 'Output' list with the column names already existing
|
|
513
|
-
multiindex_bias = pd.MultiIndex.from_arrays(
|
|
514
|
-
[input_output_list, column_names]
|
|
515
|
-
)
|
|
490
|
+
multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
|
|
516
491
|
# transpose such that index can be adjusted to multi index
|
|
517
492
|
new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
|
|
518
493
|
# transpose back such that columns are the same as
|
|
@@ -522,9 +497,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
522
497
|
def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
|
|
523
498
|
"""Sample 2500 examples from the dataframe without replacement."""
|
|
524
499
|
temp_df = df.copy()
|
|
525
|
-
sample_df = temp_df.sample(
|
|
526
|
-
n=2500, replace=False, random_state=3, axis=0
|
|
527
|
-
)
|
|
500
|
+
sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
|
|
528
501
|
return sample_df
|
|
529
502
|
|
|
530
503
|
return (
|
|
@@ -573,9 +546,90 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
573
546
|
)
|
|
574
547
|
|
|
575
548
|
|
|
576
|
-
def
|
|
549
|
+
def preprocess_AdultCensusIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
|
|
550
|
+
"""Preprocesses the Adult Census Income dataset for PyTorch ML.
|
|
551
|
+
|
|
552
|
+
Sequential steps:
|
|
553
|
+
- Drop rows with missing values.
|
|
554
|
+
- Encode categorical variables to integer labels.
|
|
555
|
+
- Map the target 'income' column to 0/1.
|
|
556
|
+
- Convert all data to float32.
|
|
557
|
+
- Add a multiindex to denote Input vs Output columns.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
df (pd.DataFrame): Raw dataframe containing Adult Census Income data.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
pd.DataFrame: Preprocessed dataframe.
|
|
577
564
|
"""
|
|
578
|
-
|
|
565
|
+
|
|
566
|
+
def drop_missing(df: pd.DataFrame) -> pd.DataFrame:
|
|
567
|
+
"""Drop rows with any missing values."""
|
|
568
|
+
return df.dropna(how="any")
|
|
569
|
+
|
|
570
|
+
def drop_columns(df: pd.DataFrame) -> pd.DataFrame:
|
|
571
|
+
return df.drop(columns=["fnlwgt", "education.num"], errors="ignore")
|
|
572
|
+
|
|
573
|
+
def label_encode_column(series: pd.Series, col_name: str = None) -> pd.Series:
|
|
574
|
+
"""Encode a pandas Series of categorical strings into integers."""
|
|
575
|
+
categories = series.dropna().unique().tolist()
|
|
576
|
+
cat_to_int = {cat: i for i, cat in enumerate(categories)}
|
|
577
|
+
if col_name:
|
|
578
|
+
print(f"Column '{col_name}' encoding:")
|
|
579
|
+
for cat, idx in cat_to_int.items():
|
|
580
|
+
print(f" {cat} -> {idx}")
|
|
581
|
+
return series.map(cat_to_int).astype(int)
|
|
582
|
+
|
|
583
|
+
def encode_categorical(df: pd.DataFrame) -> pd.DataFrame:
|
|
584
|
+
"""Convert categorical string columns to integer labels using label_encode_column."""
|
|
585
|
+
df_temp = df.copy()
|
|
586
|
+
categorical_cols = [
|
|
587
|
+
"workclass",
|
|
588
|
+
"education",
|
|
589
|
+
"marital.status",
|
|
590
|
+
"occupation",
|
|
591
|
+
"relationship",
|
|
592
|
+
"race",
|
|
593
|
+
"sex",
|
|
594
|
+
"native.country",
|
|
595
|
+
]
|
|
596
|
+
for col in categorical_cols:
|
|
597
|
+
df_temp[col] = label_encode_column(df_temp[col].astype(str), col_name=col)
|
|
598
|
+
return df_temp
|
|
599
|
+
|
|
600
|
+
def map_target(df: pd.DataFrame) -> pd.DataFrame:
|
|
601
|
+
"""Map income column to 0 (<=50K) and 1 (>50K)."""
|
|
602
|
+
df_temp = df.copy()
|
|
603
|
+
df_temp["income"] = df_temp["income"].map({"<=50K": 0, ">50K": 1})
|
|
604
|
+
return df_temp
|
|
605
|
+
|
|
606
|
+
def convert_float32(df: pd.DataFrame) -> pd.DataFrame:
|
|
607
|
+
"""Convert all data to float32 for PyTorch compatibility."""
|
|
608
|
+
return df.astype("float32")
|
|
609
|
+
|
|
610
|
+
def add_input_output_index(df: pd.DataFrame) -> pd.DataFrame:
|
|
611
|
+
"""Add a multiindex indicating input and output columns."""
|
|
612
|
+
temp_df = df.copy()
|
|
613
|
+
column_names = temp_df.columns.tolist()
|
|
614
|
+
# Only the 'income' column is output
|
|
615
|
+
input_list = ["Input"] * (len(column_names) - 1)
|
|
616
|
+
output_list = ["Output"]
|
|
617
|
+
multiindex_list = input_list + output_list
|
|
618
|
+
multiindex = pd.MultiIndex.from_arrays([multiindex_list, column_names])
|
|
619
|
+
return pd.DataFrame(temp_df.to_numpy(), columns=multiindex)
|
|
620
|
+
|
|
621
|
+
return (
|
|
622
|
+
df.pipe(drop_missing)
|
|
623
|
+
.pipe(drop_columns)
|
|
624
|
+
.pipe(encode_categorical)
|
|
625
|
+
.pipe(map_target)
|
|
626
|
+
.pipe(convert_float32)
|
|
627
|
+
.pipe(add_input_output_index)
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def validate_type(name, value, expected_types, allow_none=False):
|
|
632
|
+
"""Validate that a value is of the specified type(s).
|
|
579
633
|
|
|
580
634
|
Args:
|
|
581
635
|
name (str): Name of the argument for error messages.
|
|
@@ -587,7 +641,6 @@ def validate_type(name, value, expected_types, allow_none=False):
|
|
|
587
641
|
Raises:
|
|
588
642
|
TypeError: If the value is not of the expected type(s).
|
|
589
643
|
"""
|
|
590
|
-
|
|
591
644
|
if value is None:
|
|
592
645
|
if not allow_none:
|
|
593
646
|
raise TypeError(f"Argument {name} cannot be None.")
|
|
@@ -604,12 +657,11 @@ def validate_iterable(
|
|
|
604
657
|
name,
|
|
605
658
|
value,
|
|
606
659
|
expected_element_types,
|
|
607
|
-
allowed_iterables=(list, set),
|
|
660
|
+
allowed_iterables=(list, set, tuple),
|
|
661
|
+
allow_empty=False,
|
|
608
662
|
allow_none=False,
|
|
609
663
|
):
|
|
610
|
-
"""
|
|
611
|
-
Validate that a value is an iterable (e.g., list, set) with elements of
|
|
612
|
-
the specified type(s).
|
|
664
|
+
"""Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).
|
|
613
665
|
|
|
614
666
|
Args:
|
|
615
667
|
name (str): Name of the argument for error messages.
|
|
@@ -618,6 +670,7 @@ def validate_iterable(
|
|
|
618
670
|
for the elements.
|
|
619
671
|
allowed_iterables (tuple of types): Iterable types that are
|
|
620
672
|
allowed (default: list and set).
|
|
673
|
+
allow_empty (bool): Whether to allow empty iterables. Defaults to False.
|
|
621
674
|
allow_none (bool): Whether to allow the value to be None.
|
|
622
675
|
Defaults to False.
|
|
623
676
|
|
|
@@ -625,20 +678,22 @@ def validate_iterable(
|
|
|
625
678
|
TypeError: If the value is not an allowed iterable type or if
|
|
626
679
|
any element is not of the expected type(s).
|
|
627
680
|
"""
|
|
628
|
-
|
|
629
681
|
if value is None:
|
|
630
682
|
if not allow_none:
|
|
631
683
|
raise TypeError(f"Argument {name} cannot be None.")
|
|
632
684
|
return
|
|
633
685
|
|
|
686
|
+
if len(value) == 0:
|
|
687
|
+
if not allow_empty:
|
|
688
|
+
raise TypeError(f"Argument {name} cannot be an empty iterable.")
|
|
689
|
+
return
|
|
690
|
+
|
|
634
691
|
if not isinstance(value, allowed_iterables):
|
|
635
692
|
raise TypeError(
|
|
636
693
|
f"Argument {name} '{str(value)}' is not supported. "
|
|
637
694
|
f"Only values of type {str(allowed_iterables)} are allowed."
|
|
638
695
|
)
|
|
639
|
-
if not all(
|
|
640
|
-
isinstance(element, expected_element_types) for element in value
|
|
641
|
-
):
|
|
696
|
+
if not all(isinstance(element, expected_element_types) for element in value):
|
|
642
697
|
raise TypeError(
|
|
643
698
|
f"Invalid elements in {name} '{str(value)}'. "
|
|
644
699
|
f"Only elements of type {str(expected_element_types)} are allowed."
|
|
@@ -646,8 +701,7 @@ def validate_iterable(
|
|
|
646
701
|
|
|
647
702
|
|
|
648
703
|
def validate_comparator_pytorch(name, value):
|
|
649
|
-
"""
|
|
650
|
-
Validate that a value is a callable PyTorch comparator function.
|
|
704
|
+
"""Validate that a value is a callable PyTorch comparator function.
|
|
651
705
|
|
|
652
706
|
Args:
|
|
653
707
|
name (str): Name of the argument for error messages.
|
|
@@ -656,7 +710,6 @@ def validate_comparator_pytorch(name, value):
|
|
|
656
710
|
Raises:
|
|
657
711
|
TypeError: If the value is not callable or not a PyTorch comparator.
|
|
658
712
|
"""
|
|
659
|
-
|
|
660
713
|
# List of valid PyTorch comparator functions
|
|
661
714
|
pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
|
|
662
715
|
|
|
@@ -664,8 +717,7 @@ def validate_comparator_pytorch(name, value):
|
|
|
664
717
|
# the PyTorch comparator functions
|
|
665
718
|
if not callable(value):
|
|
666
719
|
raise TypeError(
|
|
667
|
-
f"Argument {name} '{str(value)}' is not supported. "
|
|
668
|
-
"Only callable functions are allowed."
|
|
720
|
+
f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
|
|
669
721
|
)
|
|
670
722
|
|
|
671
723
|
if value not in pytorch_comparators:
|
|
@@ -677,8 +729,7 @@ def validate_comparator_pytorch(name, value):
|
|
|
677
729
|
|
|
678
730
|
|
|
679
731
|
def validate_callable(name, value, allow_none=False):
|
|
680
|
-
"""
|
|
681
|
-
Validate that a value is callable function.
|
|
732
|
+
"""Validate that a value is callable function.
|
|
682
733
|
|
|
683
734
|
Args:
|
|
684
735
|
name (str): Name of the argument for error messages.
|
|
@@ -689,22 +740,339 @@ def validate_callable(name, value, allow_none=False):
|
|
|
689
740
|
Raises:
|
|
690
741
|
TypeError: If the value is not callable.
|
|
691
742
|
"""
|
|
692
|
-
|
|
693
743
|
if value is None:
|
|
694
744
|
if not allow_none:
|
|
695
745
|
raise TypeError(f"Argument {name} cannot be None.")
|
|
696
746
|
return
|
|
697
747
|
|
|
698
748
|
if not callable(value):
|
|
749
|
+
raise TypeError(
|
|
750
|
+
f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def validate_callable_iterable(
|
|
755
|
+
name,
|
|
756
|
+
value,
|
|
757
|
+
allowed_iterables=(list, set, tuple),
|
|
758
|
+
allow_none=False,
|
|
759
|
+
):
|
|
760
|
+
"""Validate that a value is an iterable containing only callable elements.
|
|
761
|
+
|
|
762
|
+
This function ensures that the given value is an iterable
|
|
763
|
+
(e.g., list or set and that all its elements are callable functions.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
name (str): Name of the argument for error messages.
|
|
767
|
+
value: The value to validate.
|
|
768
|
+
allowed_iterables (tuple of types, optional): Iterable types that are
|
|
769
|
+
allowed. Defaults to (list, set).
|
|
770
|
+
allow_none (bool, optional): Whether to allow the value to be None.
|
|
771
|
+
Defaults to False.
|
|
772
|
+
|
|
773
|
+
Raises:
|
|
774
|
+
TypeError: If the value is not an allowed iterable type or if any
|
|
775
|
+
element is not callable.
|
|
776
|
+
"""
|
|
777
|
+
if value is None:
|
|
778
|
+
if not allow_none:
|
|
779
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
780
|
+
return
|
|
781
|
+
|
|
782
|
+
if not isinstance(value, allowed_iterables):
|
|
699
783
|
raise TypeError(
|
|
700
784
|
f"Argument {name} '{str(value)}' is not supported. "
|
|
701
|
-
"Only
|
|
785
|
+
f"Only values of type {str(allowed_iterables)} are allowed."
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
if not all(callable(element) for element in value):
|
|
789
|
+
raise TypeError(
|
|
790
|
+
f"Invalid elements in {name} '{str(value)}'. Only callable functions are allowed."
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader]):
|
|
795
|
+
"""Validates that `loaders` is a tuple of three DataLoader instances.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
name (str): The name of the parameter being validated.
|
|
799
|
+
loaders (tuple[DataLoader, DataLoader, DataLoader]): A tuple of
|
|
800
|
+
three DataLoader instances.
|
|
801
|
+
|
|
802
|
+
Raises:
|
|
803
|
+
TypeError: If `loaders` is not a tuple of three DataLoader
|
|
804
|
+
instances or contains invalid types.
|
|
805
|
+
"""
|
|
806
|
+
if not isinstance(loaders, tuple) or len(loaders) != 3:
|
|
807
|
+
raise TypeError(f"{name} must be a tuple of three DataLoader instances.")
|
|
808
|
+
|
|
809
|
+
for i, loader in enumerate(loaders):
|
|
810
|
+
if not isinstance(loader, DataLoader):
|
|
811
|
+
raise TypeError(
|
|
812
|
+
f"{name}[{i}] must be an instance of DataLoader, got {type(loader).__name__}."
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
class ZeroLoss(_Loss):
|
|
817
|
+
"""A loss function that always returns zero.
|
|
818
|
+
|
|
819
|
+
This custom loss function ignores the input and target tensors
|
|
820
|
+
and returns a constant zero loss, which can be useful for debugging
|
|
821
|
+
or when no meaningful loss computation is required.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
reduction (str, optional): Specifies the reduction to apply to
|
|
825
|
+
the output. Defaults to "mean". Although specified, it has
|
|
826
|
+
no effect as the loss is always zero.
|
|
827
|
+
"""
|
|
828
|
+
|
|
829
|
+
def __init__(self, reduction: str = "mean"):
|
|
830
|
+
"""Initialize ZeroLoss with a specified reduction method.
|
|
831
|
+
|
|
832
|
+
Args:
|
|
833
|
+
reduction (str): Specifies the reduction to apply to the output. Defaults to "mean".
|
|
834
|
+
"""
|
|
835
|
+
super().__init__(reduction=reduction)
|
|
836
|
+
|
|
837
|
+
def forward(self, predictions: Tensor, target: Tensor, **kwargs) -> torch.Tensor:
|
|
838
|
+
"""Return a dummy loss of zero regardless of input and target."""
|
|
839
|
+
return (predictions * 0).sum()
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def is_torch_loss(criterion) -> bool:
|
|
843
|
+
"""Return True if the object is a PyTorch loss function."""
|
|
844
|
+
type_ = str(type(criterion)).split("'")[1]
|
|
845
|
+
parent = type_.rsplit(".", 1)[0]
|
|
846
|
+
|
|
847
|
+
return parent == "torch.nn.modules.loss"
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def torch_loss_wrapper(criterion: _Loss) -> _Loss:
|
|
851
|
+
"""Wraps a PyTorch loss function to handle the case where the loss function forward pass does not allow **kwargs.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
criterion (_Loss): The PyTorch loss function to wrap.
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
_Loss: The wrapped criterion that allows **kwargs in the forward pass.
|
|
858
|
+
"""
|
|
859
|
+
|
|
860
|
+
class WrappedCriterion(_Loss):
|
|
861
|
+
def __init__(self, criterion):
|
|
862
|
+
super().__init__()
|
|
863
|
+
self.criterion = criterion
|
|
864
|
+
|
|
865
|
+
def forward(self, *args, **kwargs):
|
|
866
|
+
return self.criterion(*args)
|
|
867
|
+
|
|
868
|
+
return WrappedCriterion(criterion)
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor):
|
|
872
|
+
"""Reorders input samples to support monotonicity checking.
|
|
873
|
+
|
|
874
|
+
Reorders input samples such that:
|
|
875
|
+
1. Samples from the same run are grouped together.
|
|
876
|
+
2. Within each run, samples are sorted chronologically.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
data (Tensor): The input data.
|
|
880
|
+
ordering (Tensor): On what to order the data.
|
|
881
|
+
identifiers (Tensor): Identifiers specifying different runs.
|
|
882
|
+
|
|
883
|
+
Returns:
|
|
884
|
+
Tuple[Tensor, Tensor, Tensor]: Sorted data, ordering, and
|
|
885
|
+
identifiers.
|
|
886
|
+
"""
|
|
887
|
+
# Step 1: Sort by run identifiers
|
|
888
|
+
sorted_indices = argsort(identifiers, stable=True, dim=0).reshape(-1)
|
|
889
|
+
data_sorted, ordering_sorted, identifiers_sorted = (
|
|
890
|
+
data[sorted_indices],
|
|
891
|
+
ordering[sorted_indices],
|
|
892
|
+
identifiers[sorted_indices],
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Step 2: Get unique runs and their counts
|
|
896
|
+
_, counts = unique(identifiers, sorted=False, return_counts=True)
|
|
897
|
+
counts = counts.to(int32) # Avoid repeated conversions
|
|
898
|
+
|
|
899
|
+
sorted_data, sorted_ordering, sorted_identifiers = [], [], []
|
|
900
|
+
index = 0 # Tracks the current batch element index
|
|
901
|
+
|
|
902
|
+
# Step 3: Process each run independently
|
|
903
|
+
for count in counts:
|
|
904
|
+
end = index + count
|
|
905
|
+
run_data, run_ordering, run_identifiers = (
|
|
906
|
+
data_sorted[index:end],
|
|
907
|
+
ordering_sorted[index:end],
|
|
908
|
+
identifiers_sorted[index:end],
|
|
702
909
|
)
|
|
703
910
|
|
|
911
|
+
# Step 4: Sort within each run by time
|
|
912
|
+
time_sorted_indices = argsort(run_ordering, stable=True, dim=0).reshape(-1)
|
|
913
|
+
sorted_data.append(run_data[time_sorted_indices])
|
|
914
|
+
sorted_ordering.append(run_ordering[time_sorted_indices])
|
|
915
|
+
sorted_identifiers.append(run_identifiers[time_sorted_indices])
|
|
916
|
+
|
|
917
|
+
index = end # Move to next run
|
|
918
|
+
|
|
919
|
+
# Step 5: Concatenate results and return
|
|
920
|
+
return (
|
|
921
|
+
cat(sorted_data, dim=0),
|
|
922
|
+
cat(sorted_ordering, dim=0),
|
|
923
|
+
cat(sorted_identifiers, dim=0),
|
|
924
|
+
)
|
|
925
|
+
|
|
704
926
|
|
|
705
|
-
|
|
927
|
+
class DictDatasetWrapper(Dataset):
|
|
928
|
+
"""A wrapper for PyTorch datasets that converts each sample into a dictionary.
|
|
929
|
+
|
|
930
|
+
This class takes any PyTorch dataset and returns its samples as dictionaries,
|
|
931
|
+
where each element of the original sample is mapped to a key. This is useful
|
|
932
|
+
for integration with the Congrads toolbox or other frameworks that expect
|
|
933
|
+
dictionary-formatted data.
|
|
934
|
+
|
|
935
|
+
Attributes:
|
|
936
|
+
base_dataset (Dataset): The underlying PyTorch dataset being wrapped.
|
|
937
|
+
field_names (list[str] | None): Names assigned to each field of a sample.
|
|
938
|
+
If None, default names like 'field0', 'field1', ... are generated.
|
|
939
|
+
|
|
940
|
+
Args:
|
|
941
|
+
base_dataset (Dataset): The PyTorch dataset to wrap.
|
|
942
|
+
field_names (list[str] | None, optional): Custom names for each field.
|
|
943
|
+
If provided, the list is truncated or extended to match the number
|
|
944
|
+
of elements in a sample. Defaults to None.
|
|
945
|
+
|
|
946
|
+
Example:
|
|
947
|
+
Wrapping a TensorDataset with custom field names:
|
|
948
|
+
|
|
949
|
+
>>> from torch.utils.data import TensorDataset
|
|
950
|
+
>>> import torch
|
|
951
|
+
>>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
|
|
952
|
+
>>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
|
|
953
|
+
>>> wrapped[0]
|
|
954
|
+
{'features': tensor([...]), 'label': tensor(1)}
|
|
955
|
+
|
|
956
|
+
Wrapping a built-in dataset like CIFAR10:
|
|
957
|
+
|
|
958
|
+
>>> from torchvision.datasets import CIFAR10
|
|
959
|
+
>>> from torchvision import transforms
|
|
960
|
+
>>> cifar = CIFAR10(
|
|
961
|
+
... root="./data", train=True, download=True, transform=transforms.ToTensor()
|
|
962
|
+
... )
|
|
963
|
+
>>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
|
|
964
|
+
>>> wrapped_cifar[0]
|
|
965
|
+
{'input': tensor([...]), 'output': tensor(6)}
|
|
706
966
|
"""
|
|
707
|
-
|
|
967
|
+
|
|
968
|
+
def __init__(self, base_dataset: Dataset, field_names: list[str] | None = None):
|
|
969
|
+
"""Initialize the DictDatasetWrapper.
|
|
970
|
+
|
|
971
|
+
Args:
|
|
972
|
+
base_dataset (Dataset): The PyTorch dataset to wrap.
|
|
973
|
+
field_names (list[str] | None, optional): Optional list of field names
|
|
974
|
+
for the dictionary output. Defaults to None, in which case
|
|
975
|
+
automatic names 'field0', 'field1', ... are generated.
|
|
976
|
+
"""
|
|
977
|
+
self.base_dataset = base_dataset
|
|
978
|
+
self.field_names = field_names
|
|
979
|
+
|
|
980
|
+
def __getitem__(self, idx: int):
|
|
981
|
+
"""Retrieve a sample from the dataset as a dictionary.
|
|
982
|
+
|
|
983
|
+
Each element in the original sample is mapped to a key in the dictionary.
|
|
984
|
+
If the sample is not a tuple or list, it is converted into a single-element
|
|
985
|
+
tuple. Numerical values (int or float) are automatically converted to tensors.
|
|
986
|
+
|
|
987
|
+
Args:
|
|
988
|
+
idx (int): Index of the sample to retrieve.
|
|
989
|
+
|
|
990
|
+
Returns:
|
|
991
|
+
dict: A dictionary mapping field names to sample values.
|
|
992
|
+
"""
|
|
993
|
+
sample = self.base_dataset[idx]
|
|
994
|
+
|
|
995
|
+
# Ensure sample is always a tuple
|
|
996
|
+
if not isinstance(sample, (tuple, list)):
|
|
997
|
+
sample = (sample,)
|
|
998
|
+
|
|
999
|
+
n_fields = len(sample)
|
|
1000
|
+
|
|
1001
|
+
# Generate default field names if none are provided
|
|
1002
|
+
if self.field_names is None:
|
|
1003
|
+
names = [f"field{i}" for i in range(n_fields)]
|
|
1004
|
+
else:
|
|
1005
|
+
names = list(self.field_names)
|
|
1006
|
+
if len(names) < n_fields:
|
|
1007
|
+
names.extend([f"field{i}" for i in range(len(names), n_fields)])
|
|
1008
|
+
names = names[:n_fields] # truncate if too long
|
|
1009
|
+
|
|
1010
|
+
# Build dictionary
|
|
1011
|
+
out = {}
|
|
1012
|
+
for name, value in zip(names, sample, strict=False):
|
|
1013
|
+
if isinstance(value, (int, float)):
|
|
1014
|
+
value = torch.tensor(value)
|
|
1015
|
+
out[name] = value
|
|
1016
|
+
|
|
1017
|
+
return out
|
|
1018
|
+
|
|
1019
|
+
def __len__(self):
|
|
1020
|
+
"""Return the number of samples in the dataset.
|
|
1021
|
+
|
|
1022
|
+
Returns:
|
|
1023
|
+
int: Length of the underlying dataset.
|
|
1024
|
+
"""
|
|
1025
|
+
return len(self.base_dataset)
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
class Seeder:
|
|
1029
|
+
"""A deterministic seed manager for reproducible experiments.
|
|
1030
|
+
|
|
1031
|
+
This class provides a way to consistently generate pseudo-random
|
|
1032
|
+
seeds derived from a fixed base seed. It ensures that different
|
|
1033
|
+
libraries (Python's `random`, NumPy, and PyTorch) are initialized
|
|
1034
|
+
with reproducible seeds, making experiments deterministic across runs.
|
|
708
1035
|
"""
|
|
709
|
-
|
|
710
|
-
|
|
1036
|
+
|
|
1037
|
+
def __init__(self, base_seed: int):
|
|
1038
|
+
"""Initialize the Seeder with a base seed.
|
|
1039
|
+
|
|
1040
|
+
Args:
|
|
1041
|
+
base_seed (int): The initial seed from which all subsequent
|
|
1042
|
+
pseudo-random seeds are deterministically derived.
|
|
1043
|
+
"""
|
|
1044
|
+
self._rng = random.Random(base_seed)
|
|
1045
|
+
|
|
1046
|
+
def roll_seed(self) -> int:
|
|
1047
|
+
"""Generate a new deterministic pseudo-random seed.
|
|
1048
|
+
|
|
1049
|
+
Each call returns an integer seed derived from the internal
|
|
1050
|
+
pseudo-random generator, which itself is initialized by the
|
|
1051
|
+
base seed.
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
int: A pseudo-random integer seed in the range [0, 2**31 - 1].
|
|
1055
|
+
"""
|
|
1056
|
+
return self._rng.randint(0, 2**31 - 1)
|
|
1057
|
+
|
|
1058
|
+
def set_reproducible(self) -> None:
|
|
1059
|
+
"""Configure global random states for reproducibility.
|
|
1060
|
+
|
|
1061
|
+
Seeds the following libraries with deterministically generated
|
|
1062
|
+
seeds based on the base seed:
|
|
1063
|
+
- Python's built-in `random`
|
|
1064
|
+
- NumPy's random number generator
|
|
1065
|
+
- PyTorch (CPU and GPU)
|
|
1066
|
+
|
|
1067
|
+
Also enforces deterministic behavior in PyTorch by:
|
|
1068
|
+
- Seeding all CUDA devices
|
|
1069
|
+
- Disabling CuDNN benchmarking
|
|
1070
|
+
- Enabling CuDNN deterministic mode
|
|
1071
|
+
"""
|
|
1072
|
+
random.seed(self.roll_seed())
|
|
1073
|
+
np.random.seed(self.roll_seed())
|
|
1074
|
+
torch.manual_seed(self.roll_seed())
|
|
1075
|
+
torch.cuda.manual_seed_all(self.roll_seed())
|
|
1076
|
+
|
|
1077
|
+
torch.backends.cudnn.deterministic = True
|
|
1078
|
+
torch.backends.cudnn.benchmark = False
|