congrads 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +17 -10
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +664 -134
- congrads/core.py +482 -110
- congrads/datasets.py +315 -11
- congrads/descriptor.py +100 -20
- congrads/metrics.py +178 -16
- congrads/networks.py +47 -23
- congrads/transformations.py +139 -0
- congrads/utils.py +439 -39
- congrads-1.0.2.dist-info/METADATA +208 -0
- congrads-1.0.2.dist-info/RECORD +15 -0
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/WHEEL +1 -1
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/LICENSE +0 -0
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/top_level.txt +0 -0
congrads/utils.py
CHANGED
|
@@ -1,9 +1,148 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module provides utility functions and classes for managing data,
|
|
3
|
+
logging, and preprocessing in machine learning workflows. The functionalities
|
|
4
|
+
include logging key-value pairs to CSV files, splitting datasets into
|
|
5
|
+
training, validation, and test sets, preprocessing functions for various
|
|
6
|
+
data sets and validator functions for type checking.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
2
11
|
import numpy as np
|
|
3
|
-
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
from torch import Generator
|
|
15
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CSVLogger:
|
|
19
|
+
"""
|
|
20
|
+
A utility class for logging key-value pairs to a CSV file, organized by
|
|
21
|
+
epochs. Supports merging with existing logs or overwriting them.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
file_path (str): The path to the CSV file for logging.
|
|
25
|
+
overwrite (bool): If True, overwrites any existing file at
|
|
26
|
+
the file_path.
|
|
27
|
+
merge (bool): If True, merges new values with existing data
|
|
28
|
+
in the file.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If both overwrite and merge are True.
|
|
32
|
+
FileExistsError: If the file already exists and neither
|
|
33
|
+
overwrite nor merge is True.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self, file_path: str, overwrite: bool = False, merge: bool = True
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initializes the CSVLogger.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
self.file_path = file_path
|
|
44
|
+
self.values: dict[tuple[int, str], float] = {}
|
|
45
|
+
|
|
46
|
+
if merge and overwrite:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"The attributes overwrite and merge cannot be True at the "
|
|
49
|
+
"same time. Either specify overwrite=True or merge=True."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if not os.path.exists(file_path):
|
|
53
|
+
pass
|
|
54
|
+
elif merge:
|
|
55
|
+
self.load()
|
|
56
|
+
elif overwrite:
|
|
57
|
+
pass
|
|
58
|
+
else:
|
|
59
|
+
raise FileExistsError(
|
|
60
|
+
f"A CSV file already exists at {file_path}. Specify "
|
|
61
|
+
"CSVLogger(..., overwrite=True) to overwrite the file."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def add_value(self, name: str, value: float, epoch: int):
|
|
65
|
+
"""
|
|
66
|
+
Adds a value to the logger for a specific epoch and name.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
name (str): The name of the metric or value to log.
|
|
70
|
+
value (float): The value to log.
|
|
71
|
+
epoch (int): The epoch associated with the value.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
self.values[epoch, name] = value
|
|
75
|
+
|
|
76
|
+
def save(self):
|
|
77
|
+
"""
|
|
78
|
+
Saves the logged values to the specified CSV file.
|
|
79
|
+
|
|
80
|
+
If the file exists and merge is enabled, merges the current data
|
|
81
|
+
with the existing file.
|
|
82
|
+
"""
|
|
4
83
|
|
|
84
|
+
data = self.to_dataframe(self.values)
|
|
85
|
+
data.to_csv(self.file_path, index=False)
|
|
5
86
|
|
|
6
|
-
def
|
|
87
|
+
def load(self):
|
|
88
|
+
"""
|
|
89
|
+
Loads data from the CSV file into the logger.
|
|
90
|
+
|
|
91
|
+
Converts the CSV data into the internal dictionary format for
|
|
92
|
+
further updates or operations.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
df = pd.read_csv(self.file_path)
|
|
96
|
+
self.values = self.to_dict(df)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
|
|
100
|
+
"""
|
|
101
|
+
Converts a dictionary of values into a DataFrame.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
values (dict[tuple[int, str], float]): A dictionary of values
|
|
105
|
+
keyed by (epoch, name).
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
pd.DataFrame: A DataFrame where epochs are rows, names are
|
|
109
|
+
columns, and values are the cell data.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
# Convert to a DataFrame
|
|
113
|
+
df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
|
|
114
|
+
|
|
115
|
+
# Reset the index to separate epoch and name into columns
|
|
116
|
+
df.index = pd.MultiIndex.from_tuples(df.index, names=["epoch", "name"])
|
|
117
|
+
df = df.reset_index()
|
|
118
|
+
|
|
119
|
+
# Pivot the DataFrame so epochs are rows and names are columns
|
|
120
|
+
result = df.pivot(index="epoch", columns="name", values="value")
|
|
121
|
+
|
|
122
|
+
# Optional: Reset the column names for a cleaner look
|
|
123
|
+
result = result.reset_index().rename_axis(columns=None)
|
|
124
|
+
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
|
|
129
|
+
"""
|
|
130
|
+
Converts a DataFrame with epochs as rows and names as columns
|
|
131
|
+
back into a dictionary of the format {(epoch, name): value}.
|
|
132
|
+
"""
|
|
133
|
+
# Set the epoch column as the index (if not already)
|
|
134
|
+
df = df.set_index("epoch")
|
|
135
|
+
|
|
136
|
+
# Stack the DataFrame to create a multi-index series
|
|
137
|
+
stacked = df.stack()
|
|
138
|
+
|
|
139
|
+
# Convert the multi-index series to a dictionary
|
|
140
|
+
result = stacked.to_dict()
|
|
141
|
+
|
|
142
|
+
return result
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def split_data_loaders(
|
|
7
146
|
data: Dataset,
|
|
8
147
|
loader_args: dict = None,
|
|
9
148
|
train_loader_args: dict = None,
|
|
@@ -12,7 +151,40 @@ def splitDataLoaders(
|
|
|
12
151
|
train_size: float = 0.8,
|
|
13
152
|
valid_size: float = 0.1,
|
|
14
153
|
test_size: float = 0.1,
|
|
154
|
+
split_generator: Generator = None,
|
|
15
155
|
) -> tuple[DataLoader, DataLoader, DataLoader]:
|
|
156
|
+
"""
|
|
157
|
+
Splits a dataset into training, validation, and test sets,
|
|
158
|
+
and returns corresponding DataLoader objects.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
data (Dataset): The dataset to be split.
|
|
162
|
+
loader_args (dict, optional): Default DataLoader arguments, merges
|
|
163
|
+
with loader-specific arguments, overlapping keys from
|
|
164
|
+
loader-specific arguments are superseded.
|
|
165
|
+
train_loader_args (dict, optional): Training DataLoader arguments,
|
|
166
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
167
|
+
valid_loader_args (dict, optional): Validation DataLoader arguments,
|
|
168
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
169
|
+
test_loader_args (dict, optional): Test DataLoader arguments,
|
|
170
|
+
merges with `loader_args`, overriding overlapping keys.
|
|
171
|
+
train_size (float, optional): Proportion of data to be used for
|
|
172
|
+
training. Defaults to 0.8.
|
|
173
|
+
valid_size (float, optional): Proportion of data to be used for
|
|
174
|
+
validation. Defaults to 0.1.
|
|
175
|
+
test_size (float, optional): Proportion of data to be used for
|
|
176
|
+
testing. Defaults to 0.1.
|
|
177
|
+
split_generator (Generator, optional): Optional random seed generator
|
|
178
|
+
to control the splitting of the dataset.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
tuple: A tuple containing three DataLoader objects: one for the
|
|
182
|
+
training, validation and test set.
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If the train_size, valid_size, and test_size are not
|
|
186
|
+
between 0 and 1, or if their sum does not equal 1.
|
|
187
|
+
"""
|
|
16
188
|
|
|
17
189
|
# Validate split sizes
|
|
18
190
|
if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
|
|
@@ -23,19 +195,22 @@ def splitDataLoaders(
|
|
|
23
195
|
raise ValueError("train_size, valid_size, and test_size must sum to 1.")
|
|
24
196
|
|
|
25
197
|
# Perform the splits
|
|
26
|
-
train_val_data, test_data = random_split(
|
|
198
|
+
train_val_data, test_data = random_split(
|
|
199
|
+
data, [1 - test_size, test_size], generator=split_generator
|
|
200
|
+
)
|
|
27
201
|
train_data, valid_data = random_split(
|
|
28
202
|
train_val_data,
|
|
29
203
|
[
|
|
30
204
|
train_size / (1 - test_size),
|
|
31
205
|
valid_size / (1 - test_size),
|
|
32
206
|
],
|
|
207
|
+
generator=split_generator,
|
|
33
208
|
)
|
|
34
209
|
|
|
35
210
|
# Set default arguments for each loader
|
|
36
|
-
train_loader_args =
|
|
37
|
-
valid_loader_args =
|
|
38
|
-
test_loader_args =
|
|
211
|
+
train_loader_args = dict(loader_args or {}, **(train_loader_args or {}))
|
|
212
|
+
valid_loader_args = dict(loader_args or {}, **(valid_loader_args or {}))
|
|
213
|
+
test_loader_args = dict(loader_args or {}, **(test_loader_args or {}))
|
|
39
214
|
|
|
40
215
|
# Create the DataLoaders
|
|
41
216
|
train_generator = DataLoader(train_data, **train_loader_args)
|
|
@@ -45,18 +220,43 @@ def splitDataLoaders(
|
|
|
45
220
|
return train_generator, valid_generator, test_generator
|
|
46
221
|
|
|
47
222
|
|
|
223
|
+
# pylint: disable-next=invalid-name
|
|
48
224
|
def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
225
|
+
"""
|
|
226
|
+
Preprocesses the given dataframe for bias correction by
|
|
227
|
+
performing a series of transformations.
|
|
228
|
+
|
|
229
|
+
The function sequentially:
|
|
230
|
+
|
|
231
|
+
- Drops rows with missing values.
|
|
232
|
+
- Converts a date string to datetime format and adds year, month,
|
|
233
|
+
and day columns.
|
|
234
|
+
- Normalizes the columns with specific logic for input and output variables.
|
|
235
|
+
- Adds a multi-index indicating which columns are input or output variables.
|
|
236
|
+
- Samples 2500 examples from the dataset without replacement.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
df (pd.DataFrame): The input dataframe containing the data
|
|
240
|
+
to be processed.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
pd.DataFrame: The processed dataframe after applying
|
|
244
|
+
the transformations.
|
|
245
|
+
"""
|
|
49
246
|
|
|
50
247
|
def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
|
|
51
|
-
"""Transform the string that denotes the date to
|
|
248
|
+
"""Transform the string that denotes the date to
|
|
249
|
+
the datetime format in pandas."""
|
|
52
250
|
# make copy of dataframe
|
|
53
251
|
df_temp = df.copy()
|
|
54
|
-
# add new column at the front where the date string is
|
|
252
|
+
# add new column at the front where the date string is
|
|
253
|
+
# transformed to the datetime format
|
|
55
254
|
df_temp.insert(0, "DateTransformed", pd.to_datetime(df_temp["Date"]))
|
|
56
255
|
return df_temp
|
|
57
256
|
|
|
58
257
|
def add_year(df: pd.DataFrame) -> pd.DataFrame:
|
|
59
|
-
"""Extract the year from the datetime cell and add it
|
|
258
|
+
"""Extract the year from the datetime cell and add it
|
|
259
|
+
as a new column to the dataframe at the front."""
|
|
60
260
|
# make copy of dataframe
|
|
61
261
|
df_temp = df.copy()
|
|
62
262
|
# extract year and add new column at the front containing these numbers
|
|
@@ -64,7 +264,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
64
264
|
return df_temp
|
|
65
265
|
|
|
66
266
|
def add_month(df: pd.DataFrame) -> pd.DataFrame:
|
|
67
|
-
"""Extract the month from the datetime cell and add it
|
|
267
|
+
"""Extract the month from the datetime cell and add it
|
|
268
|
+
as a new column to the dataframe at the front."""
|
|
68
269
|
# make copy of dataframe
|
|
69
270
|
df_temp = df.copy()
|
|
70
271
|
# extract month and add new column at index 1 containing these numbers
|
|
@@ -72,7 +273,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
72
273
|
return df_temp
|
|
73
274
|
|
|
74
275
|
def add_day(df: pd.DataFrame) -> pd.DataFrame:
|
|
75
|
-
"""Extract the day from the datetime cell and add it
|
|
276
|
+
"""Extract the day from the datetime cell and add it
|
|
277
|
+
as a new column to the dataframe at the front."""
|
|
76
278
|
# make copy of dataframe
|
|
77
279
|
df_temp = df.copy()
|
|
78
280
|
# extract day and add new column at index 2 containing these numbers
|
|
@@ -80,35 +282,46 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
80
282
|
return df_temp
|
|
81
283
|
|
|
82
284
|
def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
|
|
83
|
-
"""Add a multiindex denoting if the column
|
|
285
|
+
"""Add a multiindex denoting if the column
|
|
286
|
+
is an input or output variable."""
|
|
84
287
|
# copy the dataframe
|
|
85
288
|
temp_df = df.copy()
|
|
86
289
|
# extract all the column names
|
|
87
290
|
column_names = temp_df.columns.tolist()
|
|
88
|
-
# only the last 2 columns are output variables, all others are input
|
|
291
|
+
# only the last 2 columns are output variables, all others are input
|
|
292
|
+
# variables. So make list of corresponding lengths of
|
|
293
|
+
# 'Input' and 'Output'
|
|
89
294
|
input_list = ["Input"] * (len(column_names) - 2)
|
|
90
295
|
output_list = ["Output"] * 2
|
|
91
296
|
# concat both lists
|
|
92
297
|
input_output_list = input_list + output_list
|
|
93
|
-
# define multi index for attaching this 'Input' and 'Output' list with
|
|
94
|
-
|
|
298
|
+
# define multi index for attaching this 'Input' and 'Output' list with
|
|
299
|
+
# the column names already existing
|
|
300
|
+
multiindex_bias = pd.MultiIndex.from_arrays(
|
|
301
|
+
[input_output_list, column_names]
|
|
302
|
+
)
|
|
95
303
|
# transpose such that index can be adjusted to multi index
|
|
96
304
|
new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
|
|
97
|
-
# transpose back such that columns are the same as before
|
|
305
|
+
# transpose back such that columns are the same as before
|
|
306
|
+
# except with different labels
|
|
98
307
|
return new_df.transpose()
|
|
99
308
|
|
|
100
309
|
def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
|
|
101
|
-
"""Normalize the columns for the bias correction dataset.
|
|
102
|
-
|
|
310
|
+
"""Normalize the columns for the bias correction dataset.
|
|
311
|
+
This is different from normalizing all the columns separately
|
|
312
|
+
because the upper and lower bounds for the output variables
|
|
313
|
+
are assumed to be the same."""
|
|
103
314
|
# copy the dataframe
|
|
104
315
|
temp_df = df.copy()
|
|
105
316
|
# normalize each column
|
|
106
317
|
for feature_name in df.columns:
|
|
107
|
-
# the output columns are normalized using the same upper and
|
|
318
|
+
# the output columns are normalized using the same upper and
|
|
319
|
+
# lower bound for more efficient check of the inequality
|
|
108
320
|
if feature_name == "Next_Tmax" or feature_name == "Next_Tmin":
|
|
109
321
|
max_value = 38.9
|
|
110
322
|
min_value = 11.3
|
|
111
|
-
# the input columns are normalized using their respective
|
|
323
|
+
# the input columns are normalized using their respective
|
|
324
|
+
# upper and lower bounds
|
|
112
325
|
else:
|
|
113
326
|
max_value = df[feature_name].max()
|
|
114
327
|
min_value = df[feature_name].min()
|
|
@@ -120,7 +333,9 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
120
333
|
def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
|
|
121
334
|
"""Sample 2500 examples from the dataframe without replacement."""
|
|
122
335
|
temp_df = df.copy()
|
|
123
|
-
sample_df = temp_df.sample(
|
|
336
|
+
sample_df = temp_df.sample(
|
|
337
|
+
n=2500, replace=False, random_state=3, axis=0
|
|
338
|
+
)
|
|
124
339
|
return sample_df
|
|
125
340
|
|
|
126
341
|
return (
|
|
@@ -140,18 +355,47 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
140
355
|
.astype("float32")
|
|
141
356
|
# normalize columns
|
|
142
357
|
.pipe(normalize_columns_bias)
|
|
143
|
-
# add multi index indicating which columns are corresponding
|
|
358
|
+
# add multi index indicating which columns are corresponding
|
|
359
|
+
# to input and output variables
|
|
144
360
|
.pipe(add_input_output_temperature)
|
|
145
361
|
# sample 2500 examples out of the dataset
|
|
146
362
|
.pipe(sample_2500_examples)
|
|
147
363
|
)
|
|
148
364
|
|
|
149
365
|
|
|
150
|
-
|
|
366
|
+
# pylint: disable-next=invalid-name
|
|
367
|
+
def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
368
|
+
"""
|
|
369
|
+
Preprocesses the given Family Income dataframe by applying a
|
|
370
|
+
series of transformations and constraints.
|
|
371
|
+
|
|
372
|
+
The function sequentially:
|
|
373
|
+
|
|
374
|
+
- Drops rows with missing values.
|
|
375
|
+
- Converts object columns to appropriate data types and
|
|
376
|
+
removes string columns.
|
|
377
|
+
- Removes certain unnecessary columns like
|
|
378
|
+
'Agricultural Household indicator' and related features.
|
|
379
|
+
- Adds labels to columns indicating whether they are
|
|
380
|
+
input or output variables.
|
|
381
|
+
- Normalizes the columns individually.
|
|
382
|
+
- Checks and removes rows that do not satisfy predefined constraints
|
|
383
|
+
(household income > expenses, food expenses > sub-expenses).
|
|
384
|
+
- Samples 2500 examples from the dataset without replacement.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
df (pd.DataFrame): The input Family Income dataframe containing
|
|
388
|
+
the data to be processed.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
pd.DataFrame: The processed dataframe after applying the
|
|
392
|
+
transformations and constraints.
|
|
393
|
+
"""
|
|
151
394
|
|
|
152
395
|
def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
153
|
-
"""Normalize the columns for the Family Income dataframe.
|
|
154
|
-
|
|
396
|
+
"""Normalize the columns for the Family Income dataframe.
|
|
397
|
+
This can also be applied to other dataframes because this
|
|
398
|
+
function normalizes all columns individually."""
|
|
155
399
|
# copy the dataframe
|
|
156
400
|
temp_df = df.copy()
|
|
157
401
|
# normalize each column
|
|
@@ -164,9 +408,12 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
164
408
|
return temp_df
|
|
165
409
|
|
|
166
410
|
def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
167
|
-
"""Check if all the constraints are satisfied for the dataframe
|
|
168
|
-
|
|
169
|
-
|
|
411
|
+
"""Check if all the constraints are satisfied for the dataframe
|
|
412
|
+
and remove the examples that do not satisfy the constraint.
|
|
413
|
+
This function only works for the Family Income dataset and the
|
|
414
|
+
constraints are that the household income is larger than all the
|
|
415
|
+
expenses and the food expense is larger than the sum of
|
|
416
|
+
the other (more detailed) food expenses.
|
|
170
417
|
"""
|
|
171
418
|
temp_df = df.copy()
|
|
172
419
|
# check that household income is larger than expenses in the output
|
|
@@ -174,7 +421,9 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
174
421
|
income_array = np.add(
|
|
175
422
|
np.multiply(
|
|
176
423
|
input_array[:, [0, 1]],
|
|
177
|
-
np.subtract(
|
|
424
|
+
np.subtract(
|
|
425
|
+
np.asarray([11815988, 9234485]), np.asarray([11285, 0])
|
|
426
|
+
),
|
|
178
427
|
),
|
|
179
428
|
np.asarray([11285, 0]),
|
|
180
429
|
)
|
|
@@ -244,28 +493,38 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
244
493
|
return satisfied_constraints_df
|
|
245
494
|
|
|
246
495
|
def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
|
|
247
|
-
"""Add a multiindex denoting if the column is
|
|
496
|
+
"""Add a multiindex denoting if the column is
|
|
497
|
+
an input or output variable."""
|
|
498
|
+
|
|
248
499
|
# copy the dataframe
|
|
249
500
|
temp_df = df.copy()
|
|
250
501
|
# extract all the column names
|
|
251
502
|
column_names = temp_df.columns.tolist()
|
|
252
|
-
# the 2nd-9th columns correspond to output variables and all
|
|
503
|
+
# the 2nd-9th columns correspond to output variables and all
|
|
504
|
+
# others to input variables. So make list of corresponding
|
|
505
|
+
# lengths of 'Input' and 'Output'
|
|
253
506
|
input_list_start = ["Input"]
|
|
254
507
|
input_list_end = ["Input"] * (len(column_names) - 9)
|
|
255
508
|
output_list = ["Output"] * 8
|
|
256
509
|
# concat both lists
|
|
257
510
|
input_output_list = input_list_start + output_list + input_list_end
|
|
258
|
-
# define multi index for attaching this 'Input' and
|
|
259
|
-
|
|
511
|
+
# define multi index for attaching this 'Input' and
|
|
512
|
+
# 'Output' list with the column names already existing
|
|
513
|
+
multiindex_bias = pd.MultiIndex.from_arrays(
|
|
514
|
+
[input_output_list, column_names]
|
|
515
|
+
)
|
|
260
516
|
# transpose such that index can be adjusted to multi index
|
|
261
517
|
new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
|
|
262
|
-
# transpose back such that columns are the same as
|
|
518
|
+
# transpose back such that columns are the same as
|
|
519
|
+
# before except with different labels
|
|
263
520
|
return new_df.transpose()
|
|
264
521
|
|
|
265
522
|
def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
|
|
266
523
|
"""Sample 2500 examples from the dataframe without replacement."""
|
|
267
524
|
temp_df = df.copy()
|
|
268
|
-
sample_df = temp_df.sample(
|
|
525
|
+
sample_df = temp_df.sample(
|
|
526
|
+
n=2500, replace=False, random_state=3, axis=0
|
|
527
|
+
)
|
|
269
528
|
return sample_df
|
|
270
529
|
|
|
271
530
|
return (
|
|
@@ -273,13 +532,17 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
273
532
|
df.dropna(how="any")
|
|
274
533
|
# convert object to fitting dtype
|
|
275
534
|
.convert_dtypes()
|
|
276
|
-
# remove all strings (no other dtypes are present
|
|
535
|
+
# remove all strings (no other dtypes are present
|
|
536
|
+
# except for integers and floats)
|
|
277
537
|
.select_dtypes(exclude=["string"])
|
|
278
538
|
# transform all numbers to same dtype
|
|
279
539
|
.astype("float32")
|
|
280
|
-
# drop column with label Agricultural Household indicator
|
|
540
|
+
# drop column with label Agricultural Household indicator
|
|
541
|
+
# because this is not really a numerical input but
|
|
542
|
+
# rather a categorical/classification
|
|
281
543
|
.drop(["Agricultural Household indicator"], axis=1, inplace=False)
|
|
282
|
-
# this column is dropped because it depends on
|
|
544
|
+
# this column is dropped because it depends on
|
|
545
|
+
# Agricultural Household indicator
|
|
283
546
|
.drop(["Crop Farming and Gardening expenses"], axis=1, inplace=False)
|
|
284
547
|
# use 8 output variables and 24 input variables
|
|
285
548
|
.drop(
|
|
@@ -308,3 +571,140 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
308
571
|
# sample 2500 examples
|
|
309
572
|
.pipe(sample_2500_examples)
|
|
310
573
|
)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def validate_type(name, value, expected_types, allow_none=False):
|
|
577
|
+
"""
|
|
578
|
+
Validate that a value is of the specified type(s).
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
name (str): Name of the argument for error messages.
|
|
582
|
+
value: Value to validate.
|
|
583
|
+
expected_types (type or tuple of types): Expected type(s) for the value.
|
|
584
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
585
|
+
Defaults to False.
|
|
586
|
+
|
|
587
|
+
Raises:
|
|
588
|
+
TypeError: If the value is not of the expected type(s).
|
|
589
|
+
"""
|
|
590
|
+
|
|
591
|
+
if value is None:
|
|
592
|
+
if not allow_none:
|
|
593
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
594
|
+
return
|
|
595
|
+
|
|
596
|
+
if not isinstance(value, expected_types):
|
|
597
|
+
raise TypeError(
|
|
598
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
599
|
+
f"Only values of type {str(expected_types)} are allowed."
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def validate_iterable(
|
|
604
|
+
name,
|
|
605
|
+
value,
|
|
606
|
+
expected_element_types,
|
|
607
|
+
allowed_iterables=(list, set),
|
|
608
|
+
allow_none=False,
|
|
609
|
+
):
|
|
610
|
+
"""
|
|
611
|
+
Validate that a value is an iterable (e.g., list, set) with elements of
|
|
612
|
+
the specified type(s).
|
|
613
|
+
|
|
614
|
+
Args:
|
|
615
|
+
name (str): Name of the argument for error messages.
|
|
616
|
+
value: Value to validate.
|
|
617
|
+
expected_element_types (type or tuple of types): Expected type(s)
|
|
618
|
+
for the elements.
|
|
619
|
+
allowed_iterables (tuple of types): Iterable types that are
|
|
620
|
+
allowed (default: list and set).
|
|
621
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
622
|
+
Defaults to False.
|
|
623
|
+
|
|
624
|
+
Raises:
|
|
625
|
+
TypeError: If the value is not an allowed iterable type or if
|
|
626
|
+
any element is not of the expected type(s).
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
if value is None:
|
|
630
|
+
if not allow_none:
|
|
631
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
632
|
+
return
|
|
633
|
+
|
|
634
|
+
if not isinstance(value, allowed_iterables):
|
|
635
|
+
raise TypeError(
|
|
636
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
637
|
+
f"Only values of type {str(allowed_iterables)} are allowed."
|
|
638
|
+
)
|
|
639
|
+
if not all(
|
|
640
|
+
isinstance(element, expected_element_types) for element in value
|
|
641
|
+
):
|
|
642
|
+
raise TypeError(
|
|
643
|
+
f"Invalid elements in {name} '{str(value)}'. "
|
|
644
|
+
f"Only elements of type {str(expected_element_types)} are allowed."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def validate_comparator_pytorch(name, value):
|
|
649
|
+
"""
|
|
650
|
+
Validate that a value is a callable PyTorch comparator function.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
name (str): Name of the argument for error messages.
|
|
654
|
+
value: Value to validate.
|
|
655
|
+
|
|
656
|
+
Raises:
|
|
657
|
+
TypeError: If the value is not callable or not a PyTorch comparator.
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
# List of valid PyTorch comparator functions
|
|
661
|
+
pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
|
|
662
|
+
|
|
663
|
+
# Check if value is callable and if it's one of
|
|
664
|
+
# the PyTorch comparator functions
|
|
665
|
+
if not callable(value):
|
|
666
|
+
raise TypeError(
|
|
667
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
668
|
+
"Only callable functions are allowed."
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
if value not in pytorch_comparators:
|
|
672
|
+
raise TypeError(
|
|
673
|
+
f"Argument {name} '{str(value)}' is not a valid PyTorch comparator "
|
|
674
|
+
"function. Only PyTorch functions like torch.gt, torch.lt, "
|
|
675
|
+
"torch.ge, torch.le are allowed."
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def validate_callable(name, value, allow_none=False):
|
|
680
|
+
"""
|
|
681
|
+
Validate that a value is callable function.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
name (str): Name of the argument for error messages.
|
|
685
|
+
value: Value to validate.
|
|
686
|
+
allow_none (bool): Whether to allow the value to be None.
|
|
687
|
+
Defaults to False.
|
|
688
|
+
|
|
689
|
+
Raises:
|
|
690
|
+
TypeError: If the value is not callable.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
if value is None:
|
|
694
|
+
if not allow_none:
|
|
695
|
+
raise TypeError(f"Argument {name} cannot be None.")
|
|
696
|
+
return
|
|
697
|
+
|
|
698
|
+
if not callable(value):
|
|
699
|
+
raise TypeError(
|
|
700
|
+
f"Argument {name} '{str(value)}' is not supported. "
|
|
701
|
+
"Only callable functions are allowed."
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def validate_loaders() -> None:
|
|
706
|
+
"""
|
|
707
|
+
TODO: implement function
|
|
708
|
+
"""
|
|
709
|
+
# TODO complete
|
|
710
|
+
pass
|