congrads 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/utils.py CHANGED
@@ -1,45 +1,45 @@
1
- """
2
- This module provides utility functions and classes for managing data,
3
- logging, and preprocessing in machine learning workflows. The functionalities
4
- include logging key-value pairs to CSV files, splitting datasets into
5
- training, validation, and test sets, preprocessing functions for various
6
- data sets and validator functions for type checking.
7
- """
1
+ """This module holds utility functions and classes for the congrads package."""
8
2
 
9
3
  import os
4
+ import random
10
5
 
11
6
  import numpy as np
12
7
  import pandas as pd
13
8
  import torch
14
- from torch import Generator
9
+ from torch import Generator, Tensor, argsort, cat, int32, unique
10
+ from torch.nn.modules.loss import _Loss
15
11
  from torch.utils.data import DataLoader, Dataset, random_split
16
12
 
17
13
 
18
14
  class CSVLogger:
19
- """
20
- A utility class for logging key-value pairs to a CSV file, organized by
21
- epochs. Supports merging with existing logs or overwriting them.
15
+ """A utility class for logging key-value pairs to a CSV file, organized by epochs.
16
+
17
+ Supports merging with existing logs or overwriting them.
22
18
 
23
19
  Args:
24
20
  file_path (str): The path to the CSV file for logging.
25
- overwrite (bool): If True, overwrites any existing file at
26
- the file_path.
27
- merge (bool): If True, merges new values with existing data
28
- in the file.
21
+ overwrite (bool): If True, overwrites any existing file at the file_path.
22
+ merge (bool): If True, merges new values with existing data in the file.
29
23
 
30
24
  Raises:
31
25
  ValueError: If both overwrite and merge are True.
32
- FileExistsError: If the file already exists and neither
33
- overwrite nor merge is True.
26
+ FileExistsError: If the file already exists and neither overwrite nor merge is True.
34
27
  """
35
28
 
36
- def __init__(
37
- self, file_path: str, overwrite: bool = False, merge: bool = True
38
- ):
39
- """
40
- Initializes the CSVLogger.
41
- """
29
+ def __init__(self, file_path: str, overwrite: bool = False, merge: bool = True):
30
+ """Initializes the CSVLogger.
31
+
32
+ Supports merging with existing logs or overwriting them.
33
+
34
+ Args:
35
+ file_path (str): The path to the CSV file for logging.
36
+ overwrite (optional, bool): If True, overwrites any existing file at the file_path. Defaults to False.
37
+ merge (optional, bool): If True, merges new values with existing data in the file. Defaults to True.
42
38
 
39
+ Raises:
40
+ ValueError: If both overwrite and merge are True.
41
+ FileExistsError: If the file already exists and neither overwrite nor merge is True.
42
+ """
43
43
  self.file_path = file_path
44
44
  self.values: dict[tuple[int, str], float] = {}
45
45
 
@@ -62,53 +62,43 @@ class CSVLogger:
62
62
  )
63
63
 
64
64
  def add_value(self, name: str, value: float, epoch: int):
65
- """
66
- Adds a value to the logger for a specific epoch and name.
65
+ """Adds a value to the logger for a specific epoch and name.
67
66
 
68
67
  Args:
69
68
  name (str): The name of the metric or value to log.
70
69
  value (float): The value to log.
71
70
  epoch (int): The epoch associated with the value.
72
71
  """
73
-
74
72
  self.values[epoch, name] = value
75
73
 
76
74
  def save(self):
77
- """
78
- Saves the logged values to the specified CSV file.
75
+ """Saves the logged values to the specified CSV file.
79
76
 
80
77
  If the file exists and merge is enabled, merges the current data
81
78
  with the existing file.
82
79
  """
83
-
84
80
  data = self.to_dataframe(self.values)
85
81
  data.to_csv(self.file_path, index=False)
86
82
 
87
83
  def load(self):
88
- """
89
- Loads data from the CSV file into the logger.
84
+ """Loads data from the CSV file into the logger.
90
85
 
91
86
  Converts the CSV data into the internal dictionary format for
92
87
  further updates or operations.
93
88
  """
94
-
95
89
  df = pd.read_csv(self.file_path)
96
90
  self.values = self.to_dict(df)
97
91
 
98
92
  @staticmethod
99
93
  def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
100
- """
101
- Converts a dictionary of values into a DataFrame.
94
+ """Converts a dictionary of values into a DataFrame.
102
95
 
103
96
  Args:
104
- values (dict[tuple[int, str], float]): A dictionary of values
105
- keyed by (epoch, name).
97
+ values (dict[tuple[int, str], float]): A dictionary of values keyed by (epoch, name).
106
98
 
107
99
  Returns:
108
- pd.DataFrame: A DataFrame where epochs are rows, names are
109
- columns, and values are the cell data.
100
+ pd.DataFrame: A DataFrame where epochs are rows, names are columns, and values are the cell data.
110
101
  """
111
-
112
102
  # Convert to a DataFrame
113
103
  df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
114
104
 
@@ -126,10 +116,7 @@ class CSVLogger:
126
116
 
127
117
  @staticmethod
128
118
  def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
129
- """
130
- Converts a DataFrame with epochs as rows and names as columns
131
- back into a dictionary of the format {(epoch, name): value}.
132
- """
119
+ """Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}."""
133
120
  # Set the epoch column as the index (if not already)
134
121
  df = df.set_index("epoch")
135
122
 
@@ -153,9 +140,7 @@ def split_data_loaders(
153
140
  test_size: float = 0.1,
154
141
  split_generator: Generator = None,
155
142
  ) -> tuple[DataLoader, DataLoader, DataLoader]:
156
- """
157
- Splits a dataset into training, validation, and test sets,
158
- and returns corresponding DataLoader objects.
143
+ """Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.
159
144
 
160
145
  Args:
161
146
  data (Dataset): The dataset to be split.
@@ -185,12 +170,9 @@ def split_data_loaders(
185
170
  ValueError: If the train_size, valid_size, and test_size are not
186
171
  between 0 and 1, or if their sum does not equal 1.
187
172
  """
188
-
189
173
  # Validate split sizes
190
174
  if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
191
- raise ValueError(
192
- "train_size, valid_size, and test_size must be between 0 and 1."
193
- )
175
+ raise ValueError("train_size, valid_size, and test_size must be between 0 and 1.")
194
176
  if not abs(train_size + valid_size + test_size - 1.0) < 1e-6:
195
177
  raise ValueError("train_size, valid_size, and test_size must sum to 1.")
196
178
 
@@ -220,11 +202,8 @@ def split_data_loaders(
220
202
  return train_generator, valid_generator, test_generator
221
203
 
222
204
 
223
- # pylint: disable-next=invalid-name
224
- def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
225
- """
226
- Preprocesses the given dataframe for bias correction by
227
- performing a series of transformations.
205
+ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
206
+ """Preprocesses the given dataframe for bias correction by performing a series of transformations.
228
207
 
229
208
  The function sequentially:
230
209
 
@@ -245,8 +224,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
245
224
  """
246
225
 
247
226
  def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
248
- """Transform the string that denotes the date to
249
- the datetime format in pandas."""
227
+ """Transform the string that denotes the date to the datetime format in pandas."""
250
228
  # make copy of dataframe
251
229
  df_temp = df.copy()
252
230
  # add new column at the front where the date string is
@@ -255,8 +233,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
255
233
  return df_temp
256
234
 
257
235
  def add_year(df: pd.DataFrame) -> pd.DataFrame:
258
- """Extract the year from the datetime cell and add it
259
- as a new column to the dataframe at the front."""
236
+ """Extract the year from the datetime cell and add it as a new column to the dataframe at the front."""
260
237
  # make copy of dataframe
261
238
  df_temp = df.copy()
262
239
  # extract year and add new column at the front containing these numbers
@@ -264,8 +241,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
264
241
  return df_temp
265
242
 
266
243
  def add_month(df: pd.DataFrame) -> pd.DataFrame:
267
- """Extract the month from the datetime cell and add it
268
- as a new column to the dataframe at the front."""
244
+ """Extract the month from the datetime cell and add it as a new column to the dataframe at the front."""
269
245
  # make copy of dataframe
270
246
  df_temp = df.copy()
271
247
  # extract month and add new column at index 1 containing these numbers
@@ -273,8 +249,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
273
249
  return df_temp
274
250
 
275
251
  def add_day(df: pd.DataFrame) -> pd.DataFrame:
276
- """Extract the day from the datetime cell and add it
277
- as a new column to the dataframe at the front."""
252
+ """Extract the day from the datetime cell and add it as a new column to the dataframe at the front."""
278
253
  # make copy of dataframe
279
254
  df_temp = df.copy()
280
255
  # extract day and add new column at index 2 containing these numbers
@@ -282,8 +257,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
282
257
  return df_temp
283
258
 
284
259
  def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
285
- """Add a multiindex denoting if the column
286
- is an input or output variable."""
260
+ """Add a multiindex denoting if the column is an input or output variable."""
287
261
  # copy the dataframe
288
262
  temp_df = df.copy()
289
263
  # extract all the column names
@@ -297,9 +271,7 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
297
271
  input_output_list = input_list + output_list
298
272
  # define multi index for attaching this 'Input' and 'Output' list with
299
273
  # the column names already existing
300
- multiindex_bias = pd.MultiIndex.from_arrays(
301
- [input_output_list, column_names]
302
- )
274
+ multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
303
275
  # transpose such that index can be adjusted to multi index
304
276
  new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
305
277
  # transpose back such that columns are the same as before
@@ -308,9 +280,11 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
308
280
 
309
281
  def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
310
282
  """Normalize the columns for the bias correction dataset.
283
+
311
284
  This is different from normalizing all the columns separately
312
285
  because the upper and lower bounds for the output variables
313
- are assumed to be the same."""
286
+ are assumed to be the same.
287
+ """
314
288
  # copy the dataframe
315
289
  temp_df = df.copy()
316
290
  # normalize each column
@@ -325,17 +299,13 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
325
299
  else:
326
300
  max_value = df[feature_name].max()
327
301
  min_value = df[feature_name].min()
328
- temp_df[feature_name] = (df[feature_name] - min_value) / (
329
- max_value - min_value
330
- )
302
+ temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
331
303
  return temp_df
332
304
 
333
305
  def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
334
306
  """Sample 2500 examples from the dataframe without replacement."""
335
307
  temp_df = df.copy()
336
- sample_df = temp_df.sample(
337
- n=2500, replace=False, random_state=3, axis=0
338
- )
308
+ sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
339
309
  return sample_df
340
310
 
341
311
  return (
@@ -363,11 +333,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
363
333
  )
364
334
 
365
335
 
366
- # pylint: disable-next=invalid-name
367
- def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
368
- """
369
- Preprocesses the given Family Income dataframe by applying a
370
- series of transformations and constraints.
336
+ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
337
+ """Preprocesses the given Family Income dataframe.
371
338
 
372
339
  The function sequentially:
373
340
 
@@ -393,27 +360,43 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
393
360
  """
394
361
 
395
362
  def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
396
- """Normalize the columns for the Family Income dataframe.
397
- This can also be applied to other dataframes because this
398
- function normalizes all columns individually."""
363
+ """Normalize each column of the dataframe independently.
364
+
365
+ This function scales each column to have values between 0 and 1
366
+ (or another standard normalization, depending on implementation),
367
+ making it suitable for numerical processing. While designed for
368
+ the Family Income dataset, it can be applied to any dataframe
369
+ with numeric columns.
370
+
371
+ Args:
372
+ df (pd.DataFrame): Input dataframe to normalize.
373
+
374
+ Returns:
375
+ pd.DataFrame: Dataframe with each column normalized independently.
376
+ """
399
377
  # copy the dataframe
400
378
  temp_df = df.copy()
401
379
  # normalize each column
402
380
  for feature_name in df.columns:
403
381
  max_value = df[feature_name].max()
404
382
  min_value = df[feature_name].min()
405
- temp_df[feature_name] = (df[feature_name] - min_value) / (
406
- max_value - min_value
407
- )
383
+ temp_df[feature_name] = (df[feature_name] - min_value) / (max_value - min_value)
408
384
  return temp_df
409
385
 
410
386
  def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
411
- """Check if all the constraints are satisfied for the dataframe
412
- and remove the examples that do not satisfy the constraint.
413
- This function only works for the Family Income dataset and the
414
- constraints are that the household income is larger than all the
415
- expenses and the food expense is larger than the sum of
416
- the other (more detailed) food expenses.
387
+ """Filter rows that violate income-related constraints.
388
+
389
+ This function is specific to the Family Income dataset. It removes rows
390
+ that do not satisfy the following constraints:
391
+ 1. Household income must be greater than all expenses.
392
+ 2. Food expense must be greater than the sum of detailed food expenses.
393
+
394
+ Args:
395
+ df (pd.DataFrame): Input dataframe containing income and expense data.
396
+
397
+ Returns:
398
+ pd.DataFrame: Filtered dataframe containing only rows that satisfy
399
+ all constraints.
417
400
  """
418
401
  temp_df = df.copy()
419
402
  # check that household income is larger than expenses in the output
@@ -421,9 +404,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
421
404
  income_array = np.add(
422
405
  np.multiply(
423
406
  input_array[:, [0, 1]],
424
- np.subtract(
425
- np.asarray([11815988, 9234485]), np.asarray([11285, 0])
426
- ),
407
+ np.subtract(np.asarray([11815988, 9234485]), np.asarray([11285, 0])),
427
408
  ),
428
409
  np.asarray([11285, 0]),
429
410
  )
@@ -483,9 +464,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
483
464
  food_mul_expense_array = expense_reduced_array[:, [1, 2, 3]]
484
465
  food_mul_expense_array_sum = np.sum(food_mul_expense_array, axis=1)
485
466
  food_expense_array = expense_reduced_array[:, 0]
486
- sanity_check_array = np.greater_equal(
487
- food_expense_array, food_mul_expense_array_sum
488
- )
467
+ sanity_check_array = np.greater_equal(food_expense_array, food_mul_expense_array_sum)
489
468
  drop_reduction["Unimportant"] = sanity_check_array.tolist()
490
469
  new_reduction = drop_reduction[drop_reduction.Unimportant]
491
470
  satisfied_constraints_df = new_reduction.drop("Unimportant", axis=1)
@@ -493,9 +472,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
493
472
  return satisfied_constraints_df
494
473
 
495
474
  def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
496
- """Add a multiindex denoting if the column is
497
- an input or output variable."""
498
-
475
+ """Add a multiindex denoting if the column is an input or output variable."""
499
476
  # copy the dataframe
500
477
  temp_df = df.copy()
501
478
  # extract all the column names
@@ -510,9 +487,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
510
487
  input_output_list = input_list_start + output_list + input_list_end
511
488
  # define multi index for attaching this 'Input' and
512
489
  # 'Output' list with the column names already existing
513
- multiindex_bias = pd.MultiIndex.from_arrays(
514
- [input_output_list, column_names]
515
- )
490
+ multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
516
491
  # transpose such that index can be adjusted to multi index
517
492
  new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
518
493
  # transpose back such that columns are the same as
@@ -522,9 +497,7 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
522
497
  def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
523
498
  """Sample 2500 examples from the dataframe without replacement."""
524
499
  temp_df = df.copy()
525
- sample_df = temp_df.sample(
526
- n=2500, replace=False, random_state=3, axis=0
527
- )
500
+ sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
528
501
  return sample_df
529
502
 
530
503
  return (
@@ -573,9 +546,90 @@ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
573
546
  )
574
547
 
575
548
 
576
- def validate_type(name, value, expected_types, allow_none=False):
549
+ def preprocess_AdultCensusIncome(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
550
+ """Preprocesses the Adult Census Income dataset for PyTorch ML.
551
+
552
+ Sequential steps:
553
+ - Drop rows with missing values.
554
+ - Encode categorical variables to integer labels.
555
+ - Map the target 'income' column to 0/1.
556
+ - Convert all data to float32.
557
+ - Add a multiindex to denote Input vs Output columns.
558
+
559
+ Args:
560
+ df (pd.DataFrame): Raw dataframe containing Adult Census Income data.
561
+
562
+ Returns:
563
+ pd.DataFrame: Preprocessed dataframe.
577
564
  """
578
- Validate that a value is of the specified type(s).
565
+
566
+ def drop_missing(df: pd.DataFrame) -> pd.DataFrame:
567
+ """Drop rows with any missing values."""
568
+ return df.dropna(how="any")
569
+
570
+ def drop_columns(df: pd.DataFrame) -> pd.DataFrame:
571
+ return df.drop(columns=["fnlwgt", "education.num"], errors="ignore")
572
+
573
+ def label_encode_column(series: pd.Series, col_name: str = None) -> pd.Series:
574
+ """Encode a pandas Series of categorical strings into integers."""
575
+ categories = series.dropna().unique().tolist()
576
+ cat_to_int = {cat: i for i, cat in enumerate(categories)}
577
+ if col_name:
578
+ print(f"Column '{col_name}' encoding:")
579
+ for cat, idx in cat_to_int.items():
580
+ print(f" {cat} -> {idx}")
581
+ return series.map(cat_to_int).astype(int)
582
+
583
+ def encode_categorical(df: pd.DataFrame) -> pd.DataFrame:
584
+ """Convert categorical string columns to integer labels using label_encode_column."""
585
+ df_temp = df.copy()
586
+ categorical_cols = [
587
+ "workclass",
588
+ "education",
589
+ "marital.status",
590
+ "occupation",
591
+ "relationship",
592
+ "race",
593
+ "sex",
594
+ "native.country",
595
+ ]
596
+ for col in categorical_cols:
597
+ df_temp[col] = label_encode_column(df_temp[col].astype(str), col_name=col)
598
+ return df_temp
599
+
600
+ def map_target(df: pd.DataFrame) -> pd.DataFrame:
601
+ """Map income column to 0 (<=50K) and 1 (>50K)."""
602
+ df_temp = df.copy()
603
+ df_temp["income"] = df_temp["income"].map({"<=50K": 0, ">50K": 1})
604
+ return df_temp
605
+
606
+ def convert_float32(df: pd.DataFrame) -> pd.DataFrame:
607
+ """Convert all data to float32 for PyTorch compatibility."""
608
+ return df.astype("float32")
609
+
610
+ def add_input_output_index(df: pd.DataFrame) -> pd.DataFrame:
611
+ """Add a multiindex indicating input and output columns."""
612
+ temp_df = df.copy()
613
+ column_names = temp_df.columns.tolist()
614
+ # Only the 'income' column is output
615
+ input_list = ["Input"] * (len(column_names) - 1)
616
+ output_list = ["Output"]
617
+ multiindex_list = input_list + output_list
618
+ multiindex = pd.MultiIndex.from_arrays([multiindex_list, column_names])
619
+ return pd.DataFrame(temp_df.to_numpy(), columns=multiindex)
620
+
621
+ return (
622
+ df.pipe(drop_missing)
623
+ .pipe(drop_columns)
624
+ .pipe(encode_categorical)
625
+ .pipe(map_target)
626
+ .pipe(convert_float32)
627
+ .pipe(add_input_output_index)
628
+ )
629
+
630
+
631
+ def validate_type(name, value, expected_types, allow_none=False):
632
+ """Validate that a value is of the specified type(s).
579
633
 
580
634
  Args:
581
635
  name (str): Name of the argument for error messages.
@@ -587,7 +641,6 @@ def validate_type(name, value, expected_types, allow_none=False):
587
641
  Raises:
588
642
  TypeError: If the value is not of the expected type(s).
589
643
  """
590
-
591
644
  if value is None:
592
645
  if not allow_none:
593
646
  raise TypeError(f"Argument {name} cannot be None.")
@@ -604,12 +657,11 @@ def validate_iterable(
604
657
  name,
605
658
  value,
606
659
  expected_element_types,
607
- allowed_iterables=(list, set),
660
+ allowed_iterables=(list, set, tuple),
661
+ allow_empty=False,
608
662
  allow_none=False,
609
663
  ):
610
- """
611
- Validate that a value is an iterable (e.g., list, set) with elements of
612
- the specified type(s).
664
+ """Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).
613
665
 
614
666
  Args:
615
667
  name (str): Name of the argument for error messages.
@@ -618,6 +670,7 @@ def validate_iterable(
618
670
  for the elements.
619
671
  allowed_iterables (tuple of types): Iterable types that are
620
672
  allowed (default: list and set).
673
+ allow_empty (bool): Whether to allow empty iterables. Defaults to False.
621
674
  allow_none (bool): Whether to allow the value to be None.
622
675
  Defaults to False.
623
676
 
@@ -625,20 +678,22 @@ def validate_iterable(
625
678
  TypeError: If the value is not an allowed iterable type or if
626
679
  any element is not of the expected type(s).
627
680
  """
628
-
629
681
  if value is None:
630
682
  if not allow_none:
631
683
  raise TypeError(f"Argument {name} cannot be None.")
632
684
  return
633
685
 
686
+ if len(value) == 0:
687
+ if not allow_empty:
688
+ raise TypeError(f"Argument {name} cannot be an empty iterable.")
689
+ return
690
+
634
691
  if not isinstance(value, allowed_iterables):
635
692
  raise TypeError(
636
693
  f"Argument {name} '{str(value)}' is not supported. "
637
694
  f"Only values of type {str(allowed_iterables)} are allowed."
638
695
  )
639
- if not all(
640
- isinstance(element, expected_element_types) for element in value
641
- ):
696
+ if not all(isinstance(element, expected_element_types) for element in value):
642
697
  raise TypeError(
643
698
  f"Invalid elements in {name} '{str(value)}'. "
644
699
  f"Only elements of type {str(expected_element_types)} are allowed."
@@ -646,8 +701,7 @@ def validate_iterable(
646
701
 
647
702
 
648
703
  def validate_comparator_pytorch(name, value):
649
- """
650
- Validate that a value is a callable PyTorch comparator function.
704
+ """Validate that a value is a callable PyTorch comparator function.
651
705
 
652
706
  Args:
653
707
  name (str): Name of the argument for error messages.
@@ -656,7 +710,6 @@ def validate_comparator_pytorch(name, value):
656
710
  Raises:
657
711
  TypeError: If the value is not callable or not a PyTorch comparator.
658
712
  """
659
-
660
713
  # List of valid PyTorch comparator functions
661
714
  pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
662
715
 
@@ -664,8 +717,7 @@ def validate_comparator_pytorch(name, value):
664
717
  # the PyTorch comparator functions
665
718
  if not callable(value):
666
719
  raise TypeError(
667
- f"Argument {name} '{str(value)}' is not supported. "
668
- "Only callable functions are allowed."
720
+ f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
669
721
  )
670
722
 
671
723
  if value not in pytorch_comparators:
@@ -677,8 +729,7 @@ def validate_comparator_pytorch(name, value):
677
729
 
678
730
 
679
731
  def validate_callable(name, value, allow_none=False):
680
- """
681
- Validate that a value is callable function.
732
+ """Validate that a value is callable function.
682
733
 
683
734
  Args:
684
735
  name (str): Name of the argument for error messages.
@@ -689,22 +740,339 @@ def validate_callable(name, value, allow_none=False):
689
740
  Raises:
690
741
  TypeError: If the value is not callable.
691
742
  """
692
-
693
743
  if value is None:
694
744
  if not allow_none:
695
745
  raise TypeError(f"Argument {name} cannot be None.")
696
746
  return
697
747
 
698
748
  if not callable(value):
749
+ raise TypeError(
750
+ f"Argument {name} '{str(value)}' is not supported. Only callable functions are allowed."
751
+ )
752
+
753
+
754
+ def validate_callable_iterable(
755
+ name,
756
+ value,
757
+ allowed_iterables=(list, set, tuple),
758
+ allow_none=False,
759
+ ):
760
+ """Validate that a value is an iterable containing only callable elements.
761
+
762
+ This function ensures that the given value is an iterable
763
+ (e.g., list or set and that all its elements are callable functions.
764
+
765
+ Args:
766
+ name (str): Name of the argument for error messages.
767
+ value: The value to validate.
768
+ allowed_iterables (tuple of types, optional): Iterable types that are
769
+ allowed. Defaults to (list, set).
770
+ allow_none (bool, optional): Whether to allow the value to be None.
771
+ Defaults to False.
772
+
773
+ Raises:
774
+ TypeError: If the value is not an allowed iterable type or if any
775
+ element is not callable.
776
+ """
777
+ if value is None:
778
+ if not allow_none:
779
+ raise TypeError(f"Argument {name} cannot be None.")
780
+ return
781
+
782
+ if not isinstance(value, allowed_iterables):
699
783
  raise TypeError(
700
784
  f"Argument {name} '{str(value)}' is not supported. "
701
- "Only callable functions are allowed."
785
+ f"Only values of type {str(allowed_iterables)} are allowed."
786
+ )
787
+
788
+ if not all(callable(element) for element in value):
789
+ raise TypeError(
790
+ f"Invalid elements in {name} '{str(value)}'. Only callable functions are allowed."
791
+ )
792
+
793
+
794
+ def validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader]):
795
+ """Validates that `loaders` is a tuple of three DataLoader instances.
796
+
797
+ Args:
798
+ name (str): The name of the parameter being validated.
799
+ loaders (tuple[DataLoader, DataLoader, DataLoader]): A tuple of
800
+ three DataLoader instances.
801
+
802
+ Raises:
803
+ TypeError: If `loaders` is not a tuple of three DataLoader
804
+ instances or contains invalid types.
805
+ """
806
+ if not isinstance(loaders, tuple) or len(loaders) != 3:
807
+ raise TypeError(f"{name} must be a tuple of three DataLoader instances.")
808
+
809
+ for i, loader in enumerate(loaders):
810
+ if not isinstance(loader, DataLoader):
811
+ raise TypeError(
812
+ f"{name}[{i}] must be an instance of DataLoader, got {type(loader).__name__}."
813
+ )
814
+
815
+
816
+ class ZeroLoss(_Loss):
817
+ """A loss function that always returns zero.
818
+
819
+ This custom loss function ignores the input and target tensors
820
+ and returns a constant zero loss, which can be useful for debugging
821
+ or when no meaningful loss computation is required.
822
+
823
+ Args:
824
+ reduction (str, optional): Specifies the reduction to apply to
825
+ the output. Defaults to "mean". Although specified, it has
826
+ no effect as the loss is always zero.
827
+ """
828
+
829
+ def __init__(self, reduction: str = "mean"):
830
+ """Initialize ZeroLoss with a specified reduction method.
831
+
832
+ Args:
833
+ reduction (str): Specifies the reduction to apply to the output. Defaults to "mean".
834
+ """
835
+ super().__init__(reduction=reduction)
836
+
837
+ def forward(self, predictions: Tensor, target: Tensor, **kwargs) -> torch.Tensor:
838
+ """Return a dummy loss of zero regardless of input and target."""
839
+ return (predictions * 0).sum()
840
+
841
+
842
+ def is_torch_loss(criterion) -> bool:
843
+ """Return True if the object is a PyTorch loss function."""
844
+ type_ = str(type(criterion)).split("'")[1]
845
+ parent = type_.rsplit(".", 1)[0]
846
+
847
+ return parent == "torch.nn.modules.loss"
848
+
849
+
850
+ def torch_loss_wrapper(criterion: _Loss) -> _Loss:
851
+ """Wraps a PyTorch loss function to handle the case where the loss function forward pass does not allow **kwargs.
852
+
853
+ Args:
854
+ criterion (_Loss): The PyTorch loss function to wrap.
855
+
856
+ Returns:
857
+ _Loss: The wrapped criterion that allows **kwargs in the forward pass.
858
+ """
859
+
860
+ class WrappedCriterion(_Loss):
861
+ def __init__(self, criterion):
862
+ super().__init__()
863
+ self.criterion = criterion
864
+
865
+ def forward(self, *args, **kwargs):
866
+ return self.criterion(*args)
867
+
868
+ return WrappedCriterion(criterion)
869
+
870
+
871
+ def process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor):
872
+ """Reorders input samples to support monotonicity checking.
873
+
874
+ Reorders input samples such that:
875
+ 1. Samples from the same run are grouped together.
876
+ 2. Within each run, samples are sorted chronologically.
877
+
878
+ Args:
879
+ data (Tensor): The input data.
880
+ ordering (Tensor): On what to order the data.
881
+ identifiers (Tensor): Identifiers specifying different runs.
882
+
883
+ Returns:
884
+ Tuple[Tensor, Tensor, Tensor]: Sorted data, ordering, and
885
+ identifiers.
886
+ """
887
+ # Step 1: Sort by run identifiers
888
+ sorted_indices = argsort(identifiers, stable=True, dim=0).reshape(-1)
889
+ data_sorted, ordering_sorted, identifiers_sorted = (
890
+ data[sorted_indices],
891
+ ordering[sorted_indices],
892
+ identifiers[sorted_indices],
893
+ )
894
+
895
+ # Step 2: Get unique runs and their counts
896
+ _, counts = unique(identifiers, sorted=False, return_counts=True)
897
+ counts = counts.to(int32) # Avoid repeated conversions
898
+
899
+ sorted_data, sorted_ordering, sorted_identifiers = [], [], []
900
+ index = 0 # Tracks the current batch element index
901
+
902
+ # Step 3: Process each run independently
903
+ for count in counts:
904
+ end = index + count
905
+ run_data, run_ordering, run_identifiers = (
906
+ data_sorted[index:end],
907
+ ordering_sorted[index:end],
908
+ identifiers_sorted[index:end],
702
909
  )
703
910
 
911
+ # Step 4: Sort within each run by time
912
+ time_sorted_indices = argsort(run_ordering, stable=True, dim=0).reshape(-1)
913
+ sorted_data.append(run_data[time_sorted_indices])
914
+ sorted_ordering.append(run_ordering[time_sorted_indices])
915
+ sorted_identifiers.append(run_identifiers[time_sorted_indices])
916
+
917
+ index = end # Move to next run
918
+
919
+ # Step 5: Concatenate results and return
920
+ return (
921
+ cat(sorted_data, dim=0),
922
+ cat(sorted_ordering, dim=0),
923
+ cat(sorted_identifiers, dim=0),
924
+ )
925
+
704
926
 
705
- def validate_loaders() -> None:
927
+ class DictDatasetWrapper(Dataset):
928
+ """A wrapper for PyTorch datasets that converts each sample into a dictionary.
929
+
930
+ This class takes any PyTorch dataset and returns its samples as dictionaries,
931
+ where each element of the original sample is mapped to a key. This is useful
932
+ for integration with the Congrads toolbox or other frameworks that expect
933
+ dictionary-formatted data.
934
+
935
+ Attributes:
936
+ base_dataset (Dataset): The underlying PyTorch dataset being wrapped.
937
+ field_names (list[str] | None): Names assigned to each field of a sample.
938
+ If None, default names like 'field0', 'field1', ... are generated.
939
+
940
+ Args:
941
+ base_dataset (Dataset): The PyTorch dataset to wrap.
942
+ field_names (list[str] | None, optional): Custom names for each field.
943
+ If provided, the list is truncated or extended to match the number
944
+ of elements in a sample. Defaults to None.
945
+
946
+ Example:
947
+ Wrapping a TensorDataset with custom field names:
948
+
949
+ >>> from torch.utils.data import TensorDataset
950
+ >>> import torch
951
+ >>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
952
+ >>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
953
+ >>> wrapped[0]
954
+ {'features': tensor([...]), 'label': tensor(1)}
955
+
956
+ Wrapping a built-in dataset like CIFAR10:
957
+
958
+ >>> from torchvision.datasets import CIFAR10
959
+ >>> from torchvision import transforms
960
+ >>> cifar = CIFAR10(
961
+ ... root="./data", train=True, download=True, transform=transforms.ToTensor()
962
+ ... )
963
+ >>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
964
+ >>> wrapped_cifar[0]
965
+ {'input': tensor([...]), 'output': tensor(6)}
706
966
  """
707
- TODO: implement function
967
+
968
+ def __init__(self, base_dataset: Dataset, field_names: list[str] | None = None):
969
+ """Initialize the DictDatasetWrapper.
970
+
971
+ Args:
972
+ base_dataset (Dataset): The PyTorch dataset to wrap.
973
+ field_names (list[str] | None, optional): Optional list of field names
974
+ for the dictionary output. Defaults to None, in which case
975
+ automatic names 'field0', 'field1', ... are generated.
976
+ """
977
+ self.base_dataset = base_dataset
978
+ self.field_names = field_names
979
+
980
+ def __getitem__(self, idx: int):
981
+ """Retrieve a sample from the dataset as a dictionary.
982
+
983
+ Each element in the original sample is mapped to a key in the dictionary.
984
+ If the sample is not a tuple or list, it is converted into a single-element
985
+ tuple. Numerical values (int or float) are automatically converted to tensors.
986
+
987
+ Args:
988
+ idx (int): Index of the sample to retrieve.
989
+
990
+ Returns:
991
+ dict: A dictionary mapping field names to sample values.
992
+ """
993
+ sample = self.base_dataset[idx]
994
+
995
+ # Ensure sample is always a tuple
996
+ if not isinstance(sample, (tuple, list)):
997
+ sample = (sample,)
998
+
999
+ n_fields = len(sample)
1000
+
1001
+ # Generate default field names if none are provided
1002
+ if self.field_names is None:
1003
+ names = [f"field{i}" for i in range(n_fields)]
1004
+ else:
1005
+ names = list(self.field_names)
1006
+ if len(names) < n_fields:
1007
+ names.extend([f"field{i}" for i in range(len(names), n_fields)])
1008
+ names = names[:n_fields] # truncate if too long
1009
+
1010
+ # Build dictionary
1011
+ out = {}
1012
+ for name, value in zip(names, sample, strict=False):
1013
+ if isinstance(value, (int, float)):
1014
+ value = torch.tensor(value)
1015
+ out[name] = value
1016
+
1017
+ return out
1018
+
1019
+ def __len__(self):
1020
+ """Return the number of samples in the dataset.
1021
+
1022
+ Returns:
1023
+ int: Length of the underlying dataset.
1024
+ """
1025
+ return len(self.base_dataset)
1026
+
1027
+
1028
+ class Seeder:
1029
+ """A deterministic seed manager for reproducible experiments.
1030
+
1031
+ This class provides a way to consistently generate pseudo-random
1032
+ seeds derived from a fixed base seed. It ensures that different
1033
+ libraries (Python's `random`, NumPy, and PyTorch) are initialized
1034
+ with reproducible seeds, making experiments deterministic across runs.
708
1035
  """
709
- # TODO complete
710
- pass
1036
+
1037
+ def __init__(self, base_seed: int):
1038
+ """Initialize the Seeder with a base seed.
1039
+
1040
+ Args:
1041
+ base_seed (int): The initial seed from which all subsequent
1042
+ pseudo-random seeds are deterministically derived.
1043
+ """
1044
+ self._rng = random.Random(base_seed)
1045
+
1046
+ def roll_seed(self) -> int:
1047
+ """Generate a new deterministic pseudo-random seed.
1048
+
1049
+ Each call returns an integer seed derived from the internal
1050
+ pseudo-random generator, which itself is initialized by the
1051
+ base seed.
1052
+
1053
+ Returns:
1054
+ int: A pseudo-random integer seed in the range [0, 2**31 - 1].
1055
+ """
1056
+ return self._rng.randint(0, 2**31 - 1)
1057
+
1058
+ def set_reproducible(self) -> None:
1059
+ """Configure global random states for reproducibility.
1060
+
1061
+ Seeds the following libraries with deterministically generated
1062
+ seeds based on the base seed:
1063
+ - Python's built-in `random`
1064
+ - NumPy's random number generator
1065
+ - PyTorch (CPU and GPU)
1066
+
1067
+ Also enforces deterministic behavior in PyTorch by:
1068
+ - Seeding all CUDA devices
1069
+ - Disabling CuDNN benchmarking
1070
+ - Enabling CuDNN deterministic mode
1071
+ """
1072
+ random.seed(self.roll_seed())
1073
+ np.random.seed(self.roll_seed())
1074
+ torch.manual_seed(self.roll_seed())
1075
+ torch.cuda.manual_seed_all(self.roll_seed())
1076
+
1077
+ torch.backends.cudnn.deterministic = True
1078
+ torch.backends.cudnn.benchmark = False