workbench 0.8.177__py3-none-any.whl → 0.8.179__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (27) hide show
  1. workbench/api/endpoint.py +3 -2
  2. workbench/core/artifacts/endpoint_core.py +5 -5
  3. workbench/core/artifacts/feature_set_core.py +67 -8
  4. workbench/core/views/training_view.py +38 -48
  5. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  6. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  7. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  8. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  9. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +44 -45
  10. workbench/model_scripts/custom_models/uq_models/mapie.template +42 -43
  11. workbench/model_scripts/custom_models/uq_models/meta_uq.template +7 -22
  12. workbench/model_scripts/custom_models/uq_models/ngboost.template +5 -12
  13. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  14. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  15. workbench/model_scripts/quant_regression/quant_regression.template +5 -10
  16. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  17. workbench/model_scripts/xgb_model/generated_model_script.py +24 -33
  18. workbench/model_scripts/xgb_model/xgb_model.template +23 -32
  19. workbench/scripts/ml_pipeline_sqs.py +14 -2
  20. workbench/utils/model_utils.py +12 -2
  21. workbench/utils/xgboost_model_utils.py +161 -138
  22. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/METADATA +1 -1
  23. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/RECORD +27 -27
  24. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/WHEEL +0 -0
  25. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/entry_points.txt +0 -0
  26. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/licenses/LICENSE +0 -0
  27. {workbench-0.8.177.dist-info → workbench-0.8.179.dist-info}/top_level.txt +0 -0
workbench/api/endpoint.py CHANGED
@@ -4,6 +4,7 @@ Endpoints can be viewed in the AWS Sagemaker interfaces or in the Workbench
4
4
  Dashboard UI, which provides additional model details and performance metrics"""
5
5
 
6
6
  import pandas as pd
7
+ from typing import Tuple
7
8
 
8
9
  # Workbench Imports
9
10
  from workbench.core.artifacts.endpoint_core import EndpointCore
@@ -70,14 +71,14 @@ class Endpoint(EndpointCore):
70
71
  """
71
72
  return super().fast_inference(eval_df, threads=threads)
72
73
 
73
- def cross_fold_inference(self, nfolds: int = 5) -> dict:
74
+ def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
74
75
  """Run cross-fold inference (only works for XGBoost models)
75
76
 
76
77
  Args:
77
78
  nfolds (int): The number of folds to use for cross-validation (default: 5)
78
79
 
79
80
  Returns:
80
- dict: A dictionary with fold results
81
+ Tuple(dict, pd.DataFrame): A tuple containing a dictionary of metrics and a DataFrame with predictions
81
82
  """
82
83
  return super().cross_fold_inference(nfolds)
83
84
 
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  import numpy as np
9
9
  from io import StringIO
10
10
  import awswrangler as wr
11
- from typing import Union, Optional
11
+ from typing import Union, Optional, Tuple
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
@@ -436,24 +436,24 @@ class EndpointCore(Artifact):
436
436
  # Return the prediction DataFrame
437
437
  return prediction_df
438
438
 
439
- def cross_fold_inference(self, nfolds: int = 5) -> dict:
439
+ def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
440
440
  """Run cross-fold inference (only works for XGBoost models)
441
441
 
442
442
  Args:
443
443
  nfolds (int): Number of folds to use for cross-fold (default: 5)
444
444
 
445
445
  Returns:
446
- dict: Dictionary with the cross-fold inference results
446
+ Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
447
447
  """
448
448
 
449
449
  # Grab our model
450
450
  model = ModelCore(self.model_name)
451
451
 
452
452
  # Compute CrossFold Metrics
453
- cross_fold_metrics = cross_fold_inference(model, nfolds=nfolds)
453
+ cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
454
454
  if cross_fold_metrics:
455
455
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
456
- return cross_fold_metrics
456
+ return cross_fold_metrics, out_of_fold_df
457
457
 
458
458
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
459
459
  """Run inference on the Endpoint using the provided DataFrame
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Optional, List, Union
21
21
 
22
22
  from workbench.utils.aws_utils import aws_throttle
23
23
 
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
509
509
  ].tolist()
510
510
  return hold_out_ids
511
511
 
512
+ def set_training_filter(self, filter_expression: Optional[str] = None):
513
+ """Set a filter expression for the training view for this FeatureSet
514
+
515
+ Args:
516
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
517
+ If None or empty string, will reset to training view with no filter
518
+ (default: None)
519
+ """
520
+ from workbench.core.views import TrainingView
521
+
522
+ # Grab the existing holdout ids
523
+ holdout_ids = self.get_training_holdouts()
524
+
525
+ # Create a NEW training view
526
+ self.log.important(f"Setting Training Filter: {filter_expression}")
527
+ TrainingView.create(
528
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
+ )
530
+
531
+ def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
+ """Exclude a list of IDs from the training view
533
+
534
+ Args:
535
+ ids (List[Union[str, int]],): List of IDs to exclude from training
536
+ column_name (Optional[str]): Column name to filter on.
537
+ If None, uses self.id_column (default: None)
538
+ """
539
+ # Use the default id_column if not specified
540
+ column = column_name or self.id_column
541
+
542
+ # Handle empty list case
543
+ if not ids:
544
+ self.log.warning("No IDs provided to exclude")
545
+ return
546
+
547
+ # Build the filter expression with proper SQL quoting
548
+ quoted_ids = ", ".join([repr(id) for id in ids])
549
+ filter_expression = f"{column} NOT IN ({quoted_ids})"
550
+
551
+ # Apply the filter
552
+ self.set_training_filter(filter_expression)
553
+
512
554
  @classmethod
513
555
  def delete_views(cls, table: str, database: str):
514
556
  """Delete any views associated with this FeatureSet
@@ -707,7 +749,7 @@ if __name__ == "__main__":
707
749
 
708
750
  # Test getting the holdout ids
709
751
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
752
+ holdout_ids = my_features.get_training_holdouts()
711
753
  print(f"Holdout IDs: {holdout_ids}")
712
754
 
713
755
  # Get a sample of the data
@@ -729,16 +771,33 @@ if __name__ == "__main__":
729
771
  table = my_features.view("training").table
730
772
  df = my_features.query(f'SELECT id, name FROM "{table}"')
731
773
  my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
774
+ my_features.set_training_holdouts(my_holdout_ids)
738
775
 
739
776
  # Get the training data
740
777
  print("Getting the training data...")
741
778
  training_data = my_features.get_training_data()
779
+ print(f"Training Data: {training_data.shape}")
780
+
781
+ # Test the filter expression functionality
782
+ print("Setting a filter expression...")
783
+ my_features.set_training_filter("id < 50 AND height > 65.0")
784
+ training_data = my_features.get_training_data()
785
+ print(f"Training Data: {training_data.shape}")
786
+ print(training_data)
787
+
788
+ # Remove training filter
789
+ print("Removing the filter expression...")
790
+ my_features.set_training_filter(None)
791
+ training_data = my_features.get_training_data()
792
+ print(f"Training Data: {training_data.shape}")
793
+ print(training_data)
794
+
795
+ # Test excluding ids from training
796
+ print("Excluding ids from training...")
797
+ my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
798
+ training_data = my_features.get_training_data()
799
+ print(f"Training Data: {training_data.shape}")
800
+ print(training_data)
742
801
 
743
802
  # Now delete the AWS artifacts associated with this Feature Set
744
803
  # print("Deleting Workbench Feature Set...")
@@ -3,7 +3,7 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
@@ -34,6 +34,7 @@ class TrainingView(CreateView):
34
34
  source_table: str = None,
35
35
  id_column: str = None,
36
36
  holdout_ids: Union[list[str], list[int], None] = None,
37
+ filter_expression: str = None,
37
38
  ) -> Union[View, None]:
38
39
  """Factory method to create and return a TrainingView instance.
39
40
 
@@ -42,6 +43,8 @@ class TrainingView(CreateView):
42
43
  source_table (str, optional): The table/view to create the view from. Defaults to None.
43
44
  id_column (str, optional): The name of the id column. Defaults to None.
44
45
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
46
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
47
+ Defaults to None.
45
48
 
46
49
  Returns:
47
50
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +72,36 @@ class TrainingView(CreateView):
69
72
  else:
70
73
  id_column = instance.auto_id_column
71
74
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
75
+ # Enclose each column name in double quotes
76
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
77
+
78
+ # Build the training assignment logic
79
+ if holdout_ids:
80
+ # Format the list of holdout ids for SQL IN clause
81
+ if all(isinstance(id, str) for id in holdout_ids):
82
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
83
+ else:
84
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
76
85
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
86
+ training_logic = f"""CASE
87
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
88
+ ELSE True
89
+ END AS training"""
80
90
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
91
+ # Default 80/20 split using modulo
92
+ training_logic = f"""CASE
93
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
94
+ ELSE False
95
+ END AS training"""
82
96
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
97
+ # Build WHERE clause if filter_expression is provided
98
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
99
 
86
100
  # Construct the CREATE VIEW query
87
101
  create_view_query = f"""
88
102
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
103
+ SELECT {sql_columns}, {training_logic}
104
+ FROM {instance.source_table}{where_clause}
94
105
  """
95
106
 
96
107
  # Execute the CREATE VIEW query
@@ -99,43 +110,13 @@ class TrainingView(CreateView):
99
110
  # Return the View
100
111
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
112
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
105
-
106
- Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
109
- """
110
- self.log.important(f"Creating default Training View {self.table}...")
111
-
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
115
-
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
118
-
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
120
- create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
127
- """
128
-
129
- # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
131
-
132
113
 
133
114
  if __name__ == "__main__":
134
115
  """Exercise the Training View functionality"""
135
116
  from workbench.api import FeatureSet
136
117
 
137
118
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
119
+ fs = FeatureSet("abalone_features")
139
120
 
140
121
  # Delete the existing training view
141
122
  training_view = TrainingView.create(fs)
@@ -152,9 +133,18 @@ if __name__ == "__main__":
152
133
 
153
134
  # Create a TrainingView with holdout ids
154
135
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
136
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
137
 
157
138
  # Pull the training data
158
139
  df = training_view.pull_dataframe()
159
140
  print(df.head())
160
141
  print(df["training"].value_counts())
142
+ print(f"Shape: {df.shape}")
143
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
144
+
145
+ # Test the filter expression
146
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
147
+ df = training_view.pull_dataframe()
148
+ print(df.head())
149
+ print(f"Shape with filter: {df.shape}")
150
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
8
8
  "id_column": "{{id_column}}",
9
9
  "features": "{{feature_list}}",
10
10
  "target": "{{target_column}}",
11
- "track_columns": "{{track_columns}}"
11
+ "track_columns": "{{track_columns}}",
12
12
  }
13
13
 
14
14
  from io import StringIO
@@ -73,10 +73,7 @@ if __name__ == "__main__":
73
73
  args = parser.parse_args()
74
74
 
75
75
  # Load training data from the specified directory
76
- training_files = [
77
- os.path.join(args.train, file)
78
- for file in os.listdir(args.train) if file.endswith(".csv")
79
- ]
76
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
80
77
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
81
78
 
82
79
  # Check if the DataFrame is empty
@@ -88,6 +85,7 @@ if __name__ == "__main__":
88
85
  # Now serialize the model
89
86
  model.serialize(args.model_dir)
90
87
 
88
+
91
89
  # Model loading and prediction functions
92
90
  def model_fn(model_dir):
93
91
 
@@ -14,7 +14,7 @@ import pandas as pd
14
14
  TEMPLATE_PARAMS = {
15
15
  "features": "{{feature_list}}",
16
16
  "target": "{{target_column}}",
17
- "train_all_data": "{{train_all_data}}"
17
+ "train_all_data": "{{train_all_data}}",
18
18
  }
19
19
 
20
20
 
@@ -37,7 +37,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
37
37
  """
38
38
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
39
39
  Prioritizes exact matches, then case-insensitive matches.
40
-
40
+
41
41
  Raises ValueError if any model features cannot be matched.
42
42
  """
43
43
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -81,10 +81,7 @@ if __name__ == "__main__":
81
81
  args = parser.parse_args()
82
82
 
83
83
  # Load training data from the specified directory
84
- training_files = [
85
- os.path.join(args.train, file)
86
- for file in os.listdir(args.train) if file.endswith(".csv")
87
- ]
84
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
88
85
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
89
86
 
90
87
  # Check if the DataFrame is empty
@@ -109,8 +106,10 @@ if __name__ == "__main__":
109
106
  # Create and train the Regression/Confidence model
110
107
  # model = BayesianRidge()
111
108
  model = BayesianRidge(
112
- alpha_1=1e-6, alpha_2=1e-6, # Noise precision
113
- lambda_1=1e-6, lambda_2=1e-6, # Weight precision
109
+ alpha_1=1e-6,
110
+ alpha_2=1e-6, # Noise precision
111
+ lambda_1=1e-6,
112
+ lambda_2=1e-6, # Weight precision
114
113
  fit_intercept=True,
115
114
  )
116
115
 
@@ -4,11 +4,7 @@ import awswrangler as wr
4
4
  import numpy as np
5
5
 
6
6
  # Model Performance Scores
7
- from sklearn.metrics import (
8
- mean_absolute_error,
9
- r2_score,
10
- root_mean_squared_error
11
- )
7
+ from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
12
8
  from sklearn.model_selection import KFold
13
9
  from scipy.optimize import minimize
14
10
 
@@ -23,7 +19,7 @@ TEMPLATE_PARAMS = {
23
19
  "features": "{{feature_list}}",
24
20
  "target": "{{target_column}}",
25
21
  "train_all_data": "{{train_all_data}}",
26
- "model_metrics_s3_path": "{{model_metrics_s3_path}}"
22
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
27
23
  }
28
24
 
29
25
 
@@ -47,7 +43,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
47
43
  """
48
44
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
49
45
  Prioritizes exact matches, then case-insensitive matches.
50
-
46
+
51
47
  Raises ValueError if any model features cannot be matched.
52
48
  """
53
49
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -90,10 +86,7 @@ if __name__ == "__main__":
90
86
  args = parser.parse_args()
91
87
 
92
88
  # Load training data from the specified directory
93
- training_files = [
94
- os.path.join(args.train, file)
95
- for file in os.listdir(args.train) if file.endswith(".csv")
96
- ]
89
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
97
90
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
98
91
 
99
92
  # Check if the DataFrame is empty
@@ -172,16 +165,14 @@ if __name__ == "__main__":
172
165
  cv_residuals = np.array(cv_residuals)
173
166
  cv_uncertainties = np.array(cv_uncertainties)
174
167
 
175
-
176
168
  # Optimize calibration parameters: σ_cal = a * σ_uc + b
177
169
  def neg_log_likelihood(params):
178
170
  a, b = params
179
171
  sigma_cal = a * cv_uncertainties + b
180
172
  sigma_cal = np.maximum(sigma_cal, 1e-8) # Prevent division by zero
181
- return np.sum(0.5 * np.log(2 * np.pi * sigma_cal ** 2) + 0.5 * (cv_residuals ** 2) / (sigma_cal ** 2))
173
+ return np.sum(0.5 * np.log(2 * np.pi * sigma_cal**2) + 0.5 * (cv_residuals**2) / (sigma_cal**2))
182
174
 
183
-
184
- result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method='Nelder-Mead')
175
+ result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method="Nelder-Mead")
185
176
  cal_a, cal_b = result.x
186
177
 
187
178
  print(f"Calibration parameters: a={cal_a:.4f}, b={cal_b:.4f}")
@@ -205,7 +196,9 @@ if __name__ == "__main__":
205
196
  result_df["prediction"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].mean(axis=1)
206
197
 
207
198
  # Compute uncalibrated uncertainty
208
- result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(axis=1)
199
+ result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(
200
+ axis=1
201
+ )
209
202
 
210
203
  # Apply calibration to uncertainty
211
204
  result_df["prediction_std"] = cal_a * result_df["prediction_std_uc"] + cal_b
@@ -352,4 +345,4 @@ def predict_fn(df, models) -> pd.DataFrame:
352
345
  df = df.reindex(sorted(df.columns), axis=1)
353
346
 
354
347
  # All done, return the DataFrame
355
- return df
348
+ return df
@@ -9,7 +9,7 @@ from sklearn.model_selection import train_test_split
9
9
  TEMPLATE_PARAMS = {
10
10
  "features": "{{feature_list}}",
11
11
  "target": "{{target_column}}",
12
- "train_all_data": "{{train_all_data}}"
12
+ "train_all_data": "{{train_all_data}}",
13
13
  }
14
14
 
15
15
  from io import StringIO
@@ -33,7 +33,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
33
33
  """
34
34
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
35
35
  Prioritizes exact matches, then case-insensitive matches.
36
-
36
+
37
37
  Raises ValueError if any model features cannot be matched.
38
38
  """
39
39
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -46,7 +46,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
46
46
  rename_dict[df_columns_lower[feature.lower()]] = feature
47
47
  else:
48
48
  missing.append(feature)
49
-
49
+
50
50
  if missing:
51
51
  raise ValueError(f"Features not found: {missing}")
52
52
 
@@ -76,10 +76,7 @@ if __name__ == "__main__":
76
76
  args = parser.parse_args()
77
77
 
78
78
  # Load training data from the specified directory
79
- training_files = [
80
- os.path.join(args.train, file)
81
- for file in os.listdir(args.train) if file.endswith(".csv")
82
- ]
79
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
83
80
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
84
81
 
85
82
  # Check if the DataFrame is empty
@@ -112,10 +109,7 @@ if __name__ == "__main__":
112
109
  )
113
110
 
114
111
  # Create a Pipeline with StandardScaler
115
- model = Pipeline([
116
- ("scaler", StandardScaler()),
117
- ("model", model)
118
- ])
112
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
119
113
 
120
114
  # Prepare features and targets for training
121
115
  X_train = df_train[features]