congrads 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/utils.py CHANGED
@@ -1,9 +1,148 @@
1
- import pandas as pd
1
+ """
2
+ This module provides utility functions and classes for managing data,
3
+ logging, and preprocessing in machine learning workflows. The functionalities
4
+ include logging key-value pairs to CSV files, splitting datasets into
5
+ training, validation, and test sets, preprocessing functions for various
6
+ data sets and validator functions for type checking.
7
+ """
8
+
9
+ import os
10
+
2
11
  import numpy as np
3
- from torch.utils.data import Dataset, random_split, DataLoader
12
+ import pandas as pd
13
+ import torch
14
+ from torch import Generator
15
+ from torch.utils.data import DataLoader, Dataset, random_split
16
+
17
+
18
+ class CSVLogger:
19
+ """
20
+ A utility class for logging key-value pairs to a CSV file, organized by
21
+ epochs. Supports merging with existing logs or overwriting them.
22
+
23
+ Args:
24
+ file_path (str): The path to the CSV file for logging.
25
+ overwrite (bool): If True, overwrites any existing file at
26
+ the file_path.
27
+ merge (bool): If True, merges new values with existing data
28
+ in the file.
29
+
30
+ Raises:
31
+ ValueError: If both overwrite and merge are True.
32
+ FileExistsError: If the file already exists and neither
33
+ overwrite nor merge is True.
34
+ """
35
+
36
+ def __init__(
37
+ self, file_path: str, overwrite: bool = False, merge: bool = True
38
+ ):
39
+ """
40
+ Initializes the CSVLogger.
41
+ """
42
+
43
+ self.file_path = file_path
44
+ self.values: dict[tuple[int, str], float] = {}
45
+
46
+ if merge and overwrite:
47
+ raise ValueError(
48
+ "The attributes overwrite and merge cannot be True at the "
49
+ "same time. Either specify overwrite=True or merge=True."
50
+ )
51
+
52
+ if not os.path.exists(file_path):
53
+ pass
54
+ elif merge:
55
+ self.load()
56
+ elif overwrite:
57
+ pass
58
+ else:
59
+ raise FileExistsError(
60
+ f"A CSV file already exists at {file_path}. Specify "
61
+ "CSVLogger(..., overwrite=True) to overwrite the file."
62
+ )
63
+
64
+ def add_value(self, name: str, value: float, epoch: int):
65
+ """
66
+ Adds a value to the logger for a specific epoch and name.
67
+
68
+ Args:
69
+ name (str): The name of the metric or value to log.
70
+ value (float): The value to log.
71
+ epoch (int): The epoch associated with the value.
72
+ """
73
+
74
+ self.values[epoch, name] = value
75
+
76
+ def save(self):
77
+ """
78
+ Saves the logged values to the specified CSV file.
79
+
80
+ If the file exists and merge is enabled, merges the current data
81
+ with the existing file.
82
+ """
4
83
 
84
+ data = self.to_dataframe(self.values)
85
+ data.to_csv(self.file_path, index=False)
5
86
 
6
- def splitDataLoaders(
87
+ def load(self):
88
+ """
89
+ Loads data from the CSV file into the logger.
90
+
91
+ Converts the CSV data into the internal dictionary format for
92
+ further updates or operations.
93
+ """
94
+
95
+ df = pd.read_csv(self.file_path)
96
+ self.values = self.to_dict(df)
97
+
98
+ @staticmethod
99
+ def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
100
+ """
101
+ Converts a dictionary of values into a DataFrame.
102
+
103
+ Args:
104
+ values (dict[tuple[int, str], float]): A dictionary of values
105
+ keyed by (epoch, name).
106
+
107
+ Returns:
108
+ pd.DataFrame: A DataFrame where epochs are rows, names are
109
+ columns, and values are the cell data.
110
+ """
111
+
112
+ # Convert to a DataFrame
113
+ df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
114
+
115
+ # Reset the index to separate epoch and name into columns
116
+ df.index = pd.MultiIndex.from_tuples(df.index, names=["epoch", "name"])
117
+ df = df.reset_index()
118
+
119
+ # Pivot the DataFrame so epochs are rows and names are columns
120
+ result = df.pivot(index="epoch", columns="name", values="value")
121
+
122
+ # Optional: Reset the column names for a cleaner look
123
+ result = result.reset_index().rename_axis(columns=None)
124
+
125
+ return result
126
+
127
+ @staticmethod
128
+ def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
129
+ """
130
+ Converts a DataFrame with epochs as rows and names as columns
131
+ back into a dictionary of the format {(epoch, name): value}.
132
+ """
133
+ # Set the epoch column as the index (if not already)
134
+ df = df.set_index("epoch")
135
+
136
+ # Stack the DataFrame to create a multi-index series
137
+ stacked = df.stack()
138
+
139
+ # Convert the multi-index series to a dictionary
140
+ result = stacked.to_dict()
141
+
142
+ return result
143
+
144
+
145
+ def split_data_loaders(
7
146
  data: Dataset,
8
147
  loader_args: dict = None,
9
148
  train_loader_args: dict = None,
@@ -12,7 +151,40 @@ def splitDataLoaders(
12
151
  train_size: float = 0.8,
13
152
  valid_size: float = 0.1,
14
153
  test_size: float = 0.1,
154
+ split_generator: Generator = None,
15
155
  ) -> tuple[DataLoader, DataLoader, DataLoader]:
156
+ """
157
+ Splits a dataset into training, validation, and test sets,
158
+ and returns corresponding DataLoader objects.
159
+
160
+ Args:
161
+ data (Dataset): The dataset to be split.
162
+ loader_args (dict, optional): Default DataLoader arguments, merges
163
+ with loader-specific arguments, overlapping keys from
164
+ loader-specific arguments are superseded.
165
+ train_loader_args (dict, optional): Training DataLoader arguments,
166
+ merges with `loader_args`, overriding overlapping keys.
167
+ valid_loader_args (dict, optional): Validation DataLoader arguments,
168
+ merges with `loader_args`, overriding overlapping keys.
169
+ test_loader_args (dict, optional): Test DataLoader arguments,
170
+ merges with `loader_args`, overriding overlapping keys.
171
+ train_size (float, optional): Proportion of data to be used for
172
+ training. Defaults to 0.8.
173
+ valid_size (float, optional): Proportion of data to be used for
174
+ validation. Defaults to 0.1.
175
+ test_size (float, optional): Proportion of data to be used for
176
+ testing. Defaults to 0.1.
177
+ split_generator (Generator, optional): Optional random seed generator
178
+ to control the splitting of the dataset.
179
+
180
+ Returns:
181
+ tuple: A tuple containing three DataLoader objects: one for the
182
+ training, validation and test set.
183
+
184
+ Raises:
185
+ ValueError: If the train_size, valid_size, and test_size are not
186
+ between 0 and 1, or if their sum does not equal 1.
187
+ """
16
188
 
17
189
  # Validate split sizes
18
190
  if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
@@ -23,19 +195,22 @@ def splitDataLoaders(
23
195
  raise ValueError("train_size, valid_size, and test_size must sum to 1.")
24
196
 
25
197
  # Perform the splits
26
- train_val_data, test_data = random_split(data, [1 - test_size, test_size])
198
+ train_val_data, test_data = random_split(
199
+ data, [1 - test_size, test_size], generator=split_generator
200
+ )
27
201
  train_data, valid_data = random_split(
28
202
  train_val_data,
29
203
  [
30
204
  train_size / (1 - test_size),
31
205
  valid_size / (1 - test_size),
32
206
  ],
207
+ generator=split_generator,
33
208
  )
34
209
 
35
210
  # Set default arguments for each loader
36
- train_loader_args = train_loader_args or loader_args or {}
37
- valid_loader_args = valid_loader_args or loader_args or {}
38
- test_loader_args = test_loader_args or loader_args or {}
211
+ train_loader_args = dict(loader_args or {}, **(train_loader_args or {}))
212
+ valid_loader_args = dict(loader_args or {}, **(valid_loader_args or {}))
213
+ test_loader_args = dict(loader_args or {}, **(test_loader_args or {}))
39
214
 
40
215
  # Create the DataLoaders
41
216
  train_generator = DataLoader(train_data, **train_loader_args)
@@ -45,18 +220,43 @@ def splitDataLoaders(
45
220
  return train_generator, valid_generator, test_generator
46
221
 
47
222
 
223
+ # pylint: disable-next=invalid-name
48
224
  def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
225
+ """
226
+ Preprocesses the given dataframe for bias correction by
227
+ performing a series of transformations.
228
+
229
+ The function sequentially:
230
+
231
+ - Drops rows with missing values.
232
+ - Converts a date string to datetime format and adds year, month,
233
+ and day columns.
234
+ - Normalizes the columns with specific logic for input and output variables.
235
+ - Adds a multi-index indicating which columns are input or output variables.
236
+ - Samples 2500 examples from the dataset without replacement.
237
+
238
+ Args:
239
+ df (pd.DataFrame): The input dataframe containing the data
240
+ to be processed.
241
+
242
+ Returns:
243
+ pd.DataFrame: The processed dataframe after applying
244
+ the transformations.
245
+ """
49
246
 
50
247
  def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
51
- """Transform the string that denotes the date to the datetime format in pandas."""
248
+ """Transform the string that denotes the date to
249
+ the datetime format in pandas."""
52
250
  # make copy of dataframe
53
251
  df_temp = df.copy()
54
- # add new column at the front where the date string is transformed to the datetime format
252
+ # add new column at the front where the date string is
253
+ # transformed to the datetime format
55
254
  df_temp.insert(0, "DateTransformed", pd.to_datetime(df_temp["Date"]))
56
255
  return df_temp
57
256
 
58
257
  def add_year(df: pd.DataFrame) -> pd.DataFrame:
59
- """Extract the year from the datetime cell and add it as a new column to the dataframe at the front."""
258
+ """Extract the year from the datetime cell and add it
259
+ as a new column to the dataframe at the front."""
60
260
  # make copy of dataframe
61
261
  df_temp = df.copy()
62
262
  # extract year and add new column at the front containing these numbers
@@ -64,7 +264,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
64
264
  return df_temp
65
265
 
66
266
  def add_month(df: pd.DataFrame) -> pd.DataFrame:
67
- """Extract the month from the datetime cell and add it as a new column to the dataframe at the front."""
267
+ """Extract the month from the datetime cell and add it
268
+ as a new column to the dataframe at the front."""
68
269
  # make copy of dataframe
69
270
  df_temp = df.copy()
70
271
  # extract month and add new column at index 1 containing these numbers
@@ -72,7 +273,8 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
72
273
  return df_temp
73
274
 
74
275
  def add_day(df: pd.DataFrame) -> pd.DataFrame:
75
- """Extract the day from the datetime cell and add it as a new column to the dataframe at the front."""
276
+ """Extract the day from the datetime cell and add it
277
+ as a new column to the dataframe at the front."""
76
278
  # make copy of dataframe
77
279
  df_temp = df.copy()
78
280
  # extract day and add new column at index 2 containing these numbers
@@ -80,35 +282,46 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
80
282
  return df_temp
81
283
 
82
284
  def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
83
- """Add a multiindex denoting if the column is an input or output variable."""
285
+ """Add a multiindex denoting if the column
286
+ is an input or output variable."""
84
287
  # copy the dataframe
85
288
  temp_df = df.copy()
86
289
  # extract all the column names
87
290
  column_names = temp_df.columns.tolist()
88
- # only the last 2 columns are output variables, all others are input variables. So make list of corresponding lengths of 'Input' and 'Output'
291
+ # only the last 2 columns are output variables, all others are input
292
+ # variables. So make list of corresponding lengths of
293
+ # 'Input' and 'Output'
89
294
  input_list = ["Input"] * (len(column_names) - 2)
90
295
  output_list = ["Output"] * 2
91
296
  # concat both lists
92
297
  input_output_list = input_list + output_list
93
- # define multi index for attaching this 'Input' and 'Output' list with the column names already existing
94
- multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
298
+ # define multi index for attaching this 'Input' and 'Output' list with
299
+ # the column names already existing
300
+ multiindex_bias = pd.MultiIndex.from_arrays(
301
+ [input_output_list, column_names]
302
+ )
95
303
  # transpose such that index can be adjusted to multi index
96
304
  new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
97
- # transpose back such that columns are the same as before except with different labels
305
+ # transpose back such that columns are the same as before
306
+ # except with different labels
98
307
  return new_df.transpose()
99
308
 
100
309
  def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
101
- """Normalize the columns for the bias correction dataset. This is different from normalizing all the columns separately because the
102
- upper and lower bounds for the output variables are assumed to be the same."""
310
+ """Normalize the columns for the bias correction dataset.
311
+ This is different from normalizing all the columns separately
312
+ because the upper and lower bounds for the output variables
313
+ are assumed to be the same."""
103
314
  # copy the dataframe
104
315
  temp_df = df.copy()
105
316
  # normalize each column
106
317
  for feature_name in df.columns:
107
- # the output columns are normalized using the same upper and lower bound for more efficient check of the inequality
318
+ # the output columns are normalized using the same upper and
319
+ # lower bound for more efficient check of the inequality
108
320
  if feature_name == "Next_Tmax" or feature_name == "Next_Tmin":
109
321
  max_value = 38.9
110
322
  min_value = 11.3
111
- # the input columns are normalized using their respective upper and lower bounds
323
+ # the input columns are normalized using their respective
324
+ # upper and lower bounds
112
325
  else:
113
326
  max_value = df[feature_name].max()
114
327
  min_value = df[feature_name].min()
@@ -120,7 +333,9 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
120
333
  def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
121
334
  """Sample 2500 examples from the dataframe without replacement."""
122
335
  temp_df = df.copy()
123
- sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
336
+ sample_df = temp_df.sample(
337
+ n=2500, replace=False, random_state=3, axis=0
338
+ )
124
339
  return sample_df
125
340
 
126
341
  return (
@@ -140,18 +355,47 @@ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
140
355
  .astype("float32")
141
356
  # normalize columns
142
357
  .pipe(normalize_columns_bias)
143
- # add multi index indicating which columns are corresponding to input and output variables
358
+ # add multi index indicating which columns are corresponding
359
+ # to input and output variables
144
360
  .pipe(add_input_output_temperature)
145
361
  # sample 2500 examples out of the dataset
146
362
  .pipe(sample_2500_examples)
147
363
  )
148
364
 
149
365
 
150
- def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
366
+ # pylint: disable-next=invalid-name
367
+ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
368
+ """
369
+ Preprocesses the given Family Income dataframe by applying a
370
+ series of transformations and constraints.
371
+
372
+ The function sequentially:
373
+
374
+ - Drops rows with missing values.
375
+ - Converts object columns to appropriate data types and
376
+ removes string columns.
377
+ - Removes certain unnecessary columns like
378
+ 'Agricultural Household indicator' and related features.
379
+ - Adds labels to columns indicating whether they are
380
+ input or output variables.
381
+ - Normalizes the columns individually.
382
+ - Checks and removes rows that do not satisfy predefined constraints
383
+ (household income > expenses, food expenses > sub-expenses).
384
+ - Samples 2500 examples from the dataset without replacement.
385
+
386
+ Args:
387
+ df (pd.DataFrame): The input Family Income dataframe containing
388
+ the data to be processed.
389
+
390
+ Returns:
391
+ pd.DataFrame: The processed dataframe after applying the
392
+ transformations and constraints.
393
+ """
151
394
 
152
395
  def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
153
- """Normalize the columns for the Family Income dataframe. This can also be applied to other dataframes because this function normalizes
154
- all columns individually."""
396
+ """Normalize the columns for the Family Income dataframe.
397
+ This can also be applied to other dataframes because this
398
+ function normalizes all columns individually."""
155
399
  # copy the dataframe
156
400
  temp_df = df.copy()
157
401
  # normalize each column
@@ -164,9 +408,12 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
164
408
  return temp_df
165
409
 
166
410
  def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
167
- """Check if all the constraints are satisfied for the dataframe and remove the examples that do not satisfy the constraint. This
168
- function only works for the Family Income dataset and the constraints are that the household income is larger than all the expenses
169
- and the food expense is larger than the sum of the other (more detailed) food expenses.
411
+ """Check if all the constraints are satisfied for the dataframe
412
+ and remove the examples that do not satisfy the constraint.
413
+ This function only works for the Family Income dataset and the
414
+ constraints are that the household income is larger than all the
415
+ expenses and the food expense is larger than the sum of
416
+ the other (more detailed) food expenses.
170
417
  """
171
418
  temp_df = df.copy()
172
419
  # check that household income is larger than expenses in the output
@@ -174,7 +421,9 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
174
421
  income_array = np.add(
175
422
  np.multiply(
176
423
  input_array[:, [0, 1]],
177
- np.subtract(np.asarray([11815988, 9234485]), np.asarray([11285, 0])),
424
+ np.subtract(
425
+ np.asarray([11815988, 9234485]), np.asarray([11285, 0])
426
+ ),
178
427
  ),
179
428
  np.asarray([11285, 0]),
180
429
  )
@@ -244,28 +493,38 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
244
493
  return satisfied_constraints_df
245
494
 
246
495
  def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
247
- """Add a multiindex denoting if the column is an input or output variable."""
496
+ """Add a multiindex denoting if the column is
497
+ an input or output variable."""
498
+
248
499
  # copy the dataframe
249
500
  temp_df = df.copy()
250
501
  # extract all the column names
251
502
  column_names = temp_df.columns.tolist()
252
- # the 2nd-9th columns correspond to output variables and all others to input variables. So make list of corresponding lengths of 'Input' and 'Output'
503
+ # the 2nd-9th columns correspond to output variables and all
504
+ # others to input variables. So make list of corresponding
505
+ # lengths of 'Input' and 'Output'
253
506
  input_list_start = ["Input"]
254
507
  input_list_end = ["Input"] * (len(column_names) - 9)
255
508
  output_list = ["Output"] * 8
256
509
  # concat both lists
257
510
  input_output_list = input_list_start + output_list + input_list_end
258
- # define multi index for attaching this 'Input' and 'Output' list with the column names already existing
259
- multiindex_bias = pd.MultiIndex.from_arrays([input_output_list, column_names])
511
+ # define multi index for attaching this 'Input' and
512
+ # 'Output' list with the column names already existing
513
+ multiindex_bias = pd.MultiIndex.from_arrays(
514
+ [input_output_list, column_names]
515
+ )
260
516
  # transpose such that index can be adjusted to multi index
261
517
  new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
262
- # transpose back such that columns are the same as before except with different labels
518
+ # transpose back such that columns are the same as
519
+ # before except with different labels
263
520
  return new_df.transpose()
264
521
 
265
522
  def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
266
523
  """Sample 2500 examples from the dataframe without replacement."""
267
524
  temp_df = df.copy()
268
- sample_df = temp_df.sample(n=2500, replace=False, random_state=3, axis=0)
525
+ sample_df = temp_df.sample(
526
+ n=2500, replace=False, random_state=3, axis=0
527
+ )
269
528
  return sample_df
270
529
 
271
530
  return (
@@ -273,13 +532,17 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
273
532
  df.dropna(how="any")
274
533
  # convert object to fitting dtype
275
534
  .convert_dtypes()
276
- # remove all strings (no other dtypes are present except for integers and floats)
535
+ # remove all strings (no other dtypes are present
536
+ # except for integers and floats)
277
537
  .select_dtypes(exclude=["string"])
278
538
  # transform all numbers to same dtype
279
539
  .astype("float32")
280
- # drop column with label Agricultural Household indicator because this is not really a numerical input but rather a categorical/classification
540
+ # drop column with label Agricultural Household indicator
541
+ # because this is not really a numerical input but
542
+ # rather a categorical/classification
281
543
  .drop(["Agricultural Household indicator"], axis=1, inplace=False)
282
- # this column is dropped because it depends on Agricultural Household indicator
544
+ # this column is dropped because it depends on
545
+ # Agricultural Household indicator
283
546
  .drop(["Crop Farming and Gardening expenses"], axis=1, inplace=False)
284
547
  # use 8 output variables and 24 input variables
285
548
  .drop(
@@ -308,3 +571,140 @@ def preprocess_FiniteIncome(df: pd.DataFrame) -> pd.DataFrame:
308
571
  # sample 2500 examples
309
572
  .pipe(sample_2500_examples)
310
573
  )
574
+
575
+
576
+ def validate_type(name, value, expected_types, allow_none=False):
577
+ """
578
+ Validate that a value is of the specified type(s).
579
+
580
+ Args:
581
+ name (str): Name of the argument for error messages.
582
+ value: Value to validate.
583
+ expected_types (type or tuple of types): Expected type(s) for the value.
584
+ allow_none (bool): Whether to allow the value to be None.
585
+ Defaults to False.
586
+
587
+ Raises:
588
+ TypeError: If the value is not of the expected type(s).
589
+ """
590
+
591
+ if value is None:
592
+ if not allow_none:
593
+ raise TypeError(f"Argument {name} cannot be None.")
594
+ return
595
+
596
+ if not isinstance(value, expected_types):
597
+ raise TypeError(
598
+ f"Argument {name} '{str(value)}' is not supported. "
599
+ f"Only values of type {str(expected_types)} are allowed."
600
+ )
601
+
602
+
603
+ def validate_iterable(
604
+ name,
605
+ value,
606
+ expected_element_types,
607
+ allowed_iterables=(list, set),
608
+ allow_none=False,
609
+ ):
610
+ """
611
+ Validate that a value is an iterable (e.g., list, set) with elements of
612
+ the specified type(s).
613
+
614
+ Args:
615
+ name (str): Name of the argument for error messages.
616
+ value: Value to validate.
617
+ expected_element_types (type or tuple of types): Expected type(s)
618
+ for the elements.
619
+ allowed_iterables (tuple of types): Iterable types that are
620
+ allowed (default: list and set).
621
+ allow_none (bool): Whether to allow the value to be None.
622
+ Defaults to False.
623
+
624
+ Raises:
625
+ TypeError: If the value is not an allowed iterable type or if
626
+ any element is not of the expected type(s).
627
+ """
628
+
629
+ if value is None:
630
+ if not allow_none:
631
+ raise TypeError(f"Argument {name} cannot be None.")
632
+ return
633
+
634
+ if not isinstance(value, allowed_iterables):
635
+ raise TypeError(
636
+ f"Argument {name} '{str(value)}' is not supported. "
637
+ f"Only values of type {str(allowed_iterables)} are allowed."
638
+ )
639
+ if not all(
640
+ isinstance(element, expected_element_types) for element in value
641
+ ):
642
+ raise TypeError(
643
+ f"Invalid elements in {name} '{str(value)}'. "
644
+ f"Only elements of type {str(expected_element_types)} are allowed."
645
+ )
646
+
647
+
648
+ def validate_comparator_pytorch(name, value):
649
+ """
650
+ Validate that a value is a callable PyTorch comparator function.
651
+
652
+ Args:
653
+ name (str): Name of the argument for error messages.
654
+ value: Value to validate.
655
+
656
+ Raises:
657
+ TypeError: If the value is not callable or not a PyTorch comparator.
658
+ """
659
+
660
+ # List of valid PyTorch comparator functions
661
+ pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
662
+
663
+ # Check if value is callable and if it's one of
664
+ # the PyTorch comparator functions
665
+ if not callable(value):
666
+ raise TypeError(
667
+ f"Argument {name} '{str(value)}' is not supported. "
668
+ "Only callable functions are allowed."
669
+ )
670
+
671
+ if value not in pytorch_comparators:
672
+ raise TypeError(
673
+ f"Argument {name} '{str(value)}' is not a valid PyTorch comparator "
674
+ "function. Only PyTorch functions like torch.gt, torch.lt, "
675
+ "torch.ge, torch.le are allowed."
676
+ )
677
+
678
+
679
+ def validate_callable(name, value, allow_none=False):
680
+ """
681
+ Validate that a value is callable function.
682
+
683
+ Args:
684
+ name (str): Name of the argument for error messages.
685
+ value: Value to validate.
686
+ allow_none (bool): Whether to allow the value to be None.
687
+ Defaults to False.
688
+
689
+ Raises:
690
+ TypeError: If the value is not callable.
691
+ """
692
+
693
+ if value is None:
694
+ if not allow_none:
695
+ raise TypeError(f"Argument {name} cannot be None.")
696
+ return
697
+
698
+ if not callable(value):
699
+ raise TypeError(
700
+ f"Argument {name} '{str(value)}' is not supported. "
701
+ "Only callable functions are allowed."
702
+ )
703
+
704
+
705
+ def validate_loaders() -> None:
706
+ """
707
+ TODO: implement function
708
+ """
709
+ # TODO complete
710
+ pass