workbench 0.8.169__py3-none-any.whl → 0.8.171__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

workbench/api/model.py CHANGED
@@ -40,6 +40,7 @@ class Model(ModelCore):
40
40
  mem_size: int = 2048,
41
41
  max_concurrency: int = 5,
42
42
  instance: str = "ml.t2.medium",
43
+ data_capture: bool = False,
43
44
  ) -> Endpoint:
44
45
  """Create an Endpoint from the Model.
45
46
 
@@ -50,6 +51,7 @@ class Model(ModelCore):
50
51
  mem_size (int): The memory size for the Endpoint in MB (default: 2048)
51
52
  max_concurrency (int): The maximum concurrency for the Endpoint (default: 5)
52
53
  instance (str): The instance type to use for Realtime(serverless=False) Endpoints (default: "ml.t2.medium")
54
+ data_capture (bool): Enable data capture for the Endpoint (default: False)
53
55
 
54
56
  Returns:
55
57
  Endpoint: The Endpoint created from the Model
@@ -73,6 +75,7 @@ class Model(ModelCore):
73
75
  model_to_endpoint.transform(
74
76
  mem_size=mem_size,
75
77
  max_concurrency=max_concurrency,
78
+ data_capture=data_capture,
76
79
  )
77
80
 
78
81
  # Set the Endpoint Owner and Return the Endpoint
@@ -972,12 +972,23 @@ class EndpointCore(Artifact):
972
972
  cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
973
973
  cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
974
974
 
975
- # Recursively delete all endpoint S3 artifacts (inference, data capture, monitoring, etc)
975
+ # Recursively delete all endpoint S3 artifacts (inference, etc)
976
+ # Note: We do not want to delete the data_capture/ files since these
977
+ # might be used for collection and data drift analysis
976
978
  base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
977
- s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
978
- cls.log.info(f"Deleting S3 Objects at {base_endpoint_path}...")
979
- cls.log.info(f"{s3_objects}")
980
- wr.s3.delete_objects(s3_objects, boto3_session=cls.boto3_session)
979
+ all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
980
+
981
+ # Filter out objects that contain 'data_capture/' in their path
982
+ s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
983
+ cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
984
+ cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
985
+ cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
986
+
987
+ if s3_objects_to_delete:
988
+ wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
989
+ cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
990
+ else:
991
+ cls.log.info("No objects to delete (only data_capture files found)")
981
992
 
982
993
  # Delete any dataframes that were stored in the Dataframe Cache
983
994
  cls.log.info("Deleting Dataframe Cache...")
@@ -186,11 +186,11 @@ class MonitorCore:
186
186
 
187
187
  # Log the data capture operation
188
188
  self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
189
- self.log.important("This normally redeploys the endpoint...")
189
+ self.log.important("This will redeploy the endpoint...")
190
190
 
191
191
  # Create and apply the data capture configuration
192
192
  data_capture_config = DataCaptureConfig(
193
- enable_capture=True, # Required parameter
193
+ enable_capture=True,
194
194
  sampling_percentage=capture_percentage,
195
195
  destination_s3_uri=self.data_capture_path,
196
196
  )
@@ -196,7 +196,9 @@ class AWSMeta:
196
196
 
197
197
  # Return the summary as a DataFrame
198
198
  df = pd.DataFrame(data_summary).convert_dtypes()
199
- return df.sort_values(by="Created", ascending=False)
199
+ if not df.empty:
200
+ df.sort_values(by="Created", ascending=False, inplace=True)
201
+ return df
200
202
 
201
203
  def models(self, details: bool = False) -> pd.DataFrame:
202
204
  """Get a summary of the Models in AWS.
@@ -256,7 +258,9 @@ class AWSMeta:
256
258
 
257
259
  # Return the summary as a DataFrame
258
260
  df = pd.DataFrame(model_summary).convert_dtypes()
259
- return df.sort_values(by="Created", ascending=False)
261
+ if not df.empty:
262
+ df.sort_values(by="Created", ascending=False, inplace=True)
263
+ return df
260
264
 
261
265
  def endpoints(self, details: bool = False) -> pd.DataFrame:
262
266
  """Get a summary of the Endpoints in AWS.
@@ -317,7 +321,9 @@ class AWSMeta:
317
321
 
318
322
  # Return the summary as a DataFrame
319
323
  df = pd.DataFrame(data_summary).convert_dtypes()
320
- return df.sort_values(by="Created", ascending=False)
324
+ if not df.empty:
325
+ df.sort_values(by="Created", ascending=False, inplace=True)
326
+ return df
321
327
 
322
328
  def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
323
329
  """Internal: Get the Endpoint Configuration information for the given endpoint config name.
@@ -657,7 +663,8 @@ class AWSMeta:
657
663
  df = pd.DataFrame(data_summary).convert_dtypes()
658
664
 
659
665
  # Sort by the Modified column
660
- df = df.sort_values(by="Modified", ascending=False)
666
+ if not df.empty:
667
+ df = df.sort_values(by="Modified", ascending=False)
661
668
  return df
662
669
 
663
670
  def _aws_pipelines(self) -> pd.DataFrame:
@@ -5,6 +5,7 @@ from sagemaker import ModelPackage
5
5
  from sagemaker.serializers import CSVSerializer
6
6
  from sagemaker.deserializers import CSVDeserializer
7
7
  from sagemaker.serverless import ServerlessInferenceConfig
8
+ from sagemaker.model_monitor import DataCaptureConfig
8
9
 
9
10
  # Local Imports
10
11
  from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
@@ -51,27 +52,38 @@ class ModelToEndpoint(Transform):
51
52
  EndpointCore.managed_delete(self.output_name)
52
53
 
53
54
  # Get the Model Package ARN for our input model
54
- input_model = ModelCore(self.input_name)
55
- model_package_arn = input_model.model_package_arn()
55
+ workbench_model = ModelCore(self.input_name)
56
56
 
57
57
  # Deploy the model
58
- self._deploy_model(model_package_arn, **kwargs)
58
+ self._deploy_model(workbench_model, **kwargs)
59
59
 
60
60
  # Add this endpoint to the set of registered endpoints for the model
61
- input_model.register_endpoint(self.output_name)
61
+ workbench_model.register_endpoint(self.output_name)
62
62
 
63
63
  # This ensures that the endpoint is ready for use
64
64
  time.sleep(5) # We wait for AWS Lag
65
65
  end = EndpointCore(self.output_name)
66
66
  self.log.important(f"Endpoint {end.name} is ready for use")
67
67
 
68
- def _deploy_model(self, model_package_arn: str, mem_size: int = 2048, max_concurrency: int = 5):
68
+ def _deploy_model(
69
+ self,
70
+ workbench_model: ModelCore,
71
+ mem_size: int = 2048,
72
+ max_concurrency: int = 5,
73
+ data_capture: bool = False,
74
+ capture_percentage: int = 100,
75
+ ):
69
76
  """Internal Method: Deploy the Model
70
77
 
71
78
  Args:
72
- model_package_arn(str): The Model Package ARN used to deploy the Endpoint
79
+ workbench_model(ModelCore): The Workbench ModelCore object to deploy
80
+ mem_size(int): Memory size for serverless deployment
81
+ max_concurrency(int): Max concurrency for serverless deployment
82
+ data_capture(bool): Enable data capture during deployment
83
+ capture_percentage(int): Percentage of data to capture. Defaults to 100.
73
84
  """
74
85
  # Grab the specified Model Package
86
+ model_package_arn = workbench_model.model_package_arn()
75
87
  model_package = ModelPackage(
76
88
  role=self.workbench_role_arn,
77
89
  model_package_arn=model_package_arn,
@@ -95,6 +107,23 @@ class ModelToEndpoint(Transform):
95
107
  max_concurrency=max_concurrency,
96
108
  )
97
109
 
110
+ # Configure data capture if requested (and not serverless)
111
+ data_capture_config = None
112
+ if data_capture and not self.serverless:
113
+ # Set up the S3 path for data capture
114
+ base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
115
+ data_capture_path = f"{base_endpoint_path}/data_capture"
116
+ self.log.important(f"Configuring Data Capture --> {data_capture_path}")
117
+ data_capture_config = DataCaptureConfig(
118
+ enable_capture=True,
119
+ sampling_percentage=capture_percentage,
120
+ destination_s3_uri=data_capture_path,
121
+ )
122
+ elif data_capture and self.serverless:
123
+ self.log.warning(
124
+ "Data capture is not supported for serverless endpoints. Skipping data capture configuration."
125
+ )
126
+
98
127
  # Deploy the Endpoint
99
128
  self.log.important(f"Deploying the Endpoint {self.output_name}...")
100
129
  model_package.deploy(
@@ -104,6 +133,7 @@ class ModelToEndpoint(Transform):
104
133
  endpoint_name=self.output_name,
105
134
  serializer=CSVSerializer(),
106
135
  deserializer=CSVDeserializer(),
136
+ data_capture_config=data_capture_config,
107
137
  tags=aws_tags,
108
138
  )
109
139
 
@@ -2,7 +2,6 @@
2
2
  from ngboost import NGBRegressor
3
3
  from xgboost import XGBRegressor # Base Estimator
4
4
  from sklearn.model_selection import train_test_split
5
- import numpy as np
6
5
 
7
6
  # Model Performance Scores
8
7
  from sklearn.metrics import (
@@ -16,7 +15,9 @@ import json
16
15
  import argparse
17
16
  import joblib
18
17
  import os
18
+ import numpy as np
19
19
  import pandas as pd
20
+ from typing import List, Tuple
20
21
 
21
22
  # Local Imports
22
23
  from proximity import Proximity
@@ -25,11 +26,12 @@ from proximity import Proximity
25
26
 
26
27
  # Template Placeholders
27
28
  TEMPLATE_PARAMS = {
28
- "id_column": "id",
29
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
30
- "target": "solubility",
31
- "train_all_data": True,
32
- "track_columns": ['solubility']
29
+ "id_column": "udm_mol_bat_id",
30
+ "target": "udm_asy_res_intrinsic_clearance_ul_per_min_per_mg_protein",
31
+ "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v'],
32
+ "compressed_features": [],
33
+ "train_all_data": False,
34
+ "track_columns": ['udm_asy_res_intrinsic_clearance_ul_per_min_per_mg_protein']
33
35
  }
34
36
 
35
37
 
@@ -73,136 +75,97 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
73
75
  return df.rename(columns=rename_dict)
74
76
 
75
77
 
76
- def distance_weighted_calibrated_intervals(
77
- df_pred: pd.DataFrame,
78
- prox_df: pd.DataFrame,
79
- calibration_strength: float = 0.7,
80
- distance_decay: float = 3.0,
81
- ) -> pd.DataFrame:
78
+ def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
82
79
  """
83
- Calibrate intervals using distance-weighted neighbor quantiles.
84
- Uses all 10 neighbors with distance-based weighting.
80
+ Converts appropriate columns to categorical type with consistent mappings.
81
+
82
+ Args:
83
+ df (pd.DataFrame): The DataFrame to process.
84
+ features (list): List of feature names to consider for conversion.
85
+ category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
86
+ training mode. If populated, we're in inference mode.
87
+
88
+ Returns:
89
+ tuple: (processed DataFrame, category mappings dictionary)
85
90
  """
86
- id_column = TEMPLATE_PARAMS["id_column"]
87
- target_column = TEMPLATE_PARAMS["target"]
88
-
89
- # Distance-weighted neighbor statistics
90
- def weighted_quantile(values, weights, q):
91
- """Calculate weighted quantile"""
92
- if len(values) == 0:
93
- return np.nan
94
- sorted_indices = np.argsort(values)
95
- sorted_values = values[sorted_indices]
96
- sorted_weights = weights[sorted_indices]
97
- cumsum = np.cumsum(sorted_weights)
98
- cutoff = q * cumsum[-1]
99
- return np.interp(cutoff, cumsum, sorted_values)
100
-
101
- # Calculate distance weights (closer neighbors get more weight)
102
- prox_df = prox_df.copy()
103
- prox_df['weight'] = 1 / (1 + prox_df['distance'] ** distance_decay)
104
-
105
- # Get weighted quantiles and statistics for each ID
106
- neighbor_stats = []
107
- for id_val, group in prox_df.groupby(id_column):
108
- values = group[target_column].values
109
- weights = group['weight'].values
110
-
111
- # Normalize weights
112
- weights = weights / weights.sum()
113
-
114
- stats = {
115
- id_column: id_val,
116
- 'local_q025': weighted_quantile(values, weights, 0.025),
117
- 'local_q25': weighted_quantile(values, weights, 0.25),
118
- 'local_q75': weighted_quantile(values, weights, 0.75),
119
- 'local_q975': weighted_quantile(values, weights, 0.975),
120
- 'local_median': weighted_quantile(values, weights, 0.5),
121
- 'local_std': np.sqrt(np.average((values - np.average(values, weights=weights)) ** 2, weights=weights)),
122
- 'avg_distance': group['distance'].mean(),
123
- 'min_distance': group['distance'].min(),
124
- 'max_distance': group['distance'].max(),
125
- }
126
- neighbor_stats.append(stats)
127
-
128
- neighbor_df = pd.DataFrame(neighbor_stats)
129
- out = df_pred.merge(neighbor_df, on=id_column, how='left')
130
-
131
- # Model disagreement score (normalized by prediction std)
132
- model_disagreement = (out["prediction"] - out["prediction_uq"]).abs()
133
- disagreement_score = (model_disagreement / out["prediction_std"]).clip(0, 2)
134
-
135
- # Local confidence based on:
136
- # 1. How close the neighbors are (closer = more confident)
137
- # 2. How much local variance there is (less variance = more confident)
138
- max_reasonable_distance = out['max_distance'].quantile(0.8) # 80th percentile as reference
139
- distance_confidence = (1 - (out['avg_distance'] / max_reasonable_distance)).clip(0.1, 1.0)
140
-
141
- variance_confidence = (out["prediction_std"] / out["local_std"]).clip(0.5, 2.0)
142
- local_confidence = distance_confidence * variance_confidence.clip(0.5, 1.5)
143
-
144
- # Calibration weight: higher when models disagree and we have good local data
145
- calibration_weight = (
146
- calibration_strength *
147
- local_confidence * # Weight by local data quality
148
- disagreement_score.clip(0.3, 1.0) # More calibration when models disagree
149
- )
91
+ # Training mode
92
+ if category_mappings == {}:
93
+ for col in df.select_dtypes(include=["object", "string"]):
94
+ if col in features and df[col].nunique() < 20:
95
+ print(f"Training mode: Converting {col} to category")
96
+ df[col] = df[col].astype("category")
97
+ category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
98
+
99
+ # Inference mode
100
+ else:
101
+ for col, categories in category_mappings.items():
102
+ if col in df.columns:
103
+ print(f"Inference mode: Applying categorical mapping for {col}")
104
+ df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
150
105
 
151
- # Consensus prediction (slight preference for NGBoost since it provides intervals)
152
- consensus_pred = 0.65 * out["prediction_uq"] + 0.35 * out["prediction"]
106
+ return df, category_mappings
153
107
 
154
- # Re-center local intervals around consensus prediction
155
- local_center_offset = consensus_pred - out["local_median"]
156
108
 
157
- # Apply calibration to each quantile
158
- quantile_pairs = [
159
- ("q_025", "local_q025"),
160
- ("q_25", "local_q25"),
161
- ("q_75", "local_q75"),
162
- ("q_975", "local_q975")
163
- ]
109
+ def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
110
+ """Prepare features for the XGBoost model
111
+
112
+ Args:
113
+ df (pd.DataFrame): The features DataFrame
114
+ features (List[str]): Full list of feature names
115
+ compressed_features (List[str]): List of feature names to decompress (bitstrings)
116
+
117
+ Returns:
118
+ pd.DataFrame: DataFrame with the decompressed features
119
+ List[str]: Updated list of feature names after decompression
164
120
 
165
- for model_q, local_q in quantile_pairs:
166
- # Adjust local quantiles to be centered around consensus
167
- adjusted_local_q = out[local_q] + local_center_offset
121
+ Raises:
122
+ ValueError: If any missing values are found in the specified features
123
+ """
168
124
 
169
- # Blend model and local intervals
170
- out[model_q] = (
171
- (1 - calibration_weight) * out[model_q] +
172
- calibration_weight * adjusted_local_q
125
+ # Check for any missing values in the required features
126
+ missing_counts = df[features].isna().sum()
127
+ if missing_counts.any():
128
+ missing_features = missing_counts[missing_counts > 0]
129
+ print(
130
+ f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
131
+ "WARNING: You might want to remove/replace all NaN values before processing."
173
132
  )
174
133
 
175
- # Ensure proper interval ordering and bounds using pandas
176
- out["q_025"] = pd.concat([out["q_025"], consensus_pred], axis=1).min(axis=1)
177
- out["q_975"] = pd.concat([out["q_975"], consensus_pred], axis=1).max(axis=1)
178
- out["q_25"] = pd.concat([out["q_25"], out["q_75"]], axis=1).min(axis=1)
134
+ # Decompress the specified compressed features
135
+ decompressed_features = features
136
+ for feature in compressed_features:
137
+ if (feature not in df.columns) or (feature not in features):
138
+ print(f"Feature '{feature}' not in the features list, skipping decompression.")
139
+ continue
179
140
 
180
- # Optional: Add some interval expansion when neighbors are very far
181
- # (indicates we're in a sparse region of feature space)
182
- sparse_region_mask = out['min_distance'] > out['min_distance'].quantile(0.9)
183
- expansion_factor = 1 + 0.2 * sparse_region_mask # 20% expansion in sparse regions
141
+ # Remove the feature from the list of features to avoid duplication
142
+ decompressed_features.remove(feature)
184
143
 
185
- for q in ["q_025", "q_25", "q_75", "q_975"]:
186
- interval_width = out[q] - consensus_pred
187
- out[q] = consensus_pred + interval_width * expansion_factor
144
+ # Handle all compressed features as bitstrings
145
+ bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
146
+ prefix = feature[:3]
188
147
 
189
- # Clean up temporary columns
190
- cleanup_cols = [col for col in out.columns if col.startswith("local_")] + \
191
- ['avg_distance', 'min_distance', 'max_distance']
148
+ # Create all new columns at once - avoids fragmentation
149
+ new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
150
+ new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
192
151
 
193
- return out.drop(columns=cleanup_cols)
152
+ # Add to features list
153
+ decompressed_features.extend(new_col_names)
154
+
155
+ # Drop original column and concatenate new ones
156
+ df = df.drop(columns=[feature])
157
+ df = pd.concat([df, new_df], axis=1)
158
+
159
+ return df, decompressed_features
194
160
 
195
161
 
196
- # TRAINING SECTION
197
- #
198
- # This section (__main__) is where SageMaker will execute the training job
199
- # and save the model artifacts to the model directory.
200
- #
201
162
  if __name__ == "__main__":
202
163
  # Template Parameters
203
164
  id_column = TEMPLATE_PARAMS["id_column"]
204
- features = TEMPLATE_PARAMS["features"]
205
165
  target = TEMPLATE_PARAMS["target"]
166
+ features = TEMPLATE_PARAMS["features"]
167
+ orig_features = features.copy()
168
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
206
169
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
207
170
  track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
208
171
  validation_split = 0.2
@@ -216,34 +179,51 @@ if __name__ == "__main__":
216
179
  )
217
180
  args = parser.parse_args()
218
181
 
219
- # Load training data from the specified directory
182
+ # Read the training data into DataFrames
220
183
  training_files = [
221
184
  os.path.join(args.train, file)
222
- for file in os.listdir(args.train) if file.endswith(".csv")
185
+ for file in os.listdir(args.train)
186
+ if file.endswith(".csv")
223
187
  ]
224
188
  print(f"Training Files: {training_files}")
225
189
 
226
190
  # Combine files and read them all into a single pandas dataframe
227
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
191
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
228
192
 
229
- # Check if the DataFrame is empty
230
- check_dataframe(df, "training_df")
193
+ # Check if the dataframe is empty
194
+ check_dataframe(all_df, "training_df")
231
195
 
232
- # Training data split logic
196
+ # Features/Target output
197
+ print(f"Target: {target}")
198
+ print(f"Features: {str(features)}")
199
+
200
+ # Convert any features that might be categorical to 'category' type
201
+ all_df, category_mappings = convert_categorical_types(all_df, features)
202
+
203
+ # If we have compressed features, decompress them
204
+ if compressed_features:
205
+ print(f"Decompressing features {compressed_features}...")
206
+ all_df, features = decompress_features(all_df, features, compressed_features)
207
+
208
+ # Do we want to train on all the data?
233
209
  if train_all_data:
234
- # Use all data for both training and validation
235
- print("Training on all data...")
236
- df_train = df.copy()
237
- df_val = df.copy()
238
- elif "training" in df.columns:
239
- # Split data based on a 'training' column if it exists
240
- print("Splitting data based on 'training' column...")
241
- df_train = df[df["training"]].copy()
242
- df_val = df[~df["training"]].copy()
210
+ print("Training on ALL of the data")
211
+ df_train = all_df.copy()
212
+ df_val = all_df.copy()
213
+
214
+ # Does the dataframe have a training column?
215
+ elif "training" in all_df.columns:
216
+ print("Found training column, splitting data based on training column")
217
+ df_train = all_df[all_df["training"]]
218
+ df_val = all_df[~all_df["training"]]
243
219
  else:
244
- # Perform a random split if no 'training' column is found
245
- print("Splitting data randomly...")
246
- df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
220
+ # Just do a random training Split
221
+ print("WARNING: No training column found, splitting data with random state=42")
222
+ df_train, df_val = train_test_split(
223
+ all_df, test_size=validation_split, random_state=42
224
+ )
225
+ print(f"FIT/TRAIN: {df_train.shape}")
226
+ print(f"VALIDATION: {df_val.shape}")
247
227
 
248
228
  # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
249
229
  xgb_model = XGBRegressor()
@@ -251,18 +231,16 @@ if __name__ == "__main__":
251
231
 
252
232
  # Prepare features and targets for training
253
233
  X_train = df_train[features]
254
- X_val = df_val[features]
234
+ X_validate = df_val[features]
255
235
  y_train = df_train[target]
256
- y_val = df_val[target]
236
+ y_validate = df_val[target]
257
237
 
258
238
  # Train both models using the training data
259
239
  xgb_model.fit(X_train, y_train)
260
- ngb_model.fit(X_train, y_train, X_val=X_val, Y_val=y_val)
240
+ ngb_model.fit(X_train, y_train, X_val=X_validate, Y_val=y_validate)
261
241
 
262
242
  # Make Predictions on the Validation Set
263
243
  print(f"Making Predictions on Validation Set...")
264
- y_validate = df_val[target]
265
- X_validate = df_val[features]
266
244
  preds = xgb_model.predict(X_validate)
267
245
 
268
246
  # Calculate various model performance metrics (regression)
@@ -280,9 +258,9 @@ if __name__ == "__main__":
280
258
  # Save the trained NGBoost model
281
259
  joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
282
260
 
283
- # Save the feature list to validate input during predictions
261
+ # Save the features (this will validate input during predictions)
284
262
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
285
- json.dump(features, fp)
263
+ json.dump(orig_features, fp) # We save the original features, not the decompressed ones
286
264
 
287
265
  # Now the Proximity model
288
266
  model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
@@ -295,7 +273,7 @@ if __name__ == "__main__":
295
273
  # Inference Section
296
274
  #
297
275
  def model_fn(model_dir) -> dict:
298
- """Load and return XGBoost and NGBoost regressors from model directory."""
276
+ """Load and return XGBoost, NGBoost, and Prox Model from model directory."""
299
277
 
300
278
  # Load XGBoost regressor
301
279
  xgb_path = os.path.join(model_dir, "xgb_model.json")
@@ -376,18 +354,30 @@ def predict_fn(df, models) -> pd.DataFrame:
376
354
  df["prediction_std"] = dist_params['scale'] # standard deviation
377
355
 
378
356
  # Add 95% prediction intervals using ppf (percent point function)
379
- df["q_025"] = y_dists.ppf(0.025) # 2.5th percentile
380
- df["q_975"] = y_dists.ppf(0.975) # 97.5th percentile
357
+ # Note: Our hybrid model uses XGB point prediction and NGBoost UQ
358
+ # so we need to adjust the bounds to include the point prediction
359
+ df["q_025"] = np.minimum(y_dists.ppf(0.025), df["prediction"])
360
+ df["q_975"] = np.maximum(y_dists.ppf(0.975), df["prediction"])
361
+
362
+ # Add 90% prediction intervals
363
+ df["q_05"] = y_dists.ppf(0.05) # 5th percentile
364
+ df["q_95"] = y_dists.ppf(0.95) # 95th percentile
365
+
366
+ # Add 80% prediction intervals
367
+ df["q_10"] = y_dists.ppf(0.10) # 10th percentile
368
+ df["q_90"] = y_dists.ppf(0.90) # 90th percentile
381
369
 
382
370
  # Add 50% prediction intervals
383
- df["q_25"] = y_dists.ppf(0.25) # 25th percentile
384
- df["q_75"] = y_dists.ppf(0.75) # 75th percentile
371
+ df["q_25"] = y_dists.ppf(0.25) # 25th percentile
372
+ df["q_75"] = y_dists.ppf(0.75) # 75th percentile
385
373
 
386
- # Compute Nearest neighbors with Proximity model
387
- prox_df = models["proximity"].neighbors(df)
374
+ # Reorder the quantile columns for easier reading
375
+ quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
376
+ other_cols = [col for col in df.columns if col not in quantile_cols]
377
+ df = df[other_cols + quantile_cols]
388
378
 
389
- # Shrink prediction intervals based on KNN variance
390
- df = distance_weighted_calibrated_intervals(df, prox_df)
379
+ # Compute Nearest neighbors with Proximity model
380
+ models["proximity"].neighbors(df)
391
381
 
392
382
  # Return the modified DataFrame
393
383
  return df