workbench 0.8.192__py3-none-any.whl → 0.8.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +212 -234
  4. workbench/algorithms/graph/light/proximity_graph.py +8 -7
  5. workbench/api/endpoint.py +2 -3
  6. workbench/api/model.py +2 -5
  7. workbench/core/artifacts/endpoint_core.py +25 -16
  8. workbench/core/artifacts/feature_set_core.py +126 -4
  9. workbench/core/artifacts/model_core.py +37 -55
  10. workbench/core/transforms/features_to_model/features_to_model.py +3 -3
  11. workbench/core/views/training_view.py +75 -0
  12. workbench/core/views/view.py +1 -1
  13. workbench/model_scripts/custom_models/proximity/proximity.py +212 -234
  14. workbench/model_scripts/custom_models/uq_models/proximity.py +212 -234
  15. workbench/model_scripts/pytorch_model/generated_model_script.py +567 -0
  16. workbench/model_scripts/uq_models/generated_model_script.py +589 -0
  17. workbench/model_scripts/uq_models/mapie.template +103 -6
  18. workbench/model_scripts/xgb_model/generated_model_script.py +468 -0
  19. workbench/repl/workbench_shell.py +3 -3
  20. workbench/utils/model_utils.py +25 -10
  21. workbench/utils/xgboost_model_utils.py +117 -47
  22. workbench/web_interface/components/model_plot.py +7 -1
  23. workbench/web_interface/components/plugin_unit_test.py +5 -2
  24. workbench/web_interface/components/plugins/model_details.py +9 -7
  25. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/METADATA +23 -2
  26. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/RECORD +30 -27
  27. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/licenses/LICENSE +1 -1
  28. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/WHEEL +0 -0
  29. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/entry_points.txt +0 -0
  30. {workbench-0.8.192.dist-info → workbench-0.8.197.dist-info}/top_level.txt +0 -0
workbench/api/model.py CHANGED
@@ -83,16 +83,13 @@ class Model(ModelCore):
83
83
  end.set_owner(self.get_owner())
84
84
  return end
85
85
 
86
- def prox_model(self, filtered: bool = True):
86
+ def prox_model(self):
87
87
  """Create a local Proximity Model for this Model
88
88
 
89
- Args:
90
- filtered: bool, optional): Use filtered training data for the Proximity Model (default: True)
91
-
92
89
  Returns:
93
90
  Proximity: A local Proximity Model
94
91
  """
95
- return proximity_model_local(self, filtered=filtered)
92
+ return proximity_model_local(self)
96
93
 
97
94
  def uq_model(self, uq_model_name: str = None, train_all_data: bool = False) -> "Model":
98
95
  """Create a Uncertainty Quantification Model for this Model
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  import numpy as np
9
9
  from io import StringIO
10
10
  import awswrangler as wr
11
- from typing import Union, Optional, Tuple
11
+ from typing import Union, Optional
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
@@ -438,23 +438,27 @@ class EndpointCore(Artifact):
438
438
  # Return the prediction DataFrame
439
439
  return prediction_df
440
440
 
441
- def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
441
+ def cross_fold_inference(self, nfolds: int = 5) -> pd.DataFrame:
442
442
  """Run cross-fold inference (only works for XGBoost models)
443
443
 
444
444
  Args:
445
445
  nfolds (int): Number of folds to use for cross-fold (default: 5)
446
446
 
447
447
  Returns:
448
- Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
448
+ pd.DataFrame: A DataFrame with cross fold predictions
449
449
  """
450
450
 
451
451
  # Grab our model
452
452
  model = ModelCore(self.model_name)
453
453
 
454
- # Compute CrossFold Metrics
454
+ # Compute CrossFold (Metrics and Prediction Dataframe)
455
455
  cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
456
- if cross_fold_metrics:
457
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
456
+
457
+ # If the metrics dataframe isn't empty save to the param store
458
+ if not cross_fold_metrics.empty:
459
+ # Convert to list of dictionaries
460
+ metrics = cross_fold_metrics.to_dict(orient="records")
461
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
458
462
 
459
463
  # Capture the results
460
464
  capture_name = "full_cross_fold"
@@ -478,7 +482,9 @@ class EndpointCore(Artifact):
478
482
  uq_df = self.inference(training_df)
479
483
 
480
484
  # Identify UQ-specific columns (quantiles and prediction_std)
481
- uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
485
+ uq_columns = [
486
+ col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std" or col == "confidence"
487
+ ]
482
488
 
483
489
  # Merge UQ columns with out-of-fold predictions
484
490
  if uq_columns:
@@ -502,12 +508,12 @@ class EndpointCore(Artifact):
502
508
  out_of_fold_df,
503
509
  target_column,
504
510
  model_type,
505
- pd.DataFrame([cross_fold_metrics["summary_metrics"]]),
511
+ cross_fold_metrics,
506
512
  description,
507
513
  features=additional_columns,
508
514
  id_column=id_column,
509
515
  )
510
- return cross_fold_metrics, out_of_fold_df
516
+ return out_of_fold_df
511
517
 
512
518
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
513
519
  """Run inference on the Endpoint using the provided DataFrame
@@ -766,8 +772,8 @@ class EndpointCore(Artifact):
766
772
  # Add any _proba columns to the output columns
767
773
  output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
768
774
 
769
- # Add any quantile columns to the output columns
770
- output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col.startswith("qr_")]
775
+ # Add any Uncertainty Quantile columns to the output columns
776
+ output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
771
777
 
772
778
  # Add the ID column
773
779
  if id_column and id_column in pred_results_df.columns:
@@ -896,7 +902,7 @@ class EndpointCore(Artifact):
896
902
  else:
897
903
  self.validate_proba_columns(prediction_df, class_labels)
898
904
 
899
- # Calculate precision, recall, fscore, and support, handling zero division
905
+ # Calculate precision, recall, f1, and support, handling zero division
900
906
  prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
901
907
  scores = precision_recall_fscore_support(
902
908
  prediction_df[target_column],
@@ -931,7 +937,7 @@ class EndpointCore(Artifact):
931
937
  target_column: class_labels,
932
938
  "precision": scores[0],
933
939
  "recall": scores[1],
934
- "fscore": scores[2],
940
+ "f1": scores[2],
935
941
  "roc_auc": roc_auc_per_label,
936
942
  "support": scores[3],
937
943
  }
@@ -1039,7 +1045,7 @@ class EndpointCore(Artifact):
1039
1045
  # Recursively delete all endpoint S3 artifacts (inference, etc)
1040
1046
  # Note: We do not want to delete the data_capture/ files since these
1041
1047
  # might be used for collection and data drift analysis
1042
- base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
1048
+ base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}/"
1043
1049
  all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
1044
1050
 
1045
1051
  # Filter out objects that contain 'data_capture/' in their path
@@ -1194,7 +1200,8 @@ if __name__ == "__main__":
1194
1200
 
1195
1201
  # Test the cross_fold_inference method
1196
1202
  print("Running Cross-Fold Inference...")
1197
- metrics, all_results = my_endpoint.cross_fold_inference()
1203
+ all_results = my_endpoint.cross_fold_inference()
1204
+ print(all_results)
1198
1205
 
1199
1206
  # Run Inference and metrics for a Classification Endpoint
1200
1207
  class_endpoint = EndpointCore("wine-classification")
@@ -1206,7 +1213,9 @@ if __name__ == "__main__":
1206
1213
 
1207
1214
  # Test the cross_fold_inference method
1208
1215
  print("Running Cross-Fold Inference...")
1209
- metrics, all_results = class_endpoint.cross_fold_inference()
1216
+ all_results = class_endpoint.cross_fold_inference()
1217
+ print(all_results)
1218
+ print("All done...")
1210
1219
 
1211
1220
  # Test the class method delete (commented out for now)
1212
1221
  # from workbench.api import Model
@@ -551,6 +551,75 @@ class FeatureSetCore(Artifact):
551
551
  # Apply the filter
552
552
  self.set_training_filter(filter_expression)
553
553
 
554
+ def set_training_sampling(
555
+ self,
556
+ exclude_ids: Optional[List[Union[str, int]]] = None,
557
+ replicate_ids: Optional[List[Union[str, int]]] = None,
558
+ replication_factor: int = 2,
559
+ ):
560
+ """Configure training view with ID exclusions and replications (oversampling).
561
+
562
+ Args:
563
+ exclude_ids: List of IDs to exclude from training view
564
+ replicate_ids: List of IDs to replicate in training view for oversampling
565
+ replication_factor: Number of times to replicate each ID (default: 2)
566
+
567
+ Note:
568
+ If an ID appears in both lists, exclusion takes precedence.
569
+ """
570
+ from workbench.core.views import TrainingView
571
+
572
+ # Normalize to empty lists if None
573
+ exclude_ids = exclude_ids or []
574
+ replicate_ids = replicate_ids or []
575
+
576
+ # Remove any replicate_ids that are also in exclude_ids (exclusion wins)
577
+ replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
578
+
579
+ # If no sampling needed, just create normal view
580
+ if not exclude_ids and not replicate_ids:
581
+ self.log.important("No sampling specified, creating standard training view")
582
+ TrainingView.create(self, id_column=self.id_column)
583
+ return
584
+
585
+ # Build the custom SQL query
586
+ self.log.important(
587
+ f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
588
+ f"(factor: {replication_factor}x)"
589
+ )
590
+
591
+ # Helper to format IDs for SQL
592
+ def format_ids(ids):
593
+ return ", ".join([repr(id) for id in ids])
594
+
595
+ # Start with base query
596
+ base_query = f"SELECT * FROM {self.table}"
597
+
598
+ # Add exclusions if needed
599
+ if exclude_ids:
600
+ base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
601
+
602
+ # Build full query with replication
603
+ if replicate_ids:
604
+ # Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
605
+ # We want N-1 additional copies since the original row is already in base_query
606
+ values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
607
+
608
+ custom_sql = f"""{base_query}
609
+
610
+ UNION ALL
611
+
612
+ SELECT t.*
613
+ FROM {self.table} t
614
+ CROSS JOIN (VALUES {values_clause}) AS n(num)
615
+ WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
616
+ else:
617
+ # Only exclusions, no UNION needed
618
+ custom_sql = base_query
619
+
620
+ # Create the training view with our custom SQL
621
+ TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
622
+
554
623
  @classmethod
555
624
  def delete_views(cls, table: str, database: str):
556
625
  """Delete any views associated with this FeatureSet
@@ -709,7 +778,7 @@ if __name__ == "__main__":
709
778
  pd.set_option("display.width", 1000)
710
779
 
711
780
  # Grab a FeatureSet object and pull some information from it
712
- my_features = LocalFeatureSetCore("test_features")
781
+ my_features = LocalFeatureSetCore("abalone_features")
713
782
  if not my_features.exists():
714
783
  print("FeatureSet not found!")
715
784
  sys.exit(1)
@@ -769,8 +838,8 @@ if __name__ == "__main__":
769
838
  # Set the holdout ids for the training view
770
839
  print("Setting hold out ids...")
771
840
  table = my_features.view("training").table
772
- df = my_features.query(f'SELECT id, name FROM "{table}"')
773
- my_holdout_ids = [id for id in df["id"] if id < 20]
841
+ df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
842
+ my_holdout_ids = [id for id in df["auto_id"] if id < 20]
774
843
  my_features.set_training_holdouts(my_holdout_ids)
775
844
 
776
845
  # Get the training data
@@ -780,7 +849,7 @@ if __name__ == "__main__":
780
849
 
781
850
  # Test the filter expression functionality
782
851
  print("Setting a filter expression...")
783
- my_features.set_training_filter("id < 50 AND height > 65.0")
852
+ my_features.set_training_filter("auto_id < 50 AND length > 65.0")
784
853
  training_data = my_features.get_training_data()
785
854
  print(f"Training Data: {training_data.shape}")
786
855
  print(training_data)
@@ -803,3 +872,56 @@ if __name__ == "__main__":
803
872
  # print("Deleting Workbench Feature Set...")
804
873
  # my_features.delete()
805
874
  # print("Done")
875
+
876
+ # Test set_training_sampling with exclusions and replications
877
+ print("\n--- Testing set_training_sampling ---")
878
+ my_features.set_training_filter(None) # Reset any existing filters
879
+ original_count = num_rows
880
+
881
+ # Get valid IDs from the table
882
+ all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
883
+ valid_ids = sorted(all_data["auto_id"].tolist())
884
+ print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
885
+
886
+ exclude_list = valid_ids[0:3] # First 3 IDs
887
+ replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
888
+
889
+ print(f"Original row count: {original_count}")
890
+ print(f"Excluding IDs: {exclude_list}")
891
+ print(f"Replicating IDs: {replicate_list}")
892
+
893
+ # Test with default replication factor (2x)
894
+ print("\n--- Testing with replication_factor=2 (default) ---")
895
+ my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
896
+ training_data = my_features.get_training_data()
897
+ print(f"Training Data after sampling: {training_data.shape}")
898
+
899
+ # Verify exclusions
900
+ for exc_id in exclude_list:
901
+ count = len(training_data[training_data["auto_id"] == exc_id])
902
+ print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
903
+
904
+ # Verify replications
905
+ for rep_id in replicate_list:
906
+ count = len(training_data[training_data["auto_id"] == rep_id])
907
+ print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
908
+
909
+ # Test with replication factor of 5
910
+ print("\n--- Testing with replication_factor=5 ---")
911
+ replicate_list_5x = [20, 21]
912
+ my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
913
+ training_data = my_features.get_training_data()
914
+ print(f"Training Data after sampling: {training_data.shape}")
915
+
916
+ # Verify 5x replication
917
+ for rep_id in replicate_list_5x:
918
+ count = len(training_data[training_data["auto_id"] == rep_id])
919
+ print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
920
+
921
+ # Test with large replication list (simulate 100 IDs)
922
+ print("\n--- Testing with large ID list (100 IDs) ---")
923
+ large_replicate_list = list(range(30, 130)) # 100 IDs
924
+ my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
925
+ training_data = my_features.get_training_data()
926
+ print(f"Training Data after sampling: {training_data.shape}")
927
+ print(f"Expected extra rows: {len(large_replicate_list) * 3}")
@@ -41,52 +41,39 @@ class ModelType(Enum):
41
41
  class ModelImages:
42
42
  """Class for retrieving workbench inference images"""
43
43
 
44
- image_uris = {
45
- # US East 1 images
46
- ("us-east-1", "training", "0.1", "x86_64"): (
47
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
48
- ),
49
- ("us-east-1", "inference", "0.1", "x86_64"): (
50
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
51
- ),
52
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
53
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
54
- ),
55
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
56
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
57
- ),
58
- # US West 2 images
59
- ("us-west-2", "training", "0.1", "x86_64"): (
60
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
61
- ),
62
- ("us-west-2", "inference", "0.1", "x86_64"): (
63
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
64
- ),
65
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
66
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
67
- ),
68
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
69
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
70
- ),
71
- # ARM64 images
72
- # Meta Endpoint inference images
73
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
74
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
75
- ),
76
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
77
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
78
- ),
44
+ # Account ID
45
+ ACCOUNT_ID = "507740646243"
46
+
47
+ # Image name mappings
48
+ IMAGE_NAMES = {
49
+ "training": "py312-general-ml-training",
50
+ "inference": "py312-general-ml-inference",
51
+ "pytorch_training": "py312-pytorch-training",
52
+ "pytorch_inference": "py312-pytorch-inference",
53
+ "meta-endpoint": "py312-meta-endpoint",
79
54
  }
80
55
 
81
56
  @classmethod
82
- def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
83
- key = (region, image_type, version, architecture)
84
- if key in cls.image_uris:
85
- return cls.image_uris[key]
86
- else:
87
- raise ValueError(
88
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
89
- )
57
+ def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
58
+ """
59
+ Dynamically construct ECR image URI.
60
+
61
+ Args:
62
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
63
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
64
+ version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
65
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
66
+
67
+ Returns:
68
+ ECR image URI string
69
+ """
70
+ if image_type not in cls.IMAGE_NAMES:
71
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
72
+
73
+ image_name = cls.IMAGE_NAMES[image_type]
74
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
75
+
76
+ return uri
90
77
 
91
78
 
92
79
  class ModelCore(Artifact):
@@ -602,7 +589,7 @@ class ModelCore(Artifact):
602
589
  fs = FeatureSetCore(self.get_input())
603
590
 
604
591
  # See if we have a training view for this model
605
- my_model_training_view = f"{self.name.replace('-', '_')}_training"
592
+ my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
606
593
  view = View(fs, my_model_training_view, auto_create_view=False)
607
594
  if view.exists():
608
595
  return view
@@ -880,14 +867,6 @@ class ModelCore(Artifact):
880
867
  shap_data[key] = self.df_store.get(df_location)
881
868
  return shap_data or None
882
869
 
883
- def cross_folds(self) -> dict:
884
- """Retrieve the cross-fold inference results(only works for XGBoost models)
885
-
886
- Returns:
887
- dict: Dictionary with the cross-fold inference results
888
- """
889
- return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
890
-
891
870
  def supported_inference_instances(self) -> Optional[list]:
892
871
  """Retrieve the supported endpoint inference instance types
893
872
 
@@ -1184,13 +1163,11 @@ if __name__ == "__main__":
1184
1163
  # Grab a ModelCore object and pull some information from it
1185
1164
  my_model = ModelCore("abalone-regression")
1186
1165
 
1187
- # Call the various methods
1188
-
1189
1166
  # Let's do a check/validation of the Model
1190
1167
  print(f"Model Check: {my_model.exists()}")
1191
1168
 
1192
1169
  # Make sure the model is 'ready'
1193
- # my_model.onboard()
1170
+ my_model.onboard()
1194
1171
 
1195
1172
  # Get the ARN of the Model Group
1196
1173
  print(f"Model Group ARN: {my_model.group_arn()}")
@@ -1256,5 +1233,10 @@ if __name__ == "__main__":
1256
1233
  # Delete the Model
1257
1234
  # ModelCore.managed_delete("wine-classification")
1258
1235
 
1236
+ # Check the training view logic
1237
+ model = ModelCore("wine-class-test-251112-BW")
1238
+ training_view = model.training_view()
1239
+ print(f"Training View Name: {training_view.name}")
1240
+
1259
1241
  # Check for a model that doesn't exist
1260
1242
  my_model = ModelCore("empty-model-group")
@@ -210,7 +210,7 @@ class FeaturesToModel(Transform):
210
210
  raise ValueError(msg)
211
211
 
212
212
  # Dynamically create the metric definitions
213
- metrics = ["precision", "recall", "fscore"]
213
+ metrics = ["precision", "recall", "f1"]
214
214
  metric_definitions = []
215
215
  for t in self.class_labels:
216
216
  for m in metrics:
@@ -233,7 +233,7 @@ class FeaturesToModel(Transform):
233
233
  source_dir = str(Path(script_path).parent)
234
234
 
235
235
  # Create a Sagemaker Model with our script
236
- image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image, "0.1")
236
+ image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
237
237
  self.estimator = Estimator(
238
238
  entry_point=entry_point,
239
239
  source_dir=source_dir,
@@ -306,7 +306,7 @@ class FeaturesToModel(Transform):
306
306
 
307
307
  # Register our model
308
308
  image = ModelImages.get_image_uri(
309
- self.sm_session.boto_region_name, self.inference_image, "0.1", self.inference_arch
309
+ self.sm_session.boto_region_name, self.inference_image, architecture=self.inference_arch
310
310
  )
311
311
  self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
312
312
  model = self.estimator.create_model(role=self.workbench_role_arn)
@@ -116,6 +116,57 @@ class TrainingView(CreateView):
116
116
  # Return the View
117
117
  return View(instance.data_source, instance.view_name, auto_create_view=False)
118
118
 
119
+ @classmethod
120
+ def create_with_sql(
121
+ cls,
122
+ feature_set: FeatureSet,
123
+ *,
124
+ sql_query: str,
125
+ id_column: str = None,
126
+ ) -> Union[View, None]:
127
+ """Factory method to create a TrainingView from a custom SQL query.
128
+
129
+ This method takes a complete SQL query and adds the default 80/20 training split.
130
+ Use this when you need complex queries like UNION ALL for oversampling.
131
+
132
+ Args:
133
+ feature_set (FeatureSet): A FeatureSet object
134
+ sql_query (str): Complete SELECT query (without the final semicolon)
135
+ id_column (str, optional): The name of the id column for training split. Defaults to None.
136
+
137
+ Returns:
138
+ Union[View, None]: The created View object (or None if failed)
139
+ """
140
+ # Instantiate the TrainingView
141
+ instance = cls("training", feature_set)
142
+
143
+ # Sanity check on the id column
144
+ if not id_column:
145
+ instance.log.important("No id column specified, using auto_id_column")
146
+ if not instance.auto_id_column:
147
+ instance.log.error("No id column specified and no auto_id_column found, aborting")
148
+ return None
149
+ id_column = instance.auto_id_column
150
+
151
+ # Default 80/20 split using modulo
152
+ training_logic = f"""CASE
153
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
154
+ ELSE False
155
+ END AS training"""
156
+
157
+ # Wrap the custom query and add training column
158
+ create_view_query = f"""
159
+ CREATE OR REPLACE VIEW {instance.table} AS
160
+ SELECT *, {training_logic}
161
+ FROM ({sql_query}) AS custom_source
162
+ """
163
+
164
+ # Execute the CREATE VIEW query
165
+ instance.data_source.execute_statement(create_view_query)
166
+
167
+ # Return the View
168
+ return View(instance.data_source, instance.view_name, auto_create_view=False)
169
+
119
170
 
120
171
  if __name__ == "__main__":
121
172
  """Exercise the Training View functionality"""
@@ -154,3 +205,27 @@ if __name__ == "__main__":
154
205
  print(df.head())
155
206
  print(f"Shape with filter: {df.shape}")
156
207
  print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
208
+
209
+ # Test create_with_sql with a custom query (UNION ALL for oversampling)
210
+ print("\n--- Testing create_with_sql with oversampling ---")
211
+ base_table = fs.table
212
+ replicate_ids = [0, 1, 2] # Oversample these IDs
213
+
214
+ custom_sql = f"""
215
+ SELECT * FROM {base_table}
216
+
217
+ UNION ALL
218
+
219
+ SELECT * FROM {base_table}
220
+ WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
221
+ """
222
+
223
+ training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
224
+ df = training_view.pull_dataframe()
225
+ print(f"Shape with custom SQL: {df.shape}")
226
+ print(df["training"].value_counts())
227
+
228
+ # Verify oversampling - check if replicated IDs appear twice
229
+ for rep_id in replicate_ids:
230
+ count = len(df[df["auto_id"] == rep_id])
231
+ print(f"ID {rep_id} appears {count} times")
@@ -232,7 +232,7 @@ class View:
232
232
  view_definition = df.iloc[0]["view_definition"]
233
233
 
234
234
  # Create the new view with the destination name
235
- dest_table = f"{self.base_table_name}___{dest_view_name}"
235
+ dest_table = f"{self.base_table_name}___{dest_view_name.lower()}"
236
236
  create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
237
237
 
238
238
  self.log.important(f"Copying view {self.table} to {dest_table}...")