congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/utils.py DELETED
@@ -1,1078 +0,0 @@
1
- """This module holds utility functions and classes for the congrads package."""
2
-
3
- import os
4
- import random
5
-
6
- import numpy as np
7
- import pandas as pd
8
- import torch
9
- from torch import Generator, Tensor, argsort, cat, int32, unique
10
- from torch.nn.modules.loss import _Loss
11
- from torch.utils.data import DataLoader, Dataset, random_split
12
-
13
-
14
- class CSVLogger:
15
- """A utility class for logging key-value pairs to a CSV file, organized by epochs.
16
-
17
- Supports merging with existing logs or overwriting them.
18
-
19
- Args:
20
- file_path (str): The path to the CSV file for logging.
21
- overwrite (bool): If True, overwrites any existing file at the file_path.
22
- merge (bool): If True, merges new values with existing data in the file.
23
-
24
- Raises:
25
- ValueError: If both overwrite and merge are True.
26
- FileExistsError: If the file already exists and neither overwrite nor merge is True.
27
- """
28
-
29
- def __init__(self, file_path: str, overwrite: bool = False, merge: bool = True):
30
- """Initializes the CSVLogger.
31
-
32
- Supports merging with existing logs or overwriting them.
33
-
34
- Args:
35
- file_path (str): The path to the CSV file for logging.
36
- overwrite (optional, bool): If True, overwrites any existing file at the file_path. Defaults to False.
37
- merge (optional, bool): If True, merges new values with existing data in the file. Defaults to True.
38
-
39
- Raises:
40
- ValueError: If both overwrite and merge are True.
41
- FileExistsError: If the file already exists and neither overwrite nor merge is True.
42
- """
43
- self.file_path = file_path
44
- self.values: dict[tuple[int, str], float] = {}
45
-
46
- if merge and overwrite:
47
- raise ValueError(
48
- "The attributes overwrite and merge cannot be True at the "
49
- "same time. Either specify overwrite=True or merge=True."
50
- )
51
-
52
- if not os.path.exists(file_path):
53
- pass
54
- elif merge:
55
- self.load()
56
- elif overwrite:
57
- pass
58
- else:
59
- raise FileExistsError(
60
- f"A CSV file already exists at {file_path}. Specify "
61
- "CSVLogger(..., overwrite=True) to overwrite the file."
62
- )
63
-
64
- def add_value(self, name: str, value: float, epoch: int):
65
- """Adds a value to the logger for a specific epoch and name.
66
-
67
- Args:
68
- name (str): The name of the metric or value to log.
69
- value (float): The value to log.
70
- epoch (int): The epoch associated with the value.
71
- """
72
- self.values[epoch, name] = value
73
-
74
- def save(self):
75
- """Saves the logged values to the specified CSV file.
76
-
77
- If the file exists and merge is enabled, merges the current data
78
- with the existing file.
79
- """
80
- data = self.to_dataframe(self.values)
81
- data.to_csv(self.file_path, index=False)
82
-
83
- def load(self):
84
- """Loads data from the CSV file into the logger.
85
-
86
- Converts the CSV data into the internal dictionary format for
87
- further updates or operations.
88
- """
89
- df = pd.read_csv(self.file_path)
90
- self.values = self.to_dict(df)
91
-
92
- @staticmethod
93
- def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
94
- """Converts a dictionary of values into a DataFrame.
95
-
96
- Args:
97
- values (dict[tuple[int, str], float]): A dictionary of values keyed by (epoch, name).
98
-
99
- Returns:
100
- pd.DataFrame: A DataFrame where epochs are rows, names are columns, and values are the cell data.
101
- """
102
- # Convert to a DataFrame
103
- df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
104
-
105
- # Reset the index to separate epoch and name into columns
106
- df.index = pd.MultiIndex.from_tuples(df.index, names=["epoch", "name"])
107
- df = df.reset_index()
108
-
109
- # Pivot the DataFrame so epochs are rows and names are columns
110
- result = df.pivot(index="epoch", columns="name", values="value")
111
-
112
- # Optional: Reset the column names for a cleaner look
113
- result = result.reset_index().rename_axis(columns=None)
114
-
115
- return result
116
-
117
- @staticmethod
118
- def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
119
- """Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}."""
120
- # Set the epoch column as the index (if not already)
121
- df = df.set_index("epoch")
122
-
123
- # Stack the DataFrame to create a multi-index series
124
- stacked = df.stack()
125
-
126
- # Convert the multi-index series to a dictionary
127
- result = stacked.to_dict()
128
-
129
- return result
130
-
131
-
132
- def split_data_loaders(
133
- data: Dataset,
134
- loader_args: dict = None,
135
- train_loader_args: dict = None,
136
- valid_loader_args: dict = None,
137
- test_loader_args: dict = None,
138
- train_size: float = 0.8,
139
- valid_size: float = 0.1,
140
- test_size: float = 0.1,
141
- split_generator: Generator = None,
142
- ) -> tuple[DataLoader, DataLoader, DataLoader]:
143
- """Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.
144
-
145
- Args:
146
- data (Dataset): The dataset to be split.
147
- loader_args (dict, optional): Default DataLoader arguments, merges
148
- with loader-specific arguments, overlapping keys from
149
- loader-specific arguments are superseded.
150
- train_loader_args (dict, optional): Training DataLoader arguments,
151
- merges with `loader_args`, overriding overlapping keys.
152
- valid_loader_args (dict, optional): Validation DataLoader arguments,
153
- merges with `loader_args`, overriding overlapping keys.
154
- test_loader_args (dict, optional): Test DataLoader arguments,
155
- merges with `loader_args`, overriding overlapping keys.
156
- train_size (float, optional): Proportion of data to be used for
157
- training. Defaults to 0.8.
158
- valid_size (float, optional): Proportion of data to be used for
159
- validation. Defaults to 0.1.
160
- test_size (float, optional): Proportion of data to be used for
161
- testing. Defaults to 0.1.
162
- split_generator (Generator, optional): Optional random seed generator
163
- to control the splitting of the dataset.
164
-
165
- Returns:
166
- tuple: A tuple containing three DataLoader objects: one for the
167
- training, validation and test set.
168
-
169
- Raises:
170
- ValueError: If the train_size, valid_size, and test_size are not
171
- between 0 and 1, or if their sum does not equal 1.
172
- """
173
- # Validate split sizes
174
- if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
175
- raise ValueError("train_size, valid_size, and test_size must be between 0 and 1.")
176
- if not abs(train_size + valid_size + test_size - 1.0) < 1e-6:
177
- raise ValueError("train_size, valid_size, and test_size must sum to 1.")
178
-
179
- # Perform the splits
180
- train_val_data, test_data = random_split(
181
- data, [1 - test_size, test_size], generator=split_generator
182
- )
183
- train_data, valid_data = random_split(
184
- train_val_data,
185
- [
186
- train_size / (1 - test_size),
187
- valid_size / (1 - test_size),
188
- ],
189
- generator=split_generator,
190
- )
191
-
192
- # Set default arguments for each loader
193
- train_loader_args = dict(loader_args or {}, **(train_loader_args or {}))
194
- valid_loader_args = dict(loader_args or {}, **(valid_loader_args or {}))
195
- test_loader_args = dict(loader_args or {}, **(test_loader_args or {}))
196
-
197
- # Create the DataLoaders
198
- train_generator = DataLoader(train_data, **train_loader_args)
199
- valid_generator = DataLoader(valid_data, **valid_loader_args)
200
- test_generator = DataLoader(test_data, **test_loader_args)
201
-
202
- return train_generator, valid_generator, test_generator
203
-
204
-
205
- def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
206
- """Preprocesses the given dataframe for bias correction by performing a series of transformations.
207
-
208
- The function sequentially:
209
-
210
- - Drops rows with missing values.
211
- - Converts a date string to datetime format and adds year, month,
212
- and day columns.
213
- - Normalizes the columns with specific logic for input and output variables.
214
- - Adds a multi-index indicating which columns are input or output variables.
215
- - Samples 2500 examples from the dataset without replacement.
216
-
217
- Args:
218
- df (pd.DataFrame): The input dataframe containing the data
219
- to be processed.
220
-
221
- Returns:
222
- pd.DataFrame: The processed dataframe after applying
223
- the transformations.
224
- """
225
-
226
- def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
227
- """Transform the string that denotes the date to the datetime format in pandas."""
228
- # make copy of dataframe
229
- df_temp = df.copy()
230
- # add new column at the front where the date string is
231
- # transformed to the datetime format
232
- df_temp.insert(0, "DateTransformed", pd.to_datetime(df_temp["Date"]))
233
- return df_temp
234
-
235
- def add_year(df: pd.DataFrame) -> pd.DataFrame:
236
- """Extract the year from the datetime cell and add it as a new column to the dataframe at the front."""
237
- # make copy of dataframe
238
- df_temp = df.copy()
239
- # extract year and add new column at the front containing these numbers
240
- df_temp.insert(0, "Year", df_temp["DateTransformed"].dt.year)
241
- return df_temp
242
-
243
- def add_month(df: pd.DataFrame) -> pd.DataFrame:
244
- """Extract the month from the datetime cell and add it as a new column to the dataframe at the front."""
245
- # make copy of dataframe
246
- df_temp = df.copy()
247
- # extract month and add new column at index 1 containing these numbers
248
- df_temp.insert(1, "Month", df_temp["DateTransformed"].dt.month)
249
- return df_temp
250
-
251
- def add_day(df: pd.DataFrame) -> pd.DataFrame:
252
- """Extract the day from the datetime cell and add it as a new column to the dataframe at the front."""
253
- # make copy of dataframe
254
- df_temp = df.copy()
255
- # extract day and add new column at index 2 containing these numbers
256
- df_temp.insert(2, "Day", df_temp["DateTransformed"].dt.day)
257
- return df_temp
258
-
259
- def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
260
- """Add a multiindex denoting if the column is an input or output variable."""
261
- # copy the dataframe
262
- temp_df = df.copy()
263
- # extract all the column names
264
- column_names = temp_df.columns.tolist()
265
- # only the last 2 columns are output variables, all others are input
266
- # variables. So make list of corresponding lengths of
267
- # 'Input' and 'Output'
268
- input_list = ["Input"] * (len(column_names) - 2)
269
- output_list = ["Output"] * 2
270
- # concat both lists
271
- input_output_list = input_list + output_list
272
- # define multi index for attaching this 'Input' and 'Output' list with
273
- # the column names already existing
274
- multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
275
- # transpose such that index can be adjusted to multi index
276
- new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
277
- # transpose back such that columns are the same as before
278
- # except with different labels
279
- return new_df.transpose()
280
-
281
- def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
282
- """Normalize the columns for the bias correction dataset.
283
-
284
- This is different from normalizing all the columns separately
285
- because the upper and lower bounds for the output variables
286
- are assumed to be the same.
287
- """
288
- # copy the dataframe
289
- temp_df = df.copy()
290
- # normalize each column
291
- for feature_name in df.columns:
292
- # the output columns are normalized using the same upper and
293
- # lower bound for more efficient check of the inequality
294
- if feature_name == "Next_Tmax" or feature_name == "Next_Tmin":
295
- max_value = 38.9
296
- min_value = 11.3
297
- # the input columns are normalized using their respective
298
- # upper and lower bounds
299
- else:
300
- max_value = df[feature_name].max()
301
- min_value = df[feature_name].min()
302
- temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
303
- return temp_df
304
-
305
- def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
306
- """Sample 2500 examples from the dataframe without replacement."""
307
- temp_df = df.copy()
308
- sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
309
- return sample_df
310
-
311
- return (
312
- # drop missing values
313
- df.dropna(how="any")
314
- # transform string date to datetime format
315
- .pipe(date_to_datetime)
316
- # add year as a single column
317
- .pipe(add_year)
318
- # add month as a single column
319
- .pipe(add_month)
320
- # add day as a single column
321
- .pipe(add_day)
322
- # remove original date string and the datetime format
323
- .drop(["Date", "DateTransformed"], axis=1, inplace=False)
324
- # convert all numbers to float32
325
- .astype("float32")
326
- # normalize columns
327
- .pipe(normalize_columns_bias)
328
- # add multi index indicating which columns are corresponding
329
- # to input and output variables
330
- .pipe(add_input_output_temperature)
331
- # sample 2500 examples out of the dataset
332
- .pipe(sample_2500_examples)
333
- )
334
-
335
-
336
- def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
337
- """Preprocesses the given Family Income dataframe.
338
-
339
- The function sequentially:
340
-
341
- - Drops rows with missing values.
342
- - Converts object columns to appropriate data types and
343
- removes string columns.
344
- - Removes certain unnecessary columns like
345
- 'Agricultural Household indicator' and related features.
346
- - Adds labels to columns indicating whether they are
347
- input or output variables.
348
- - Normalizes the columns individually.
349
- - Checks and removes rows that do not satisfy predefined constraints
350
- (household income > expenses, food expenses > sub-expenses).
351
- - Samples 2500 examples from the dataset without replacement.
352
-
353
- Args:
354
- df (pd.DataFrame): The input Family Income dataframe containing
355
- the data to be processed.
356
-
357
- Returns:
358
- pd.DataFrame: The processed dataframe after applying the
359
- transformations and constraints.
360
- """
361
-
362
- def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
363
- """Normalize each column of the dataframe independently.
364
-
365
- This function scales each column to have values between 0 and 1
366
- (or another standard normalization, depending on implementation),
367
- making it suitable for numerical processing. While designed for
368
- the Family Income dataset, it can be applied to any dataframe
369
- with numeric columns.
370
-
371
- Args:
372
- df (pd.DataFrame): Input dataframe to normalize.
373
-
374
- Returns:
375
- pd.DataFrame: Dataframe with each column normalized independently.
376
- """
377
- # copy the dataframe
378
- temp_df = df.copy()
379
- # normalize each column
380
- for feature_name in df.columns:
381
- max_value = df[feature_name].max()
382
- min_value = df[feature_name].min()
383
- temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
384
- return temp_df
385
-
386
- def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
387
- """Filter rows that violate income-related constraints.
388
-
389
- This function is specific to the Family Income dataset. It removes rows
390
- that do not satisfy the following constraints:
391
- 1. Household income must be greater than all expenses.
392
- 2. Food expense must be greater than the sum of detailed food expenses.
393
-
394
- Args:
395
- df (pd.DataFrame): Input dataframe containing income and expense data.
396
-
397
- Returns:
398
- pd.DataFrame: Filtered dataframe containing only rows that satisfy
399
- all constraints.
400
- """
401
- temp_df = df.copy()
402
- # check that household income is larger than expenses in the output
403
- input_array = temp_df["Input"].to_numpy()
404
- income_array = np.add(
405
- np.multiply(
406
- input_array[:, [0, 1]],
407
- np.subtract(np.asarray([11815988, 9234485]), np.asarray([11285, 0])),
408
- ),
409
- np.asarray([11285, 0]),
410
- )
411
- expense_array = temp_df["Output"].to_numpy()
412
- expense_array = np.add(
413
- np.multiply(
414
- expense_array,
415
- np.subtract(
416
- np.asarray(
417
- [
418
- 791848,
419
- 437467,
420
- 140992,
421
- 74800,
422
- 2188560,
423
- 1049275,
424
- 149940,
425
- 731000,
426
- ]
427
- ),
428
- np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
429
- ),
430
- ),
431
- np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
432
- )
433
- expense_array_without_dup = expense_array[:, [0, 4, 5, 6, 7]]
434
- sum_expenses = np.sum(expense_array_without_dup, axis=1)
435
- total_income = np.sum(income_array, axis=1)
436
- sanity_check_array = np.greater_equal(total_income, sum_expenses)
437
- temp_df["Unimportant"] = sanity_check_array.tolist()
438
- reduction = temp_df[temp_df.Unimportant]
439
- drop_reduction = reduction.drop("Unimportant", axis=1)
440
-
441
- # check that the food expense is larger than all the sub expenses
442
- expense_reduced_array = drop_reduction["Output"].to_numpy()
443
- expense_reduced_array = np.add(
444
- np.multiply(
445
- expense_reduced_array,
446
- np.subtract(
447
- np.asarray(
448
- [
449
- 791848,
450
- 437467,
451
- 140992,
452
- 74800,
453
- 2188560,
454
- 1049275,
455
- 149940,
456
- 731000,
457
- ]
458
- ),
459
- np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
460
- ),
461
- ),
462
- np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
463
- )
464
- food_mul_expense_array = expense_reduced_array[:, [1, 2, 3]]
465
- food_mul_expense_array_sum = np.sum(food_mul_expense_array, axis=1)
466
- food_expense_array = expense_reduced_array[:, 0]
467
- sanity_check_array = np.greater_equal(food_expense_array, food_mul_expense_array_sum)
468
- drop_reduction["Unimportant"] = sanity_check_array.tolist()
469
- new_reduction = drop_reduction[drop_reduction.Unimportant]
470
- satisfied_constraints_df = new_reduction.drop("Unimportant", axis=1)
471
-
472
- return satisfied_constraints_df
473
-
474
- def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
475
- """Add a multiindex denoting if the column is an input or output variable."""
476
- # copy the dataframe
477
- temp_df = df.copy()
478
- # extract all the column names
479
- column_names = temp_df.columns.tolist()
480
- # the 2nd-9th columns correspond to output variables and all
481
- # others to input variables. So make list of corresponding
482
- # lengths of 'Input' and 'Output'
483
- input_list_start = ["Input"]
484
- input_list_end = ["Input"] * (len(column_names) - 9)
485
- output_list = ["Output"] * 8
486
- # concat both lists
487
- input_output_list = input_list_start + output_list + input_list_end
488
- # define multi index for attaching this 'Input' and
489
- # 'Output' list with the column names already existing
490
- multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
491
- # transpose such that index can be adjusted to multi index
492
- new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
493
- # transpose back such that columns are the same as
494
- # before except with different labels
495
- return new_df.transpose()
496
-
497
- def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
498
- """Sample 2500 examples from the dataframe without replacement."""
499
- temp_df = df.copy()
500
- sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
501
- return sample_df
502
-
503
- return (
504
- # drop missing values
505
- df.dropna(how="any")
506
- # convert object to fitting dtype
507
- .convert_dtypes()
508
- # remove all strings (no other dtypes are present
509
- # except for integers and floats)
510
- .select_dtypes(exclude=["string"])
511
- # transform all numbers to same dtype
512
- .astype("float32")
513
- # drop column with label Agricultural Household indicator
514
- # because this is not really a numerical input but
515
- # rather a categorical/classification
516
- .drop(["Agricultural Household indicator"], axis=1, inplace=False)
517
- # this column is dropped because it depends on
518
- # Agricultural Household indicator
519
- .drop(["Crop Farming and Gardening expenses"], axis=1, inplace=False)
520
- # use 8 output variables and 24 input variables
521
- .drop(
522
- [
523
- "Total Rice Expenditure",
524
- "Total Fish and marine products Expenditure",
525
- "Fruit Expenditure",
526
- "Restaurant and hotels Expenditure",
527
- "Alcoholic Beverages Expenditure",
528
- "Tobacco Expenditure",
529
- "Clothing, Footwear and Other Wear Expenditure",
530
- "Imputed House Rental Value",
531
- "Transportation Expenditure",
532
- "Miscellaneous Goods and Services Expenditure",
533
- "Special Occasions Expenditure",
534
- ],
535
- axis=1,
536
- inplace=False,
537
- )
538
- # add input and output labels to each column
539
- .pipe(add_input_output_family_income)
540
- # normalize all the columns
541
- .pipe(normalize_columns_income)
542
- # remove all datapoints that do not satisfy the constraints
543
- .pipe(check_constraints_income)
544
- # sample 2500 examples
545
- .pipe(sample_2500_examples)
546
- )
547
-
548
-
549
- def preprocess_AdultCensusIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
550
- """Preprocesses the Adult Census Income dataset for PyTorch ML.
551
-
552
- Sequential steps:
553
- - Drop rows with missing values.
554
- - Encode categorical variables to integer labels.
555
- - Map the target 'income' column to 0/1.
556
- - Convert all data to float32.
557
- - Add a multiindex to denote Input vs Output columns.
558
-
559
- Args:
560
- df (pd.DataFrame): Raw dataframe containing Adult Census Income data.
561
-
562
- Returns:
563
- pd.DataFrame: Preprocessed dataframe.
564
- """
565
-
566
- def drop_missing(df: pd.DataFrame) -> pd.DataFrame:
567
- """Drop rows with any missing values."""
568
- return df.dropna(how="any")
569
-
570
- def drop_columns(df: pd.DataFrame) -> pd.DataFrame:
571
- return df.drop(columns=["fnlwgt", "education.num"], errors="ignore")
572
-
573
- def label_encode_column(series: pd.Series, col_name: str = None) -> pd.Series:
574
- """Encode a pandas Series of categorical strings into integers."""
575
- categories = series.dropna().unique().tolist()
576
- cat_to_int = {cat: i for i, cat in enumerate(categories)}
577
- if col_name:
578
- print(f"Column '{col_name}' encoding:")
579
- for cat, idx in cat_to_int.items():
580
- print(f" {cat} -> {idx}")
581
- return series.map(cat_to_int).astype(int)
582
-
583
- def encode_categorical(df: pd.DataFrame) -> pd.DataFrame:
584
- """Convert categorical string columns to integer labels using label_encode_column."""
585
- df_temp = df.copy()
586
- categorical_cols = [
587
- "workclass",
588
- "education",
589
- "marital.status",
590
- "occupation",
591
- "relationship",
592
- "race",
593
- "sex",
594
- "native.country",
595
- ]
596
- for col in categorical_cols:
597
- df_temp[col] = label_encode_column(df_temp[col].astype(str), col_name=col)
598
- return df_temp
599
-
600
- def map_target(df: pd.DataFrame) -> pd.DataFrame:
601
- """Map income column to 0 (<=50K) and 1 (>50K)."""
602
- df_temp = df.copy()
603
- df_temp["income"] = df_temp["income"].map({"<=50K": 0, ">50K": 1})
604
- return df_temp
605
-
606
- def convert_float32(df: pd.DataFrame) -> pd.DataFrame:
607
- """Convert all data to float32 for PyTorch compatibility."""
608
- return df.astype("float32")
609
-
610
- def add_input_output_index(df: pd.DataFrame) -> pd.DataFrame:
611
- """Add a multiindex indicating input and output columns."""
612
- temp_df = df.copy()
613
- column_names = temp_df.columns.tolist()
614
- # Only the 'income' column is output
615
- input_list = ["Input"] * (len(column_names) - 1)
616
- output_list = ["Output"]
617
- multiindex_list = input_list + output_list
618
- multiindex = pd.MultiIndex.from_arrays([multiindex_list, column_names])
619
- return pd.DataFrame(temp_df.to_numpy(), columns=multiindex)
620
-
621
- return (
622
- df.pipe(drop_missing)
623
- .pipe(drop_columns)
624
- .pipe(encode_categorical)
625
- .pipe(map_target)
626
- .pipe(convert_float32)
627
- .pipe(add_input_output_index)
628
- )
629
-
630
-
631
- def validate_type(name, value, expected_types, allow_none=False):
632
- """Validate that a value is of the specified type(s).
633
-
634
- Args:
635
- name (str): Name of the argument for error messages.
636
- value: Value to validate.
637
- expected_types (type or tuple of types): Expected type(s) for the value.
638
- allow_none (bool): Whether to allow the value to be None.
639
- Defaults to False.
640
-
641
- Raises:
642
- TypeError: If the value is not of the expected type(s).
643
- """
644
- if value is None:
645
- if not allow_none:
646
- raise TypeError(f"Argument {name} cannot be None.")
647
- return
648
-
649
- if not isinstance(value, expected_types):
650
- raise TypeError(
651
- f"Argument {name} '{str(value)}' is not supported. "
652
- f"Only values of type {str(expected_types)} are allowed."
653
- )
654
-
655
-
656
- def validate_iterable(
657
- name,
658
- value,
659
- expected_element_types,
660
- allowed_iterables=(list, set, tuple),
661
- allow_empty=False,
662
- allow_none=False,
663
- ):
664
- """Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).
665
-
666
- Args:
667
- name (str): Name of the argument for error messages.
668
- value: Value to validate.
669
- expected_element_types (type or tuple of types): Expected type(s)
670
- for the elements.
671
- allowed_iterables (tuple of types): Iterable types that are
672
- allowed (default: list and set).
673
- allow_empty (bool): Whether to allow empty iterables. Defaults to False.
674
- allow_none (bool): Whether to allow the value to be None.
675
- Defaults to False.
676
-
677
- Raises:
678
- TypeError: If the value is not an allowed iterable type or if
679
- any element is not of the expected type(s).
680
- """
681
- if value is None:
682
- if not allow_none:
683
- raise TypeError(f"Argument {name} cannot be None.")
684
- return
685
-
686
- if len(value) == 0:
687
- if not allow_empty:
688
- raise TypeError(f"Argument {name} cannot be an empty iterable.")
689
- return
690
-
691
- if not isinstance(value, allowed_iterables):
692
- raise TypeError(
693
- f"Argument {name} '{str(value)}' is not supported. "
694
- f"Only values of type {str(allowed_iterables)} are allowed."
695
- )
696
- if not all(isinstance(element, expected_element_types) for element in value):
697
- raise TypeError(
698
- f"Invalid elements in {name} '{str(value)}'. "
699
- f"Only elements of type {str(expected_element_types)} are allowed."
700
- )
701
-
702
-
703
- def validate_comparator_pytorch(name, value):
704
- """Validate that a value is a callable PyTorch comparator function.
705
-
706
- Args:
707
- name (str): Name of the argument for error messages.
708
- value: Value to validate.
709
-
710
- Raises:
711
- TypeError: If the value is not callable or not a PyTorch comparator.
712
- """
713
- # List of valid PyTorch comparator functions
714
- pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
715
-
716
- # Check if value is callable and if it's one of
717
- # the PyTorch comparator functions
718
- if not callable(value):
719
- raise TypeError(
720
- f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
721
- )
722
-
723
- if value not in pytorch_comparators:
724
- raise TypeError(
725
- f"Argument {name} '{str(value)}' is not a valid PyTorch comparator "
726
- "function. Only PyTorch functions like torch.gt, torch.lt, "
727
- "torch.ge, torch.le are allowed."
728
- )
729
-
730
-
731
- def validate_callable(name, value, allow_none=False):
732
- """Validate that a value is callable function.
733
-
734
- Args:
735
- name (str): Name of the argument for error messages.
736
- value: Value to validate.
737
- allow_none (bool): Whether to allow the value to be None.
738
- Defaults to False.
739
-
740
- Raises:
741
- TypeError: If the value is not callable.
742
- """
743
- if value is None:
744
- if not allow_none:
745
- raise TypeError(f"Argument {name} cannot be None.")
746
- return
747
-
748
- if not callable(value):
749
- raise TypeError(
750
- f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
751
- )
752
-
753
-
754
- def validate_callable_iterable(
755
- name,
756
- value,
757
- allowed_iterables=(list, set, tuple),
758
- allow_none=False,
759
- ):
760
- """Validate that a value is an iterable containing only callable elements.
761
-
762
- This function ensures that the given value is an iterable
763
- (e.g., list or set and that all its elements are callable functions.
764
-
765
- Args:
766
- name (str): Name of the argument for error messages.
767
- value: The value to validate.
768
- allowed_iterables (tuple of types, optional): Iterable types that are
769
- allowed. Defaults to (list, set).
770
- allow_none (bool, optional): Whether to allow the value to be None.
771
- Defaults to False.
772
-
773
- Raises:
774
- TypeError: If the value is not an allowed iterable type or if any
775
- element is not callable.
776
- """
777
- if value is None:
778
- if not allow_none:
779
- raise TypeError(f"Argument {name} cannot be None.")
780
- return
781
-
782
- if not isinstance(value, allowed_iterables):
783
- raise TypeError(
784
- f"Argument {name} '{str(value)}' is not supported. "
785
- f"Only values of type {str(allowed_iterables)} are allowed."
786
- )
787
-
788
- if not all(callable(element) for element in value):
789
- raise TypeError(
790
- f"Invalid elements in {name} '{str(value)}'. Only callable functions are allowed."
791
- )
792
-
793
-
794
- def validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader]):
795
- """Validates that `loaders` is a tuple of three DataLoader instances.
796
-
797
- Args:
798
- name (str): The name of the parameter being validated.
799
- loaders (tuple[DataLoader, DataLoader, DataLoader]): A tuple of
800
- three DataLoader instances.
801
-
802
- Raises:
803
- TypeError: If `loaders` is not a tuple of three DataLoader
804
- instances or contains invalid types.
805
- """
806
- if not isinstance(loaders, tuple) or len(loaders) != 3:
807
- raise TypeError(f"{name} must be a tuple of three DataLoader instances.")
808
-
809
- for i, loader in enumerate(loaders):
810
- if not isinstance(loader, DataLoader):
811
- raise TypeError(
812
- f"{name}[{i}] must be an instance of DataLoader, got {type(loader).__name__}."
813
- )
814
-
815
-
816
- class ZeroLoss(_Loss):
817
- """A loss function that always returns zero.
818
-
819
- This custom loss function ignores the input and target tensors
820
- and returns a constant zero loss, which can be useful for debugging
821
- or when no meaningful loss computation is required.
822
-
823
- Args:
824
- reduction (str, optional): Specifies the reduction to apply to
825
- the output. Defaults to "mean". Although specified, it has
826
- no effect as the loss is always zero.
827
- """
828
-
829
- def __init__(self, reduction: str = "mean"):
830
- """Initialize ZeroLoss with a specified reduction method.
831
-
832
- Args:
833
- reduction (str): Specifies the reduction to apply to the output. Defaults to "mean".
834
- """
835
- super().__init__(reduction=reduction)
836
-
837
- def forward(self, predictions: Tensor, target: Tensor, **kwargs) -> torch.Tensor:
838
- """Return a dummy loss of zero regardless of input and target."""
839
- return (predictions * 0).sum()
840
-
841
-
842
- def is_torch_loss(criterion) -> bool:
843
- """Return True if the object is a PyTorch loss function."""
844
- type_ = str(type(criterion)).split("'")[1]
845
- parent = type_.rsplit(".", 1)[0]
846
-
847
- return parent == "torch.nn.modules.loss"
848
-
849
-
850
- def torch_loss_wrapper(criterion: _Loss) -> _Loss:
851
- """Wraps a PyTorch loss function to handle the case where the loss function forward pass does not allow **kwargs.
852
-
853
- Args:
854
- criterion (_Loss): The PyTorch loss function to wrap.
855
-
856
- Returns:
857
- _Loss: The wrapped criterion that allows **kwargs in the forward pass.
858
- """
859
-
860
- class WrappedCriterion(_Loss):
861
- def __init__(self, criterion):
862
- super().__init__()
863
- self.criterion = criterion
864
-
865
- def forward(self, *args, **kwargs):
866
- return self.criterion(*args)
867
-
868
- return WrappedCriterion(criterion)
869
-
870
-
871
- def process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor):
872
- """Reorders input samples to support monotonicity checking.
873
-
874
- Reorders input samples such that:
875
- 1. Samples from the same run are grouped together.
876
- 2. Within each run, samples are sorted chronologically.
877
-
878
- Args:
879
- data (Tensor): The input data.
880
- ordering (Tensor): On what to order the data.
881
- identifiers (Tensor): Identifiers specifying different runs.
882
-
883
- Returns:
884
- Tuple[Tensor, Tensor, Tensor]: Sorted data, ordering, and
885
- identifiers.
886
- """
887
- # Step 1: Sort by run identifiers
888
- sorted_indices = argsort(identifiers, stable=True, dim=0).reshape(-1)
889
- data_sorted, ordering_sorted, identifiers_sorted = (
890
- data[sorted_indices],
891
- ordering[sorted_indices],
892
- identifiers[sorted_indices],
893
- )
894
-
895
- # Step 2: Get unique runs and their counts
896
- _, counts = unique(identifiers, sorted=False, return_counts=True)
897
- counts = counts.to(int32) # Avoid repeated conversions
898
-
899
- sorted_data, sorted_ordering, sorted_identifiers = [], [], []
900
- index = 0 # Tracks the current batch element index
901
-
902
- # Step 3: Process each run independently
903
- for count in counts:
904
- end = index + count
905
- run_data, run_ordering, run_identifiers = (
906
- data_sorted[index:end],
907
- ordering_sorted[index:end],
908
- identifiers_sorted[index:end],
909
- )
910
-
911
- # Step 4: Sort within each run by time
912
- time_sorted_indices = argsort(run_ordering, stable=True, dim=0).reshape(-1)
913
- sorted_data.append(run_data[time_sorted_indices])
914
- sorted_ordering.append(run_ordering[time_sorted_indices])
915
- sorted_identifiers.append(run_identifiers[time_sorted_indices])
916
-
917
- index = end # Move to next run
918
-
919
- # Step 5: Concatenate results and return
920
- return (
921
- cat(sorted_data, dim=0),
922
- cat(sorted_ordering, dim=0),
923
- cat(sorted_identifiers, dim=0),
924
- )
925
-
926
-
927
- class DictDatasetWrapper(Dataset):
928
- """A wrapper for PyTorch datasets that converts each sample into a dictionary.
929
-
930
- This class takes any PyTorch dataset and returns its samples as dictionaries,
931
- where each element of the original sample is mapped to a key. This is useful
932
- for integration with the Congrads toolbox or other frameworks that expect
933
- dictionary-formatted data.
934
-
935
- Attributes:
936
- base_dataset (Dataset): The underlying PyTorch dataset being wrapped.
937
- field_names (list[str] | None): Names assigned to each field of a sample.
938
- If None, default names like 'field0', 'field1', ... are generated.
939
-
940
- Args:
941
- base_dataset (Dataset): The PyTorch dataset to wrap.
942
- field_names (list[str] | None, optional): Custom names for each field.
943
- If provided, the list is truncated or extended to match the number
944
- of elements in a sample. Defaults to None.
945
-
946
- Example:
947
- Wrapping a TensorDataset with custom field names:
948
-
949
- >>> from torch.utils.data import TensorDataset
950
- >>> import torch
951
- >>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
952
- >>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
953
- >>> wrapped[0]
954
- {'features': tensor([...]), 'label': tensor(1)}
955
-
956
- Wrapping a built-in dataset like CIFAR10:
957
-
958
- >>> from torchvision.datasets import CIFAR10
959
- >>> from torchvision import transforms
960
- >>> cifar = CIFAR10(
961
- ... root="./data", train=True, download=True, transform=transforms.ToTensor()
962
- ... )
963
- >>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
964
- >>> wrapped_cifar[0]
965
- {'input': tensor([...]), 'output': tensor(6)}
966
- """
967
-
968
- def __init__(self, base_dataset: Dataset, field_names: list[str] | None = None):
969
- """Initialize the DictDatasetWrapper.
970
-
971
- Args:
972
- base_dataset (Dataset): The PyTorch dataset to wrap.
973
- field_names (list[str] | None, optional): Optional list of field names
974
- for the dictionary output. Defaults to None, in which case
975
- automatic names 'field0', 'field1', ... are generated.
976
- """
977
- self.base_dataset = base_dataset
978
- self.field_names = field_names
979
-
980
- def __getitem__(self, idx: int):
981
- """Retrieve a sample from the dataset as a dictionary.
982
-
983
- Each element in the original sample is mapped to a key in the dictionary.
984
- If the sample is not a tuple or list, it is converted into a single-element
985
- tuple. Numerical values (int or float) are automatically converted to tensors.
986
-
987
- Args:
988
- idx (int): Index of the sample to retrieve.
989
-
990
- Returns:
991
- dict: A dictionary mapping field names to sample values.
992
- """
993
- sample = self.base_dataset[idx]
994
-
995
- # Ensure sample is always a tuple
996
- if not isinstance(sample, (tuple, list)):
997
- sample = (sample,)
998
-
999
- n_fields = len(sample)
1000
-
1001
- # Generate default field names if none are provided
1002
- if self.field_names is None:
1003
- names = [f"field{i}" for i in range(n_fields)]
1004
- else:
1005
- names = list(self.field_names)
1006
- if len(names) < n_fields:
1007
- names.extend([f"field{i}" for i in range(len(names), n_fields)])
1008
- names = names[:n_fields] # truncate if too long
1009
-
1010
- # Build dictionary
1011
- out = {}
1012
- for name, value in zip(names, sample, strict=False):
1013
- if isinstance(value, (int, float)):
1014
- value = torch.tensor(value)
1015
- out[name] = value
1016
-
1017
- return out
1018
-
1019
- def __len__(self):
1020
- """Return the number of samples in the dataset.
1021
-
1022
- Returns:
1023
- int: Length of the underlying dataset.
1024
- """
1025
- return len(self.base_dataset)
1026
-
1027
-
1028
- class Seeder:
1029
- """A deterministic seed manager for reproducible experiments.
1030
-
1031
- This class provides a way to consistently generate pseudo-random
1032
- seeds derived from a fixed base seed. It ensures that different
1033
- libraries (Python's `random`, NumPy, and PyTorch) are initialized
1034
- with reproducible seeds, making experiments deterministic across runs.
1035
- """
1036
-
1037
- def __init__(self, base_seed: int):
1038
- """Initialize the Seeder with a base seed.
1039
-
1040
- Args:
1041
- base_seed (int): The initial seed from which all subsequent
1042
- pseudo-random seeds are deterministically derived.
1043
- """
1044
- self._rng = random.Random(base_seed)
1045
-
1046
- def roll_seed(self) -> int:
1047
- """Generate a new deterministic pseudo-random seed.
1048
-
1049
- Each call returns an integer seed derived from the internal
1050
- pseudo-random generator, which itself is initialized by the
1051
- base seed.
1052
-
1053
- Returns:
1054
- int: A pseudo-random integer seed in the range [0, 2**31 - 1].
1055
- """
1056
- return self._rng.randint(0, 2**31 - 1)
1057
-
1058
- def set_reproducible(self) -> None:
1059
- """Configure global random states for reproducibility.
1060
-
1061
- Seeds the following libraries with deterministically generated
1062
- seeds based on the base seed:
1063
- - Python's built-in `random`
1064
- - NumPy's random number generator
1065
- - PyTorch (CPU and GPU)
1066
-
1067
- Also enforces deterministic behavior in PyTorch by:
1068
- - Seeding all CUDA devices
1069
- - Disabling CuDNN benchmarking
1070
- - Enabling CuDNN deterministic mode
1071
- """
1072
- random.seed(self.roll_seed())
1073
- np.random.seed(self.roll_seed())
1074
- torch.manual_seed(self.roll_seed())
1075
- torch.cuda.manual_seed_all(self.roll_seed())
1076
-
1077
- torch.backends.cudnn.deterministic = True
1078
- torch.backends.cudnn.benchmark = False