congrads 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/utils.py ADDED
@@ -0,0 +1,710 @@
1
+ """
2
+ This module provides utility functions and classes for managing data,
3
+ logging, and preprocessing in machine learning workflows. The functionalities
4
+ include logging key-value pairs to CSV files, splitting datasets into
5
+ training, validation, and test sets, preprocessing functions for various
6
+ data sets and validator functions for type checking.
7
+ """
8
+
9
+ import os
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ from torch import Generator
15
+ from torch.utils.data import DataLoader, Dataset, random_split
16
+
17
+
18
+ class CSVLogger:
19
+ """
20
+ A utility class for logging key-value pairs to a CSV file, organized by
21
+ epochs. Supports merging with existing logs or overwriting them.
22
+
23
+ Args:
24
+ file_path (str): The path to the CSV file for logging.
25
+ overwrite (bool): If True, overwrites any existing file at
26
+ the file_path.
27
+ merge (bool): If True, merges new values with existing data
28
+ in the file.
29
+
30
+ Raises:
31
+ ValueError: If both overwrite and merge are True.
32
+ FileExistsError: If the file already exists and neither
33
+ overwrite nor merge is True.
34
+ """
35
+
36
+ def __init__(
37
+ self, file_path: str, overwrite: bool = False, merge: bool = True
38
+ ):
39
+ """
40
+ Initializes the CSVLogger.
41
+ """
42
+
43
+ self.file_path = file_path
44
+ self.values: dict[tuple[int, str], float] = {}
45
+
46
+ if merge and overwrite:
47
+ raise ValueError(
48
+ "The attributes overwrite and merge cannot be True at the "
49
+ "same time. Either specify overwrite=True or merge=True."
50
+ )
51
+
52
+ if not os.path.exists(file_path):
53
+ pass
54
+ elif merge:
55
+ self.load()
56
+ elif overwrite:
57
+ pass
58
+ else:
59
+ raise FileExistsError(
60
+ f"A CSV file already exists at {file_path}. Specify "
61
+ "CSVLogger(..., overwrite=True) to overwrite the file."
62
+ )
63
+
64
+ def add_value(self, name: str, value: float, epoch: int):
65
+ """
66
+ Adds a value to the logger for a specific epoch and name.
67
+
68
+ Args:
69
+ name (str): The name of the metric or value to log.
70
+ value (float): The value to log.
71
+ epoch (int): The epoch associated with the value.
72
+ """
73
+
74
+ self.values[epoch, name] = value
75
+
76
+ def save(self):
77
+ """
78
+ Saves the logged values to the specified CSV file.
79
+
80
+ If the file exists and merge is enabled, merges the current data
81
+ with the existing file.
82
+ """
83
+
84
+ data = self.to_dataframe(self.values)
85
+ data.to_csv(self.file_path, index=False)
86
+
87
+ def load(self):
88
+ """
89
+ Loads data from the CSV file into the logger.
90
+
91
+ Converts the CSV data into the internal dictionary format for
92
+ further updates or operations.
93
+ """
94
+
95
+ df = pd.read_csv(self.file_path)
96
+ self.values = self.to_dict(df)
97
+
98
+ @staticmethod
99
+ def to_dataframe(values: dict[tuple[int, str], float]) -> pd.DataFrame:
100
+ """
101
+ Converts a dictionary of values into a DataFrame.
102
+
103
+ Args:
104
+ values (dict[tuple[int, str], float]): A dictionary of values
105
+ keyed by (epoch, name).
106
+
107
+ Returns:
108
+ pd.DataFrame: A DataFrame where epochs are rows, names are
109
+ columns, and values are the cell data.
110
+ """
111
+
112
+ # Convert to a DataFrame
113
+ df = pd.DataFrame.from_dict(values, orient="index", columns=["value"])
114
+
115
+ # Reset the index to separate epoch and name into columns
116
+ df.index = pd.MultiIndex.from_tuples(df.index, names=["epoch", "name"])
117
+ df = df.reset_index()
118
+
119
+ # Pivot the DataFrame so epochs are rows and names are columns
120
+ result = df.pivot(index="epoch", columns="name", values="value")
121
+
122
+ # Optional: Reset the column names for a cleaner look
123
+ result = result.reset_index().rename_axis(columns=None)
124
+
125
+ return result
126
+
127
+ @staticmethod
128
+ def to_dict(df: pd.DataFrame) -> dict[tuple[int, str], float]:
129
+ """
130
+ Converts a DataFrame with epochs as rows and names as columns
131
+ back into a dictionary of the format {(epoch, name): value}.
132
+ """
133
+ # Set the epoch column as the index (if not already)
134
+ df = df.set_index("epoch")
135
+
136
+ # Stack the DataFrame to create a multi-index series
137
+ stacked = df.stack()
138
+
139
+ # Convert the multi-index series to a dictionary
140
+ result = stacked.to_dict()
141
+
142
+ return result
143
+
144
+
145
+ def split_data_loaders(
146
+ data: Dataset,
147
+ loader_args: dict = None,
148
+ train_loader_args: dict = None,
149
+ valid_loader_args: dict = None,
150
+ test_loader_args: dict = None,
151
+ train_size: float = 0.8,
152
+ valid_size: float = 0.1,
153
+ test_size: float = 0.1,
154
+ split_generator: Generator = None,
155
+ ) -> tuple[DataLoader, DataLoader, DataLoader]:
156
+ """
157
+ Splits a dataset into training, validation, and test sets,
158
+ and returns corresponding DataLoader objects.
159
+
160
+ Args:
161
+ data (Dataset): The dataset to be split.
162
+ loader_args (dict, optional): Default DataLoader arguments, merges
163
+ with loader-specific arguments, overlapping keys from
164
+ loader-specific arguments are superseded.
165
+ train_loader_args (dict, optional): Training DataLoader arguments,
166
+ merges with `loader_args`, overriding overlapping keys.
167
+ valid_loader_args (dict, optional): Validation DataLoader arguments,
168
+ merges with `loader_args`, overriding overlapping keys.
169
+ test_loader_args (dict, optional): Test DataLoader arguments,
170
+ merges with `loader_args`, overriding overlapping keys.
171
+ train_size (float, optional): Proportion of data to be used for
172
+ training. Defaults to 0.8.
173
+ valid_size (float, optional): Proportion of data to be used for
174
+ validation. Defaults to 0.1.
175
+ test_size (float, optional): Proportion of data to be used for
176
+ testing. Defaults to 0.1.
177
+ split_generator (Generator, optional): Optional random seed generator
178
+ to control the splitting of the dataset.
179
+
180
+ Returns:
181
+ tuple: A tuple containing three DataLoader objects: one for the
182
+ training, validation and test set.
183
+
184
+ Raises:
185
+ ValueError: If the train_size, valid_size, and test_size are not
186
+ between 0 and 1, or if their sum does not equal 1.
187
+ """
188
+
189
+ # Validate split sizes
190
+ if not (0 < train_size < 1 and 0 < valid_size < 1 and 0 < test_size < 1):
191
+ raise ValueError(
192
+ "train_size, valid_size, and test_size must be between 0 and 1."
193
+ )
194
+ if not abs(train_size + valid_size + test_size - 1.0) < 1e-6:
195
+ raise ValueError("train_size, valid_size, and test_size must sum to 1.")
196
+
197
+ # Perform the splits
198
+ train_val_data, test_data = random_split(
199
+ data, [1 - test_size, test_size], generator=split_generator
200
+ )
201
+ train_data, valid_data = random_split(
202
+ train_val_data,
203
+ [
204
+ train_size / (1 - test_size),
205
+ valid_size / (1 - test_size),
206
+ ],
207
+ generator=split_generator,
208
+ )
209
+
210
+ # Set default arguments for each loader
211
+ train_loader_args = dict(loader_args or {}, **(train_loader_args or {}))
212
+ valid_loader_args = dict(loader_args or {}, **(valid_loader_args or {}))
213
+ test_loader_args = dict(loader_args or {}, **(test_loader_args or {}))
214
+
215
+ # Create the DataLoaders
216
+ train_generator = DataLoader(train_data, **train_loader_args)
217
+ valid_generator = DataLoader(valid_data, **valid_loader_args)
218
+ test_generator = DataLoader(test_data, **test_loader_args)
219
+
220
+ return train_generator, valid_generator, test_generator
221
+
222
+
223
+ # pylint: disable-next=invalid-name
224
+ def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame:
225
+ """
226
+ Preprocesses the given dataframe for bias correction by
227
+ performing a series of transformations.
228
+
229
+ The function sequentially:
230
+
231
+ - Drops rows with missing values.
232
+ - Converts a date string to datetime format and adds year, month,
233
+ and day columns.
234
+ - Normalizes the columns with specific logic for input and output variables.
235
+ - Adds a multi-index indicating which columns are input or output variables.
236
+ - Samples 2500 examples from the dataset without replacement.
237
+
238
+ Args:
239
+ df (pd.DataFrame): The input dataframe containing the data
240
+ to be processed.
241
+
242
+ Returns:
243
+ pd.DataFrame: The processed dataframe after applying
244
+ the transformations.
245
+ """
246
+
247
+ def date_to_datetime(df: pd.DataFrame) -> pd.DataFrame:
248
+ """Transform the string that denotes the date to
249
+ the datetime format in pandas."""
250
+ # make copy of dataframe
251
+ df_temp = df.copy()
252
+ # add new column at the front where the date string is
253
+ # transformed to the datetime format
254
+ df_temp.insert(0, "DateTransformed", pd.to_datetime(df_temp["Date"]))
255
+ return df_temp
256
+
257
+ def add_year(df: pd.DataFrame) -> pd.DataFrame:
258
+ """Extract the year from the datetime cell and add it
259
+ as a new column to the dataframe at the front."""
260
+ # make copy of dataframe
261
+ df_temp = df.copy()
262
+ # extract year and add new column at the front containing these numbers
263
+ df_temp.insert(0, "Year", df_temp["DateTransformed"].dt.year)
264
+ return df_temp
265
+
266
+ def add_month(df: pd.DataFrame) -> pd.DataFrame:
267
+ """Extract the month from the datetime cell and add it
268
+ as a new column to the dataframe at the front."""
269
+ # make copy of dataframe
270
+ df_temp = df.copy()
271
+ # extract month and add new column at index 1 containing these numbers
272
+ df_temp.insert(1, "Month", df_temp["DateTransformed"].dt.month)
273
+ return df_temp
274
+
275
+ def add_day(df: pd.DataFrame) -> pd.DataFrame:
276
+ """Extract the day from the datetime cell and add it
277
+ as a new column to the dataframe at the front."""
278
+ # make copy of dataframe
279
+ df_temp = df.copy()
280
+ # extract day and add new column at index 2 containing these numbers
281
+ df_temp.insert(2, "Day", df_temp["DateTransformed"].dt.day)
282
+ return df_temp
283
+
284
+ def add_input_output_temperature(df: pd.DataFrame) -> pd.DataFrame:
285
+ """Add a multiindex denoting if the column
286
+ is an input or output variable."""
287
+ # copy the dataframe
288
+ temp_df = df.copy()
289
+ # extract all the column names
290
+ column_names = temp_df.columns.tolist()
291
+ # only the last 2 columns are output variables, all others are input
292
+ # variables. So make list of corresponding lengths of
293
+ # 'Input' and 'Output'
294
+ input_list = ["Input"] * (len(column_names) - 2)
295
+ output_list = ["Output"] * 2
296
+ # concat both lists
297
+ input_output_list = input_list + output_list
298
+ # define multi index for attaching this 'Input' and 'Output' list with
299
+ # the column names already existing
300
+ multiindex_bias = pd.MultiIndex.from_arrays(
301
+ [input_output_list, column_names]
302
+ )
303
+ # transpose such that index can be adjusted to multi index
304
+ new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
305
+ # transpose back such that columns are the same as before
306
+ # except with different labels
307
+ return new_df.transpose()
308
+
309
+ def normalize_columns_bias(df: pd.DataFrame) -> pd.DataFrame:
310
+ """Normalize the columns for the bias correction dataset.
311
+ This is different from normalizing all the columns separately
312
+ because the upper and lower bounds for the output variables
313
+ are assumed to be the same."""
314
+ # copy the dataframe
315
+ temp_df = df.copy()
316
+ # normalize each column
317
+ for feature_name in df.columns:
318
+ # the output columns are normalized using the same upper and
319
+ # lower bound for more efficient check of the inequality
320
+ if feature_name == "Next_Tmax" or feature_name == "Next_Tmin":
321
+ max_value = 38.9
322
+ min_value = 11.3
323
+ # the input columns are normalized using their respective
324
+ # upper and lower bounds
325
+ else:
326
+ max_value = df[feature_name].max()
327
+ min_value = df[feature_name].min()
328
+ temp_df[feature_name] = (df[feature_name] - min_value) / (
329
+ max_value - min_value
330
+ )
331
+ return temp_df
332
+
333
+ def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
334
+ """Sample 2500 examples from the dataframe without replacement."""
335
+ temp_df = df.copy()
336
+ sample_df = temp_df.sample(
337
+ n=2500, replace=False, random_state=3, axis=0
338
+ )
339
+ return sample_df
340
+
341
+ return (
342
+ # drop missing values
343
+ df.dropna(how="any")
344
+ # transform string date to datetime format
345
+ .pipe(date_to_datetime)
346
+ # add year as a single column
347
+ .pipe(add_year)
348
+ # add month as a single column
349
+ .pipe(add_month)
350
+ # add day as a single column
351
+ .pipe(add_day)
352
+ # remove original date string and the datetime format
353
+ .drop(["Date", "DateTransformed"], axis=1, inplace=False)
354
+ # convert all numbers to float32
355
+ .astype("float32")
356
+ # normalize columns
357
+ .pipe(normalize_columns_bias)
358
+ # add multi index indicating which columns are corresponding
359
+ # to input and output variables
360
+ .pipe(add_input_output_temperature)
361
+ # sample 2500 examples out of the dataset
362
+ .pipe(sample_2500_examples)
363
+ )
364
+
365
+
366
+ # pylint: disable-next=invalid-name
367
+ def preprocess_FamilyIncome(df: pd.DataFrame) -> pd.DataFrame:
368
+ """
369
+ Preprocesses the given Family Income dataframe by applying a
370
+ series of transformations and constraints.
371
+
372
+ The function sequentially:
373
+
374
+ - Drops rows with missing values.
375
+ - Converts object columns to appropriate data types and
376
+ removes string columns.
377
+ - Removes certain unnecessary columns like
378
+ 'Agricultural Household indicator' and related features.
379
+ - Adds labels to columns indicating whether they are
380
+ input or output variables.
381
+ - Normalizes the columns individually.
382
+ - Checks and removes rows that do not satisfy predefined constraints
383
+ (household income > expenses, food expenses > sub-expenses).
384
+ - Samples 2500 examples from the dataset without replacement.
385
+
386
+ Args:
387
+ df (pd.DataFrame): The input Family Income dataframe containing
388
+ the data to be processed.
389
+
390
+ Returns:
391
+ pd.DataFrame: The processed dataframe after applying the
392
+ transformations and constraints.
393
+ """
394
+
395
+ def normalize_columns_income(df: pd.DataFrame) -> pd.DataFrame:
396
+ """Normalize the columns for the Family Income dataframe.
397
+ This can also be applied to other dataframes because this
398
+ function normalizes all columns individually."""
399
+ # copy the dataframe
400
+ temp_df = df.copy()
401
+ # normalize each column
402
+ for feature_name in df.columns:
403
+ max_value = df[feature_name].max()
404
+ min_value = df[feature_name].min()
405
+ temp_df[feature_name] = (df[feature_name] - min_value) / (
406
+ max_value - min_value
407
+ )
408
+ return temp_df
409
+
410
+ def check_constraints_income(df: pd.DataFrame) -> pd.DataFrame:
411
+ """Check if all the constraints are satisfied for the dataframe
412
+ and remove the examples that do not satisfy the constraint.
413
+ This function only works for the Family Income dataset and the
414
+ constraints are that the household income is larger than all the
415
+ expenses and the food expense is larger than the sum of
416
+ the other (more detailed) food expenses.
417
+ """
418
+ temp_df = df.copy()
419
+ # check that household income is larger than expenses in the output
420
+ input_array = temp_df["Input"].to_numpy()
421
+ income_array = np.add(
422
+ np.multiply(
423
+ input_array[:, [0, 1]],
424
+ np.subtract(
425
+ np.asarray([11815988, 9234485]), np.asarray([11285, 0])
426
+ ),
427
+ ),
428
+ np.asarray([11285, 0]),
429
+ )
430
+ expense_array = temp_df["Output"].to_numpy()
431
+ expense_array = np.add(
432
+ np.multiply(
433
+ expense_array,
434
+ np.subtract(
435
+ np.asarray(
436
+ [
437
+ 791848,
438
+ 437467,
439
+ 140992,
440
+ 74800,
441
+ 2188560,
442
+ 1049275,
443
+ 149940,
444
+ 731000,
445
+ ]
446
+ ),
447
+ np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
448
+ ),
449
+ ),
450
+ np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
451
+ )
452
+ expense_array_without_dup = expense_array[:, [0, 4, 5, 6, 7]]
453
+ sum_expenses = np.sum(expense_array_without_dup, axis=1)
454
+ total_income = np.sum(income_array, axis=1)
455
+ sanity_check_array = np.greater_equal(total_income, sum_expenses)
456
+ temp_df["Unimportant"] = sanity_check_array.tolist()
457
+ reduction = temp_df[temp_df.Unimportant]
458
+ drop_reduction = reduction.drop("Unimportant", axis=1)
459
+
460
+ # check that the food expense is larger than all the sub expenses
461
+ expense_reduced_array = drop_reduction["Output"].to_numpy()
462
+ expense_reduced_array = np.add(
463
+ np.multiply(
464
+ expense_reduced_array,
465
+ np.subtract(
466
+ np.asarray(
467
+ [
468
+ 791848,
469
+ 437467,
470
+ 140992,
471
+ 74800,
472
+ 2188560,
473
+ 1049275,
474
+ 149940,
475
+ 731000,
476
+ ]
477
+ ),
478
+ np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
479
+ ),
480
+ ),
481
+ np.asarray([3704, 0, 0, 0, 1950, 0, 0, 0]),
482
+ )
483
+ food_mul_expense_array = expense_reduced_array[:, [1, 2, 3]]
484
+ food_mul_expense_array_sum = np.sum(food_mul_expense_array, axis=1)
485
+ food_expense_array = expense_reduced_array[:, 0]
486
+ sanity_check_array = np.greater_equal(
487
+ food_expense_array, food_mul_expense_array_sum
488
+ )
489
+ drop_reduction["Unimportant"] = sanity_check_array.tolist()
490
+ new_reduction = drop_reduction[drop_reduction.Unimportant]
491
+ satisfied_constraints_df = new_reduction.drop("Unimportant", axis=1)
492
+
493
+ return satisfied_constraints_df
494
+
495
+ def add_input_output_family_income(df: pd.DataFrame) -> pd.DataFrame:
496
+ """Add a multiindex denoting if the column is
497
+ an input or output variable."""
498
+
499
+ # copy the dataframe
500
+ temp_df = df.copy()
501
+ # extract all the column names
502
+ column_names = temp_df.columns.tolist()
503
+ # the 2nd-9th columns correspond to output variables and all
504
+ # others to input variables. So make list of corresponding
505
+ # lengths of 'Input' and 'Output'
506
+ input_list_start = ["Input"]
507
+ input_list_end = ["Input"] * (len(column_names) - 9)
508
+ output_list = ["Output"] * 8
509
+ # concat both lists
510
+ input_output_list = input_list_start + output_list + input_list_end
511
+ # define multi index for attaching this 'Input' and
512
+ # 'Output' list with the column names already existing
513
+ multiindex_bias = pd.MultiIndex.from_arrays(
514
+ [input_output_list, column_names]
515
+ )
516
+ # transpose such that index can be adjusted to multi index
517
+ new_df = pd.DataFrame(df.transpose().to_numpy(), index=multiindex_bias)
518
+ # transpose back such that columns are the same as
519
+ # before except with different labels
520
+ return new_df.transpose()
521
+
522
+ def sample_2500_examples(df: pd.DataFrame) -> pd.DataFrame:
523
+ """Sample 2500 examples from the dataframe without replacement."""
524
+ temp_df = df.copy()
525
+ sample_df = temp_df.sample(
526
+ n=2500, replace=False, random_state=3, axis=0
527
+ )
528
+ return sample_df
529
+
530
+ return (
531
+ # drop missing values
532
+ df.dropna(how="any")
533
+ # convert object to fitting dtype
534
+ .convert_dtypes()
535
+ # remove all strings (no other dtypes are present
536
+ # except for integers and floats)
537
+ .select_dtypes(exclude=["string"])
538
+ # transform all numbers to same dtype
539
+ .astype("float32")
540
+ # drop column with label Agricultural Household indicator
541
+ # because this is not really a numerical input but
542
+ # rather a categorical/classification
543
+ .drop(["Agricultural Household indicator"], axis=1, inplace=False)
544
+ # this column is dropped because it depends on
545
+ # Agricultural Household indicator
546
+ .drop(["Crop Farming and Gardening expenses"], axis=1, inplace=False)
547
+ # use 8 output variables and 24 input variables
548
+ .drop(
549
+ [
550
+ "Total Rice Expenditure",
551
+ "Total Fish and marine products Expenditure",
552
+ "Fruit Expenditure",
553
+ "Restaurant and hotels Expenditure",
554
+ "Alcoholic Beverages Expenditure",
555
+ "Tobacco Expenditure",
556
+ "Clothing, Footwear and Other Wear Expenditure",
557
+ "Imputed House Rental Value",
558
+ "Transportation Expenditure",
559
+ "Miscellaneous Goods and Services Expenditure",
560
+ "Special Occasions Expenditure",
561
+ ],
562
+ axis=1,
563
+ inplace=False,
564
+ )
565
+ # add input and output labels to each column
566
+ .pipe(add_input_output_family_income)
567
+ # normalize all the columns
568
+ .pipe(normalize_columns_income)
569
+ # remove all datapoints that do not satisfy the constraints
570
+ .pipe(check_constraints_income)
571
+ # sample 2500 examples
572
+ .pipe(sample_2500_examples)
573
+ )
574
+
575
+
576
+ def validate_type(name, value, expected_types, allow_none=False):
577
+ """
578
+ Validate that a value is of the specified type(s).
579
+
580
+ Args:
581
+ name (str): Name of the argument for error messages.
582
+ value: Value to validate.
583
+ expected_types (type or tuple of types): Expected type(s) for the value.
584
+ allow_none (bool): Whether to allow the value to be None.
585
+ Defaults to False.
586
+
587
+ Raises:
588
+ TypeError: If the value is not of the expected type(s).
589
+ """
590
+
591
+ if value is None:
592
+ if not allow_none:
593
+ raise TypeError(f"Argument {name} cannot be None.")
594
+ return
595
+
596
+ if not isinstance(value, expected_types):
597
+ raise TypeError(
598
+ f"Argument {name} '{str(value)}' is not supported. "
599
+ f"Only values of type {str(expected_types)} are allowed."
600
+ )
601
+
602
+
603
+ def validate_iterable(
604
+ name,
605
+ value,
606
+ expected_element_types,
607
+ allowed_iterables=(list, set),
608
+ allow_none=False,
609
+ ):
610
+ """
611
+ Validate that a value is an iterable (e.g., list, set) with elements of
612
+ the specified type(s).
613
+
614
+ Args:
615
+ name (str): Name of the argument for error messages.
616
+ value: Value to validate.
617
+ expected_element_types (type or tuple of types): Expected type(s)
618
+ for the elements.
619
+ allowed_iterables (tuple of types): Iterable types that are
620
+ allowed (default: list and set).
621
+ allow_none (bool): Whether to allow the value to be None.
622
+ Defaults to False.
623
+
624
+ Raises:
625
+ TypeError: If the value is not an allowed iterable type or if
626
+ any element is not of the expected type(s).
627
+ """
628
+
629
+ if value is None:
630
+ if not allow_none:
631
+ raise TypeError(f"Argument {name} cannot be None.")
632
+ return
633
+
634
+ if not isinstance(value, allowed_iterables):
635
+ raise TypeError(
636
+ f"Argument {name} '{str(value)}' is not supported. "
637
+ f"Only values of type {str(allowed_iterables)} are allowed."
638
+ )
639
+ if not all(
640
+ isinstance(element, expected_element_types) for element in value
641
+ ):
642
+ raise TypeError(
643
+ f"Invalid elements in {name} '{str(value)}'. "
644
+ f"Only elements of type {str(expected_element_types)} are allowed."
645
+ )
646
+
647
+
648
+ def validate_comparator_pytorch(name, value):
649
+ """
650
+ Validate that a value is a callable PyTorch comparator function.
651
+
652
+ Args:
653
+ name (str): Name of the argument for error messages.
654
+ value: Value to validate.
655
+
656
+ Raises:
657
+ TypeError: If the value is not callable or not a PyTorch comparator.
658
+ """
659
+
660
+ # List of valid PyTorch comparator functions
661
+ pytorch_comparators = {torch.gt, torch.lt, torch.ge, torch.le}
662
+
663
+ # Check if value is callable and if it's one of
664
+ # the PyTorch comparator functions
665
+ if not callable(value):
666
+ raise TypeError(
667
+ f"Argument {name} '{str(value)}' is not supported. "
668
+ "Only callable functions are allowed."
669
+ )
670
+
671
+ if value not in pytorch_comparators:
672
+ raise TypeError(
673
+ f"Argument {name} '{str(value)}' is not a valid PyTorch comparator "
674
+ "function. Only PyTorch functions like torch.gt, torch.lt, "
675
+ "torch.ge, torch.le are allowed."
676
+ )
677
+
678
+
679
+ def validate_callable(name, value, allow_none=False):
680
+ """
681
+ Validate that a value is callable function.
682
+
683
+ Args:
684
+ name (str): Name of the argument for error messages.
685
+ value: Value to validate.
686
+ allow_none (bool): Whether to allow the value to be None.
687
+ Defaults to False.
688
+
689
+ Raises:
690
+ TypeError: If the value is not callable.
691
+ """
692
+
693
+ if value is None:
694
+ if not allow_none:
695
+ raise TypeError(f"Argument {name} cannot be None.")
696
+ return
697
+
698
+ if not callable(value):
699
+ raise TypeError(
700
+ f"Argument {name} '{str(value)}' is not supported. "
701
+ "Only callable functions are allowed."
702
+ )
703
+
704
+
705
+ def validate_loaders() -> None:
706
+ """
707
+ TODO: implement function
708
+ """
709
+ # TODO complete
710
+ pass