workbench 0.8.193__py3-none-any.whl → 0.8.198__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +212 -234
- workbench/algorithms/graph/light/proximity_graph.py +8 -7
- workbench/api/endpoint.py +2 -3
- workbench/api/model.py +2 -5
- workbench/core/artifacts/endpoint_core.py +25 -16
- workbench/core/artifacts/feature_set_core.py +126 -4
- workbench/core/artifacts/model_core.py +9 -14
- workbench/core/transforms/features_to_model/features_to_model.py +3 -3
- workbench/core/views/training_view.py +75 -0
- workbench/core/views/view.py +1 -1
- workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
- workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
- workbench/model_scripts/custom_models/proximity/proximity.py +212 -234
- workbench/model_scripts/custom_models/uq_models/proximity.py +212 -234
- workbench/model_scripts/pytorch_model/generated_model_script.py +567 -0
- workbench/model_scripts/uq_models/generated_model_script.py +589 -0
- workbench/model_scripts/uq_models/mapie.template +103 -6
- workbench/model_scripts/xgb_model/generated_model_script.py +4 -4
- workbench/repl/workbench_shell.py +3 -3
- workbench/utils/model_utils.py +10 -7
- workbench/utils/xgboost_model_utils.py +95 -34
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/model_details.py +2 -5
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/METADATA +1 -1
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/RECORD +31 -27
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/WHEEL +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.198.dist-info}/top_level.txt +0 -0
workbench/api/model.py
CHANGED
|
@@ -83,16 +83,13 @@ class Model(ModelCore):
|
|
|
83
83
|
end.set_owner(self.get_owner())
|
|
84
84
|
return end
|
|
85
85
|
|
|
86
|
-
def prox_model(self
|
|
86
|
+
def prox_model(self):
|
|
87
87
|
"""Create a local Proximity Model for this Model
|
|
88
88
|
|
|
89
|
-
Args:
|
|
90
|
-
filtered: bool, optional): Use filtered training data for the Proximity Model (default: True)
|
|
91
|
-
|
|
92
89
|
Returns:
|
|
93
90
|
Proximity: A local Proximity Model
|
|
94
91
|
"""
|
|
95
|
-
return proximity_model_local(self
|
|
92
|
+
return proximity_model_local(self)
|
|
96
93
|
|
|
97
94
|
def uq_model(self, uq_model_name: str = None, train_all_data: bool = False) -> "Model":
|
|
98
95
|
"""Create a Uncertainty Quantification Model for this Model
|
|
@@ -8,7 +8,7 @@ import pandas as pd
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from io import StringIO
|
|
10
10
|
import awswrangler as wr
|
|
11
|
-
from typing import Union, Optional
|
|
11
|
+
from typing import Union, Optional
|
|
12
12
|
import hashlib
|
|
13
13
|
|
|
14
14
|
# Model Performance Scores
|
|
@@ -438,23 +438,27 @@ class EndpointCore(Artifact):
|
|
|
438
438
|
# Return the prediction DataFrame
|
|
439
439
|
return prediction_df
|
|
440
440
|
|
|
441
|
-
def cross_fold_inference(self, nfolds: int = 5) ->
|
|
441
|
+
def cross_fold_inference(self, nfolds: int = 5) -> pd.DataFrame:
|
|
442
442
|
"""Run cross-fold inference (only works for XGBoost models)
|
|
443
443
|
|
|
444
444
|
Args:
|
|
445
445
|
nfolds (int): Number of folds to use for cross-fold (default: 5)
|
|
446
446
|
|
|
447
447
|
Returns:
|
|
448
|
-
|
|
448
|
+
pd.DataFrame: A DataFrame with cross fold predictions
|
|
449
449
|
"""
|
|
450
450
|
|
|
451
451
|
# Grab our model
|
|
452
452
|
model = ModelCore(self.model_name)
|
|
453
453
|
|
|
454
|
-
# Compute CrossFold Metrics
|
|
454
|
+
# Compute CrossFold (Metrics and Prediction Dataframe)
|
|
455
455
|
cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
|
|
456
|
-
|
|
457
|
-
|
|
456
|
+
|
|
457
|
+
# If the metrics dataframe isn't empty save to the param store
|
|
458
|
+
if not cross_fold_metrics.empty:
|
|
459
|
+
# Convert to list of dictionaries
|
|
460
|
+
metrics = cross_fold_metrics.to_dict(orient="records")
|
|
461
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
|
|
458
462
|
|
|
459
463
|
# Capture the results
|
|
460
464
|
capture_name = "full_cross_fold"
|
|
@@ -478,7 +482,9 @@ class EndpointCore(Artifact):
|
|
|
478
482
|
uq_df = self.inference(training_df)
|
|
479
483
|
|
|
480
484
|
# Identify UQ-specific columns (quantiles and prediction_std)
|
|
481
|
-
uq_columns = [
|
|
485
|
+
uq_columns = [
|
|
486
|
+
col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std" or col == "confidence"
|
|
487
|
+
]
|
|
482
488
|
|
|
483
489
|
# Merge UQ columns with out-of-fold predictions
|
|
484
490
|
if uq_columns:
|
|
@@ -502,12 +508,12 @@ class EndpointCore(Artifact):
|
|
|
502
508
|
out_of_fold_df,
|
|
503
509
|
target_column,
|
|
504
510
|
model_type,
|
|
505
|
-
|
|
511
|
+
cross_fold_metrics,
|
|
506
512
|
description,
|
|
507
513
|
features=additional_columns,
|
|
508
514
|
id_column=id_column,
|
|
509
515
|
)
|
|
510
|
-
return
|
|
516
|
+
return out_of_fold_df
|
|
511
517
|
|
|
512
518
|
def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
|
|
513
519
|
"""Run inference on the Endpoint using the provided DataFrame
|
|
@@ -766,8 +772,8 @@ class EndpointCore(Artifact):
|
|
|
766
772
|
# Add any _proba columns to the output columns
|
|
767
773
|
output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
|
|
768
774
|
|
|
769
|
-
# Add any
|
|
770
|
-
output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col
|
|
775
|
+
# Add any Uncertainty Quantile columns to the output columns
|
|
776
|
+
output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
|
|
771
777
|
|
|
772
778
|
# Add the ID column
|
|
773
779
|
if id_column and id_column in pred_results_df.columns:
|
|
@@ -896,7 +902,7 @@ class EndpointCore(Artifact):
|
|
|
896
902
|
else:
|
|
897
903
|
self.validate_proba_columns(prediction_df, class_labels)
|
|
898
904
|
|
|
899
|
-
# Calculate precision, recall,
|
|
905
|
+
# Calculate precision, recall, f1, and support, handling zero division
|
|
900
906
|
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
901
907
|
scores = precision_recall_fscore_support(
|
|
902
908
|
prediction_df[target_column],
|
|
@@ -931,7 +937,7 @@ class EndpointCore(Artifact):
|
|
|
931
937
|
target_column: class_labels,
|
|
932
938
|
"precision": scores[0],
|
|
933
939
|
"recall": scores[1],
|
|
934
|
-
"
|
|
940
|
+
"f1": scores[2],
|
|
935
941
|
"roc_auc": roc_auc_per_label,
|
|
936
942
|
"support": scores[3],
|
|
937
943
|
}
|
|
@@ -1039,7 +1045,7 @@ class EndpointCore(Artifact):
|
|
|
1039
1045
|
# Recursively delete all endpoint S3 artifacts (inference, etc)
|
|
1040
1046
|
# Note: We do not want to delete the data_capture/ files since these
|
|
1041
1047
|
# might be used for collection and data drift analysis
|
|
1042
|
-
base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
|
|
1048
|
+
base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}/"
|
|
1043
1049
|
all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
|
|
1044
1050
|
|
|
1045
1051
|
# Filter out objects that contain 'data_capture/' in their path
|
|
@@ -1194,7 +1200,8 @@ if __name__ == "__main__":
|
|
|
1194
1200
|
|
|
1195
1201
|
# Test the cross_fold_inference method
|
|
1196
1202
|
print("Running Cross-Fold Inference...")
|
|
1197
|
-
|
|
1203
|
+
all_results = my_endpoint.cross_fold_inference()
|
|
1204
|
+
print(all_results)
|
|
1198
1205
|
|
|
1199
1206
|
# Run Inference and metrics for a Classification Endpoint
|
|
1200
1207
|
class_endpoint = EndpointCore("wine-classification")
|
|
@@ -1206,7 +1213,9 @@ if __name__ == "__main__":
|
|
|
1206
1213
|
|
|
1207
1214
|
# Test the cross_fold_inference method
|
|
1208
1215
|
print("Running Cross-Fold Inference...")
|
|
1209
|
-
|
|
1216
|
+
all_results = class_endpoint.cross_fold_inference()
|
|
1217
|
+
print(all_results)
|
|
1218
|
+
print("All done...")
|
|
1210
1219
|
|
|
1211
1220
|
# Test the class method delete (commented out for now)
|
|
1212
1221
|
# from workbench.api import Model
|
|
@@ -551,6 +551,75 @@ class FeatureSetCore(Artifact):
|
|
|
551
551
|
# Apply the filter
|
|
552
552
|
self.set_training_filter(filter_expression)
|
|
553
553
|
|
|
554
|
+
def set_training_sampling(
|
|
555
|
+
self,
|
|
556
|
+
exclude_ids: Optional[List[Union[str, int]]] = None,
|
|
557
|
+
replicate_ids: Optional[List[Union[str, int]]] = None,
|
|
558
|
+
replication_factor: int = 2,
|
|
559
|
+
):
|
|
560
|
+
"""Configure training view with ID exclusions and replications (oversampling).
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
exclude_ids: List of IDs to exclude from training view
|
|
564
|
+
replicate_ids: List of IDs to replicate in training view for oversampling
|
|
565
|
+
replication_factor: Number of times to replicate each ID (default: 2)
|
|
566
|
+
|
|
567
|
+
Note:
|
|
568
|
+
If an ID appears in both lists, exclusion takes precedence.
|
|
569
|
+
"""
|
|
570
|
+
from workbench.core.views import TrainingView
|
|
571
|
+
|
|
572
|
+
# Normalize to empty lists if None
|
|
573
|
+
exclude_ids = exclude_ids or []
|
|
574
|
+
replicate_ids = replicate_ids or []
|
|
575
|
+
|
|
576
|
+
# Remove any replicate_ids that are also in exclude_ids (exclusion wins)
|
|
577
|
+
replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
|
|
578
|
+
|
|
579
|
+
# If no sampling needed, just create normal view
|
|
580
|
+
if not exclude_ids and not replicate_ids:
|
|
581
|
+
self.log.important("No sampling specified, creating standard training view")
|
|
582
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
583
|
+
return
|
|
584
|
+
|
|
585
|
+
# Build the custom SQL query
|
|
586
|
+
self.log.important(
|
|
587
|
+
f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
|
|
588
|
+
f"(factor: {replication_factor}x)"
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# Helper to format IDs for SQL
|
|
592
|
+
def format_ids(ids):
|
|
593
|
+
return ", ".join([repr(id) for id in ids])
|
|
594
|
+
|
|
595
|
+
# Start with base query
|
|
596
|
+
base_query = f"SELECT * FROM {self.table}"
|
|
597
|
+
|
|
598
|
+
# Add exclusions if needed
|
|
599
|
+
if exclude_ids:
|
|
600
|
+
base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
|
|
601
|
+
|
|
602
|
+
# Build full query with replication
|
|
603
|
+
if replicate_ids:
|
|
604
|
+
# Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
|
|
605
|
+
# We want N-1 additional copies since the original row is already in base_query
|
|
606
|
+
values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
|
|
607
|
+
|
|
608
|
+
custom_sql = f"""{base_query}
|
|
609
|
+
|
|
610
|
+
UNION ALL
|
|
611
|
+
|
|
612
|
+
SELECT t.*
|
|
613
|
+
FROM {self.table} t
|
|
614
|
+
CROSS JOIN (VALUES {values_clause}) AS n(num)
|
|
615
|
+
WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
|
|
616
|
+
else:
|
|
617
|
+
# Only exclusions, no UNION needed
|
|
618
|
+
custom_sql = base_query
|
|
619
|
+
|
|
620
|
+
# Create the training view with our custom SQL
|
|
621
|
+
TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
|
|
622
|
+
|
|
554
623
|
@classmethod
|
|
555
624
|
def delete_views(cls, table: str, database: str):
|
|
556
625
|
"""Delete any views associated with this FeatureSet
|
|
@@ -709,7 +778,7 @@ if __name__ == "__main__":
|
|
|
709
778
|
pd.set_option("display.width", 1000)
|
|
710
779
|
|
|
711
780
|
# Grab a FeatureSet object and pull some information from it
|
|
712
|
-
my_features = LocalFeatureSetCore("
|
|
781
|
+
my_features = LocalFeatureSetCore("abalone_features")
|
|
713
782
|
if not my_features.exists():
|
|
714
783
|
print("FeatureSet not found!")
|
|
715
784
|
sys.exit(1)
|
|
@@ -769,8 +838,8 @@ if __name__ == "__main__":
|
|
|
769
838
|
# Set the holdout ids for the training view
|
|
770
839
|
print("Setting hold out ids...")
|
|
771
840
|
table = my_features.view("training").table
|
|
772
|
-
df = my_features.query(f'SELECT
|
|
773
|
-
my_holdout_ids = [id for id in df["
|
|
841
|
+
df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
842
|
+
my_holdout_ids = [id for id in df["auto_id"] if id < 20]
|
|
774
843
|
my_features.set_training_holdouts(my_holdout_ids)
|
|
775
844
|
|
|
776
845
|
# Get the training data
|
|
@@ -780,7 +849,7 @@ if __name__ == "__main__":
|
|
|
780
849
|
|
|
781
850
|
# Test the filter expression functionality
|
|
782
851
|
print("Setting a filter expression...")
|
|
783
|
-
my_features.set_training_filter("
|
|
852
|
+
my_features.set_training_filter("auto_id < 50 AND length > 65.0")
|
|
784
853
|
training_data = my_features.get_training_data()
|
|
785
854
|
print(f"Training Data: {training_data.shape}")
|
|
786
855
|
print(training_data)
|
|
@@ -803,3 +872,56 @@ if __name__ == "__main__":
|
|
|
803
872
|
# print("Deleting Workbench Feature Set...")
|
|
804
873
|
# my_features.delete()
|
|
805
874
|
# print("Done")
|
|
875
|
+
|
|
876
|
+
# Test set_training_sampling with exclusions and replications
|
|
877
|
+
print("\n--- Testing set_training_sampling ---")
|
|
878
|
+
my_features.set_training_filter(None) # Reset any existing filters
|
|
879
|
+
original_count = num_rows
|
|
880
|
+
|
|
881
|
+
# Get valid IDs from the table
|
|
882
|
+
all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
883
|
+
valid_ids = sorted(all_data["auto_id"].tolist())
|
|
884
|
+
print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
|
|
885
|
+
|
|
886
|
+
exclude_list = valid_ids[0:3] # First 3 IDs
|
|
887
|
+
replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
|
|
888
|
+
|
|
889
|
+
print(f"Original row count: {original_count}")
|
|
890
|
+
print(f"Excluding IDs: {exclude_list}")
|
|
891
|
+
print(f"Replicating IDs: {replicate_list}")
|
|
892
|
+
|
|
893
|
+
# Test with default replication factor (2x)
|
|
894
|
+
print("\n--- Testing with replication_factor=2 (default) ---")
|
|
895
|
+
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
|
|
896
|
+
training_data = my_features.get_training_data()
|
|
897
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
898
|
+
|
|
899
|
+
# Verify exclusions
|
|
900
|
+
for exc_id in exclude_list:
|
|
901
|
+
count = len(training_data[training_data["auto_id"] == exc_id])
|
|
902
|
+
print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
|
|
903
|
+
|
|
904
|
+
# Verify replications
|
|
905
|
+
for rep_id in replicate_list:
|
|
906
|
+
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
907
|
+
print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
|
|
908
|
+
|
|
909
|
+
# Test with replication factor of 5
|
|
910
|
+
print("\n--- Testing with replication_factor=5 ---")
|
|
911
|
+
replicate_list_5x = [20, 21]
|
|
912
|
+
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
|
|
913
|
+
training_data = my_features.get_training_data()
|
|
914
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
915
|
+
|
|
916
|
+
# Verify 5x replication
|
|
917
|
+
for rep_id in replicate_list_5x:
|
|
918
|
+
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
919
|
+
print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
|
|
920
|
+
|
|
921
|
+
# Test with large replication list (simulate 100 IDs)
|
|
922
|
+
print("\n--- Testing with large ID list (100 IDs) ---")
|
|
923
|
+
large_replicate_list = list(range(30, 130)) # 100 IDs
|
|
924
|
+
my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
|
|
925
|
+
training_data = my_features.get_training_data()
|
|
926
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
927
|
+
print(f"Expected extra rows: {len(large_replicate_list) * 3}")
|
|
@@ -54,14 +54,14 @@ class ModelImages:
|
|
|
54
54
|
}
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
|
-
def get_image_uri(cls, region, image_type, version="
|
|
57
|
+
def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
|
|
58
58
|
"""
|
|
59
59
|
Dynamically construct ECR image URI.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
62
|
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
63
63
|
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
64
|
-
version: Image version (e.g., '0.1', '0.2')
|
|
64
|
+
version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
|
|
65
65
|
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
66
66
|
|
|
67
67
|
Returns:
|
|
@@ -589,7 +589,7 @@ class ModelCore(Artifact):
|
|
|
589
589
|
fs = FeatureSetCore(self.get_input())
|
|
590
590
|
|
|
591
591
|
# See if we have a training view for this model
|
|
592
|
-
my_model_training_view = f"{self.name.replace('-', '_')}_training"
|
|
592
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
|
|
593
593
|
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
594
594
|
if view.exists():
|
|
595
595
|
return view
|
|
@@ -867,14 +867,6 @@ class ModelCore(Artifact):
|
|
|
867
867
|
shap_data[key] = self.df_store.get(df_location)
|
|
868
868
|
return shap_data or None
|
|
869
869
|
|
|
870
|
-
def cross_folds(self) -> dict:
|
|
871
|
-
"""Retrieve the cross-fold inference results(only works for XGBoost models)
|
|
872
|
-
|
|
873
|
-
Returns:
|
|
874
|
-
dict: Dictionary with the cross-fold inference results
|
|
875
|
-
"""
|
|
876
|
-
return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
|
|
877
|
-
|
|
878
870
|
def supported_inference_instances(self) -> Optional[list]:
|
|
879
871
|
"""Retrieve the supported endpoint inference instance types
|
|
880
872
|
|
|
@@ -1171,13 +1163,11 @@ if __name__ == "__main__":
|
|
|
1171
1163
|
# Grab a ModelCore object and pull some information from it
|
|
1172
1164
|
my_model = ModelCore("abalone-regression")
|
|
1173
1165
|
|
|
1174
|
-
# Call the various methods
|
|
1175
|
-
|
|
1176
1166
|
# Let's do a check/validation of the Model
|
|
1177
1167
|
print(f"Model Check: {my_model.exists()}")
|
|
1178
1168
|
|
|
1179
1169
|
# Make sure the model is 'ready'
|
|
1180
|
-
|
|
1170
|
+
my_model.onboard()
|
|
1181
1171
|
|
|
1182
1172
|
# Get the ARN of the Model Group
|
|
1183
1173
|
print(f"Model Group ARN: {my_model.group_arn()}")
|
|
@@ -1243,5 +1233,10 @@ if __name__ == "__main__":
|
|
|
1243
1233
|
# Delete the Model
|
|
1244
1234
|
# ModelCore.managed_delete("wine-classification")
|
|
1245
1235
|
|
|
1236
|
+
# Check the training view logic
|
|
1237
|
+
model = ModelCore("wine-class-test-251112-BW")
|
|
1238
|
+
training_view = model.training_view()
|
|
1239
|
+
print(f"Training View Name: {training_view.name}")
|
|
1240
|
+
|
|
1246
1241
|
# Check for a model that doesn't exist
|
|
1247
1242
|
my_model = ModelCore("empty-model-group")
|
|
@@ -210,7 +210,7 @@ class FeaturesToModel(Transform):
|
|
|
210
210
|
raise ValueError(msg)
|
|
211
211
|
|
|
212
212
|
# Dynamically create the metric definitions
|
|
213
|
-
metrics = ["precision", "recall", "
|
|
213
|
+
metrics = ["precision", "recall", "f1"]
|
|
214
214
|
metric_definitions = []
|
|
215
215
|
for t in self.class_labels:
|
|
216
216
|
for m in metrics:
|
|
@@ -233,7 +233,7 @@ class FeaturesToModel(Transform):
|
|
|
233
233
|
source_dir = str(Path(script_path).parent)
|
|
234
234
|
|
|
235
235
|
# Create a Sagemaker Model with our script
|
|
236
|
-
image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image
|
|
236
|
+
image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
|
|
237
237
|
self.estimator = Estimator(
|
|
238
238
|
entry_point=entry_point,
|
|
239
239
|
source_dir=source_dir,
|
|
@@ -306,7 +306,7 @@ class FeaturesToModel(Transform):
|
|
|
306
306
|
|
|
307
307
|
# Register our model
|
|
308
308
|
image = ModelImages.get_image_uri(
|
|
309
|
-
self.sm_session.boto_region_name, self.inference_image,
|
|
309
|
+
self.sm_session.boto_region_name, self.inference_image, architecture=self.inference_arch
|
|
310
310
|
)
|
|
311
311
|
self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
|
|
312
312
|
model = self.estimator.create_model(role=self.workbench_role_arn)
|
|
@@ -116,6 +116,57 @@ class TrainingView(CreateView):
|
|
|
116
116
|
# Return the View
|
|
117
117
|
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
118
118
|
|
|
119
|
+
@classmethod
|
|
120
|
+
def create_with_sql(
|
|
121
|
+
cls,
|
|
122
|
+
feature_set: FeatureSet,
|
|
123
|
+
*,
|
|
124
|
+
sql_query: str,
|
|
125
|
+
id_column: str = None,
|
|
126
|
+
) -> Union[View, None]:
|
|
127
|
+
"""Factory method to create a TrainingView from a custom SQL query.
|
|
128
|
+
|
|
129
|
+
This method takes a complete SQL query and adds the default 80/20 training split.
|
|
130
|
+
Use this when you need complex queries like UNION ALL for oversampling.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
feature_set (FeatureSet): A FeatureSet object
|
|
134
|
+
sql_query (str): Complete SELECT query (without the final semicolon)
|
|
135
|
+
id_column (str, optional): The name of the id column for training split. Defaults to None.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Union[View, None]: The created View object (or None if failed)
|
|
139
|
+
"""
|
|
140
|
+
# Instantiate the TrainingView
|
|
141
|
+
instance = cls("training", feature_set)
|
|
142
|
+
|
|
143
|
+
# Sanity check on the id column
|
|
144
|
+
if not id_column:
|
|
145
|
+
instance.log.important("No id column specified, using auto_id_column")
|
|
146
|
+
if not instance.auto_id_column:
|
|
147
|
+
instance.log.error("No id column specified and no auto_id_column found, aborting")
|
|
148
|
+
return None
|
|
149
|
+
id_column = instance.auto_id_column
|
|
150
|
+
|
|
151
|
+
# Default 80/20 split using modulo
|
|
152
|
+
training_logic = f"""CASE
|
|
153
|
+
WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
|
|
154
|
+
ELSE False
|
|
155
|
+
END AS training"""
|
|
156
|
+
|
|
157
|
+
# Wrap the custom query and add training column
|
|
158
|
+
create_view_query = f"""
|
|
159
|
+
CREATE OR REPLACE VIEW {instance.table} AS
|
|
160
|
+
SELECT *, {training_logic}
|
|
161
|
+
FROM ({sql_query}) AS custom_source
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
# Execute the CREATE VIEW query
|
|
165
|
+
instance.data_source.execute_statement(create_view_query)
|
|
166
|
+
|
|
167
|
+
# Return the View
|
|
168
|
+
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
169
|
+
|
|
119
170
|
|
|
120
171
|
if __name__ == "__main__":
|
|
121
172
|
"""Exercise the Training View functionality"""
|
|
@@ -154,3 +205,27 @@ if __name__ == "__main__":
|
|
|
154
205
|
print(df.head())
|
|
155
206
|
print(f"Shape with filter: {df.shape}")
|
|
156
207
|
print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
|
|
208
|
+
|
|
209
|
+
# Test create_with_sql with a custom query (UNION ALL for oversampling)
|
|
210
|
+
print("\n--- Testing create_with_sql with oversampling ---")
|
|
211
|
+
base_table = fs.table
|
|
212
|
+
replicate_ids = [0, 1, 2] # Oversample these IDs
|
|
213
|
+
|
|
214
|
+
custom_sql = f"""
|
|
215
|
+
SELECT * FROM {base_table}
|
|
216
|
+
|
|
217
|
+
UNION ALL
|
|
218
|
+
|
|
219
|
+
SELECT * FROM {base_table}
|
|
220
|
+
WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
|
|
224
|
+
df = training_view.pull_dataframe()
|
|
225
|
+
print(f"Shape with custom SQL: {df.shape}")
|
|
226
|
+
print(df["training"].value_counts())
|
|
227
|
+
|
|
228
|
+
# Verify oversampling - check if replicated IDs appear twice
|
|
229
|
+
for rep_id in replicate_ids:
|
|
230
|
+
count = len(df[df["auto_id"] == rep_id])
|
|
231
|
+
print(f"ID {rep_id} appears {count} times")
|
workbench/core/views/view.py
CHANGED
|
@@ -232,7 +232,7 @@ class View:
|
|
|
232
232
|
view_definition = df.iloc[0]["view_definition"]
|
|
233
233
|
|
|
234
234
|
# Create the new view with the destination name
|
|
235
|
-
dest_table = f"{self.base_table_name}___{dest_view_name}"
|
|
235
|
+
dest_table = f"{self.base_table_name}___{dest_view_name.lower()}"
|
|
236
236
|
create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
|
|
237
237
|
|
|
238
238
|
self.log.important(f"Copying view {self.table} to {dest_table}...")
|
|
Binary file
|
|
Binary file
|