workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
|
|
|
16
16
|
from workbench.core.artifacts.artifact import Artifact
|
|
17
17
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
18
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
19
20
|
|
|
20
|
-
from typing import TYPE_CHECKING
|
|
21
|
+
from typing import TYPE_CHECKING, Optional, List, Dict, Union
|
|
21
22
|
|
|
22
23
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
24
|
|
|
@@ -194,24 +195,24 @@ class FeatureSetCore(Artifact):
|
|
|
194
195
|
|
|
195
196
|
return View(self, view_name)
|
|
196
197
|
|
|
197
|
-
def set_display_columns(self,
|
|
198
|
+
def set_display_columns(self, display_columns: list[str]):
|
|
198
199
|
"""Set the display columns for this Data Source
|
|
199
200
|
|
|
200
201
|
Args:
|
|
201
|
-
|
|
202
|
+
display_columns (list[str]): The display columns for this Data Source
|
|
202
203
|
"""
|
|
203
204
|
# Check mismatch of display columns to computation columns
|
|
204
205
|
c_view = self.view("computation")
|
|
205
206
|
computation_columns = c_view.columns
|
|
206
|
-
mismatch_columns = [col for col in
|
|
207
|
+
mismatch_columns = [col for col in display_columns if col not in computation_columns]
|
|
207
208
|
if mismatch_columns:
|
|
208
209
|
self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
|
|
209
210
|
|
|
210
|
-
self.log.important(f"Setting Display Columns...{
|
|
211
|
+
self.log.important(f"Setting Display Columns...{display_columns}")
|
|
211
212
|
from workbench.core.views import DisplayView
|
|
212
213
|
|
|
213
214
|
# Create a NEW display view
|
|
214
|
-
DisplayView.create(self, source_table=c_view.table, column_list=
|
|
215
|
+
DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
|
|
215
216
|
|
|
216
217
|
def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
|
|
217
218
|
"""Set the computation columns for this FeatureSet
|
|
@@ -509,6 +510,184 @@ class FeatureSetCore(Artifact):
|
|
|
509
510
|
].tolist()
|
|
510
511
|
return hold_out_ids
|
|
511
512
|
|
|
513
|
+
def set_sample_weights(
|
|
514
|
+
self,
|
|
515
|
+
weight_dict: Dict[Union[str, int], float],
|
|
516
|
+
default_weight: float = 1.0,
|
|
517
|
+
exclude_zero_weights: bool = True,
|
|
518
|
+
):
|
|
519
|
+
"""Configure training view with sample weights for each ID.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
weight_dict: Mapping of ID to sample weight
|
|
523
|
+
- weight > 1.0: oversample/emphasize
|
|
524
|
+
- weight = 1.0: normal (default)
|
|
525
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
526
|
+
- weight = 0.0: exclude from training
|
|
527
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
528
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
529
|
+
|
|
530
|
+
Example:
|
|
531
|
+
weights = {
|
|
532
|
+
'compound_42': 3.0, # oversample 3x
|
|
533
|
+
'compound_99': 0.1, # noisy, downweight
|
|
534
|
+
'compound_123': 0.0, # exclude from training
|
|
535
|
+
}
|
|
536
|
+
model.set_sample_weights(weights) # zeros automatically excluded
|
|
537
|
+
model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
538
|
+
"""
|
|
539
|
+
from workbench.core.views import TrainingView
|
|
540
|
+
|
|
541
|
+
if not weight_dict:
|
|
542
|
+
self.log.important("Empty weight_dict, creating standard training view")
|
|
543
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
547
|
+
|
|
548
|
+
# Helper to format IDs for SQL
|
|
549
|
+
def format_id(id_val):
|
|
550
|
+
return repr(id_val)
|
|
551
|
+
|
|
552
|
+
# Build CASE statement for sample_weight
|
|
553
|
+
case_conditions = [
|
|
554
|
+
f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
|
|
555
|
+
]
|
|
556
|
+
case_statement = "\n ".join(case_conditions)
|
|
557
|
+
|
|
558
|
+
# Build inner query with sample weights
|
|
559
|
+
inner_sql = f"""SELECT
|
|
560
|
+
*,
|
|
561
|
+
CASE
|
|
562
|
+
{case_statement}
|
|
563
|
+
ELSE {default_weight}
|
|
564
|
+
END AS sample_weight
|
|
565
|
+
FROM {self.table}"""
|
|
566
|
+
|
|
567
|
+
# Optionally filter out zero weights
|
|
568
|
+
if exclude_zero_weights:
|
|
569
|
+
zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
|
|
570
|
+
custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
571
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
572
|
+
else:
|
|
573
|
+
custom_sql = inner_sql
|
|
574
|
+
|
|
575
|
+
TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
|
|
576
|
+
|
|
577
|
+
@deprecated(version=0.9)
|
|
578
|
+
def set_training_filter(self, filter_expression: Optional[str] = None):
|
|
579
|
+
"""Set a filter expression for the training view for this FeatureSet
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
|
|
583
|
+
If None or empty string, will reset to training view with no filter
|
|
584
|
+
(default: None)
|
|
585
|
+
"""
|
|
586
|
+
from workbench.core.views import TrainingView
|
|
587
|
+
|
|
588
|
+
# Grab the existing holdout ids
|
|
589
|
+
holdout_ids = self.get_training_holdouts()
|
|
590
|
+
|
|
591
|
+
# Create a NEW training view
|
|
592
|
+
self.log.important(f"Setting Training Filter: {filter_expression}")
|
|
593
|
+
TrainingView.create(
|
|
594
|
+
self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
@deprecated(version="0.9")
|
|
598
|
+
def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
|
|
599
|
+
"""Exclude a list of IDs from the training view
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
ids (List[Union[str, int]],): List of IDs to exclude from training
|
|
603
|
+
column_name (Optional[str]): Column name to filter on.
|
|
604
|
+
If None, uses self.id_column (default: None)
|
|
605
|
+
"""
|
|
606
|
+
# Use the default id_column if not specified
|
|
607
|
+
column = column_name or self.id_column
|
|
608
|
+
|
|
609
|
+
# Handle empty list case
|
|
610
|
+
if not ids:
|
|
611
|
+
self.log.warning("No IDs provided to exclude")
|
|
612
|
+
return
|
|
613
|
+
|
|
614
|
+
# Build the filter expression with proper SQL quoting
|
|
615
|
+
quoted_ids = ", ".join([repr(id) for id in ids])
|
|
616
|
+
filter_expression = f"{column} NOT IN ({quoted_ids})"
|
|
617
|
+
|
|
618
|
+
# Apply the filter
|
|
619
|
+
self.set_training_filter(filter_expression)
|
|
620
|
+
|
|
621
|
+
@deprecated(version="0.9")
|
|
622
|
+
def set_training_sampling(
|
|
623
|
+
self,
|
|
624
|
+
exclude_ids: Optional[List[Union[str, int]]] = None,
|
|
625
|
+
replicate_ids: Optional[List[Union[str, int]]] = None,
|
|
626
|
+
replication_factor: int = 2,
|
|
627
|
+
):
|
|
628
|
+
"""Configure training view with ID exclusions and replications (oversampling).
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
exclude_ids: List of IDs to exclude from training view
|
|
632
|
+
replicate_ids: List of IDs to replicate in training view for oversampling
|
|
633
|
+
replication_factor: Number of times to replicate each ID (default: 2)
|
|
634
|
+
|
|
635
|
+
Note:
|
|
636
|
+
If an ID appears in both lists, exclusion takes precedence.
|
|
637
|
+
"""
|
|
638
|
+
from workbench.core.views import TrainingView
|
|
639
|
+
|
|
640
|
+
# Normalize to empty lists if None
|
|
641
|
+
exclude_ids = exclude_ids or []
|
|
642
|
+
replicate_ids = replicate_ids or []
|
|
643
|
+
|
|
644
|
+
# Remove any replicate_ids that are also in exclude_ids (exclusion wins)
|
|
645
|
+
replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
|
|
646
|
+
|
|
647
|
+
# If no sampling needed, just create normal view
|
|
648
|
+
if not exclude_ids and not replicate_ids:
|
|
649
|
+
self.log.important("No sampling specified, creating standard training view")
|
|
650
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
# Build the custom SQL query
|
|
654
|
+
self.log.important(
|
|
655
|
+
f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
|
|
656
|
+
f"(factor: {replication_factor}x)"
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# Helper to format IDs for SQL
|
|
660
|
+
def format_ids(ids):
|
|
661
|
+
return ", ".join([repr(id) for id in ids])
|
|
662
|
+
|
|
663
|
+
# Start with base query
|
|
664
|
+
base_query = f"SELECT * FROM {self.table}"
|
|
665
|
+
|
|
666
|
+
# Add exclusions if needed
|
|
667
|
+
if exclude_ids:
|
|
668
|
+
base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
|
|
669
|
+
|
|
670
|
+
# Build full query with replication
|
|
671
|
+
if replicate_ids:
|
|
672
|
+
# Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
|
|
673
|
+
# We want N-1 additional copies since the original row is already in base_query
|
|
674
|
+
values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
|
|
675
|
+
|
|
676
|
+
custom_sql = f"""{base_query}
|
|
677
|
+
|
|
678
|
+
UNION ALL
|
|
679
|
+
|
|
680
|
+
SELECT t.*
|
|
681
|
+
FROM {self.table} t
|
|
682
|
+
CROSS JOIN (VALUES {values_clause}) AS n(num)
|
|
683
|
+
WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
|
|
684
|
+
else:
|
|
685
|
+
# Only exclusions, no UNION needed
|
|
686
|
+
custom_sql = base_query
|
|
687
|
+
|
|
688
|
+
# Create the training view with our custom SQL
|
|
689
|
+
TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
|
|
690
|
+
|
|
512
691
|
@classmethod
|
|
513
692
|
def delete_views(cls, table: str, database: str):
|
|
514
693
|
"""Delete any views associated with this FeatureSet
|
|
@@ -667,7 +846,7 @@ if __name__ == "__main__":
|
|
|
667
846
|
pd.set_option("display.width", 1000)
|
|
668
847
|
|
|
669
848
|
# Grab a FeatureSet object and pull some information from it
|
|
670
|
-
my_features = LocalFeatureSetCore("
|
|
849
|
+
my_features = LocalFeatureSetCore("abalone_features")
|
|
671
850
|
if not my_features.exists():
|
|
672
851
|
print("FeatureSet not found!")
|
|
673
852
|
sys.exit(1)
|
|
@@ -707,7 +886,7 @@ if __name__ == "__main__":
|
|
|
707
886
|
|
|
708
887
|
# Test getting the holdout ids
|
|
709
888
|
print("Getting the hold out ids...")
|
|
710
|
-
holdout_ids = my_features.get_training_holdouts(
|
|
889
|
+
holdout_ids = my_features.get_training_holdouts()
|
|
711
890
|
print(f"Holdout IDs: {holdout_ids}")
|
|
712
891
|
|
|
713
892
|
# Get a sample of the data
|
|
@@ -727,20 +906,90 @@ if __name__ == "__main__":
|
|
|
727
906
|
# Set the holdout ids for the training view
|
|
728
907
|
print("Setting hold out ids...")
|
|
729
908
|
table = my_features.view("training").table
|
|
730
|
-
df = my_features.query(f'SELECT
|
|
731
|
-
my_holdout_ids = [id for id in df["
|
|
732
|
-
my_features.set_training_holdouts(
|
|
733
|
-
|
|
734
|
-
# Test the hold out set functionality with strings
|
|
735
|
-
print("Setting hold out ids (strings)...")
|
|
736
|
-
my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
|
|
737
|
-
my_features.set_training_holdouts("name", my_holdout_ids)
|
|
909
|
+
df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
910
|
+
my_holdout_ids = [id for id in df["auto_id"] if id < 20]
|
|
911
|
+
my_features.set_training_holdouts(my_holdout_ids)
|
|
738
912
|
|
|
739
913
|
# Get the training data
|
|
740
914
|
print("Getting the training data...")
|
|
741
915
|
training_data = my_features.get_training_data()
|
|
916
|
+
print(f"Training Data: {training_data.shape}")
|
|
917
|
+
|
|
918
|
+
# Test the filter expression functionality
|
|
919
|
+
print("Setting a filter expression...")
|
|
920
|
+
my_features.set_training_filter("auto_id < 50 AND length > 65.0")
|
|
921
|
+
training_data = my_features.get_training_data()
|
|
922
|
+
print(f"Training Data: {training_data.shape}")
|
|
923
|
+
print(training_data)
|
|
924
|
+
|
|
925
|
+
# Remove training filter
|
|
926
|
+
print("Removing the filter expression...")
|
|
927
|
+
my_features.set_training_filter(None)
|
|
928
|
+
training_data = my_features.get_training_data()
|
|
929
|
+
print(f"Training Data: {training_data.shape}")
|
|
930
|
+
print(training_data)
|
|
931
|
+
|
|
932
|
+
# Test excluding ids from training
|
|
933
|
+
print("Excluding ids from training...")
|
|
934
|
+
my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
|
|
935
|
+
training_data = my_features.get_training_data()
|
|
936
|
+
print(f"Training Data: {training_data.shape}")
|
|
937
|
+
print(training_data)
|
|
742
938
|
|
|
743
939
|
# Now delete the AWS artifacts associated with this Feature Set
|
|
744
940
|
# print("Deleting Workbench Feature Set...")
|
|
745
941
|
# my_features.delete()
|
|
746
942
|
# print("Done")
|
|
943
|
+
|
|
944
|
+
# Test set_training_sampling with exclusions and replications
|
|
945
|
+
print("\n--- Testing set_training_sampling ---")
|
|
946
|
+
my_features.set_training_filter(None) # Reset any existing filters
|
|
947
|
+
original_count = num_rows
|
|
948
|
+
|
|
949
|
+
# Get valid IDs from the table
|
|
950
|
+
all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
951
|
+
valid_ids = sorted(all_data["auto_id"].tolist())
|
|
952
|
+
print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
|
|
953
|
+
|
|
954
|
+
exclude_list = valid_ids[0:3] # First 3 IDs
|
|
955
|
+
replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
|
|
956
|
+
|
|
957
|
+
print(f"Original row count: {original_count}")
|
|
958
|
+
print(f"Excluding IDs: {exclude_list}")
|
|
959
|
+
print(f"Replicating IDs: {replicate_list}")
|
|
960
|
+
|
|
961
|
+
# Test with default replication factor (2x)
|
|
962
|
+
print("\n--- Testing with replication_factor=2 (default) ---")
|
|
963
|
+
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
|
|
964
|
+
training_data = my_features.get_training_data()
|
|
965
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
966
|
+
|
|
967
|
+
# Verify exclusions
|
|
968
|
+
for exc_id in exclude_list:
|
|
969
|
+
count = len(training_data[training_data["auto_id"] == exc_id])
|
|
970
|
+
print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
|
|
971
|
+
|
|
972
|
+
# Verify replications
|
|
973
|
+
for rep_id in replicate_list:
|
|
974
|
+
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
975
|
+
print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
|
|
976
|
+
|
|
977
|
+
# Test with replication factor of 5
|
|
978
|
+
print("\n--- Testing with replication_factor=5 ---")
|
|
979
|
+
replicate_list_5x = [20, 21]
|
|
980
|
+
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
|
|
981
|
+
training_data = my_features.get_training_data()
|
|
982
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
983
|
+
|
|
984
|
+
# Verify 5x replication
|
|
985
|
+
for rep_id in replicate_list_5x:
|
|
986
|
+
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
987
|
+
print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
|
|
988
|
+
|
|
989
|
+
# Test with large replication list (simulate 100 IDs)
|
|
990
|
+
print("\n--- Testing with large ID list (100 IDs) ---")
|
|
991
|
+
large_replicate_list = list(range(30, 130)) # 100 IDs
|
|
992
|
+
my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
|
|
993
|
+
training_data = my_features.get_training_data()
|
|
994
|
+
print(f"Training Data after sampling: {training_data.shape}")
|
|
995
|
+
print(f"Expected extra rows: {len(large_replicate_list) * 3}")
|
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import proximity_model
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -29,69 +30,62 @@ class ModelType(Enum):
|
|
|
29
30
|
CLASSIFIER = "classifier"
|
|
30
31
|
REGRESSOR = "regressor"
|
|
31
32
|
CLUSTERER = "clusterer"
|
|
32
|
-
TRANSFORMER = "transformer"
|
|
33
33
|
PROXIMITY = "proximity"
|
|
34
34
|
PROJECTION = "projection"
|
|
35
35
|
UQ_REGRESSOR = "uq_regressor"
|
|
36
36
|
ENSEMBLE_REGRESSOR = "ensemble_regressor"
|
|
37
|
+
TRANSFORMER = "transformer"
|
|
38
|
+
UNKNOWN = "unknown"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelFramework(Enum):
|
|
42
|
+
"""Enumerated Types for Workbench Model Frameworks"""
|
|
43
|
+
|
|
44
|
+
SKLEARN = "sklearn"
|
|
45
|
+
XGBOOST = "xgboost"
|
|
46
|
+
LIGHTGBM = "lightgbm"
|
|
47
|
+
PYTORCH_TABULAR = "pytorch_tabular"
|
|
48
|
+
CHEMPROP = "chemprop"
|
|
49
|
+
TRANSFORMER = "transformer"
|
|
37
50
|
UNKNOWN = "unknown"
|
|
38
51
|
|
|
39
52
|
|
|
40
53
|
class ModelImages:
|
|
41
54
|
"""Class for retrieving workbench inference images"""
|
|
42
55
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
),
|
|
54
|
-
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
55
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
|
-
),
|
|
57
|
-
# US West 2 images
|
|
58
|
-
("us-west-2", "xgb_training", "0.1", "x86_64"): (
|
|
59
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
60
|
-
),
|
|
61
|
-
("us-west-2", "xgb_inference", "0.1", "x86_64"): (
|
|
62
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
63
|
-
),
|
|
64
|
-
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
66
|
-
),
|
|
67
|
-
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
68
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
|
-
),
|
|
70
|
-
# ARM64 images
|
|
71
|
-
("us-east-1", "xgb_inference", "0.1", "arm64"): (
|
|
72
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
73
|
-
),
|
|
74
|
-
("us-west-2", "xgb_inference", "0.1", "arm64"): (
|
|
75
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
76
|
-
),
|
|
77
|
-
# Meta Endpoint inference images
|
|
78
|
-
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
79
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
80
|
-
),
|
|
81
|
-
("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
|
|
82
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
83
|
-
),
|
|
56
|
+
# Account ID
|
|
57
|
+
ACCOUNT_ID = "507740646243"
|
|
58
|
+
|
|
59
|
+
# Image name mappings
|
|
60
|
+
IMAGE_NAMES = {
|
|
61
|
+
"training": "py312-general-ml-training",
|
|
62
|
+
"inference": "py312-general-ml-inference",
|
|
63
|
+
"pytorch_training": "py312-pytorch-training",
|
|
64
|
+
"pytorch_inference": "py312-pytorch-inference",
|
|
65
|
+
"meta-endpoint": "py312-meta-endpoint",
|
|
84
66
|
}
|
|
85
67
|
|
|
86
68
|
@classmethod
|
|
87
|
-
def get_image_uri(cls, region, image_type, version="
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
69
|
+
def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
|
|
70
|
+
"""
|
|
71
|
+
Dynamically construct ECR image URI.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
75
|
+
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
76
|
+
version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
|
|
77
|
+
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
ECR image URI string
|
|
81
|
+
"""
|
|
82
|
+
if image_type not in cls.IMAGE_NAMES:
|
|
83
|
+
raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
|
|
84
|
+
|
|
85
|
+
image_name = cls.IMAGE_NAMES[image_type]
|
|
86
|
+
uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
|
|
87
|
+
|
|
88
|
+
return uri
|
|
95
89
|
|
|
96
90
|
|
|
97
91
|
class ModelCore(Artifact):
|
|
@@ -105,11 +99,10 @@ class ModelCore(Artifact):
|
|
|
105
99
|
```
|
|
106
100
|
"""
|
|
107
101
|
|
|
108
|
-
def __init__(self, model_name: str,
|
|
102
|
+
def __init__(self, model_name: str, **kwargs):
|
|
109
103
|
"""ModelCore Initialization
|
|
110
104
|
Args:
|
|
111
105
|
model_name (str): Name of Model in Workbench.
|
|
112
|
-
model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
|
|
113
106
|
**kwargs: Additional keyword arguments
|
|
114
107
|
"""
|
|
115
108
|
|
|
@@ -143,10 +136,8 @@ class ModelCore(Artifact):
|
|
|
143
136
|
self.latest_model = self.model_meta["ModelPackageList"][0]
|
|
144
137
|
self.description = self.latest_model.get("ModelPackageDescription", "-")
|
|
145
138
|
self.training_job_name = self._extract_training_job_name()
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
else:
|
|
149
|
-
self.model_type = self._get_model_type()
|
|
139
|
+
self.model_type = self._get_model_type()
|
|
140
|
+
self.model_framework = self._get_model_framework()
|
|
150
141
|
except (IndexError, KeyError):
|
|
151
142
|
self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
|
|
152
143
|
return
|
|
@@ -597,6 +588,24 @@ class ModelCore(Artifact):
|
|
|
597
588
|
# Return the details
|
|
598
589
|
return details
|
|
599
590
|
|
|
591
|
+
# Training View for this model
|
|
592
|
+
def training_view(self):
|
|
593
|
+
"""Get the training view for this model"""
|
|
594
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
595
|
+
from workbench.core.views import View
|
|
596
|
+
|
|
597
|
+
# Grab our FeatureSet
|
|
598
|
+
fs = FeatureSetCore(self.get_input())
|
|
599
|
+
|
|
600
|
+
# See if we have a training view for this model
|
|
601
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
|
|
602
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
603
|
+
if view.exists():
|
|
604
|
+
return view
|
|
605
|
+
else:
|
|
606
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
607
|
+
return fs.view("training")
|
|
608
|
+
|
|
600
609
|
# Pipeline for this model
|
|
601
610
|
def get_pipeline(self) -> str:
|
|
602
611
|
"""Get the pipeline for this model"""
|
|
@@ -879,10 +888,24 @@ class ModelCore(Artifact):
|
|
|
879
888
|
except (KeyError, IndexError, TypeError):
|
|
880
889
|
return None
|
|
881
890
|
|
|
891
|
+
def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
|
|
892
|
+
"""Create and publish a Proximity Model for this Model
|
|
893
|
+
|
|
894
|
+
Args:
|
|
895
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
896
|
+
track_columns (list, optional): List of columns to track in the Proximity Model.
|
|
897
|
+
|
|
898
|
+
Returns:
|
|
899
|
+
Model: The published Proximity Model
|
|
900
|
+
"""
|
|
901
|
+
if prox_model_name is None:
|
|
902
|
+
prox_model_name = self.model_name + "-prox"
|
|
903
|
+
return proximity_model(self, prox_model_name, track_columns=track_columns)
|
|
904
|
+
|
|
882
905
|
def delete(self):
|
|
883
906
|
"""Delete the Model Packages and the Model Group"""
|
|
884
907
|
if not self.exists():
|
|
885
|
-
self.log.warning(f"Trying to delete
|
|
908
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
886
909
|
|
|
887
910
|
# Call the Class Method to delete the Model Group
|
|
888
911
|
ModelCore.managed_delete(model_group_name=self.name)
|
|
@@ -958,6 +981,27 @@ class ModelCore(Artifact):
|
|
|
958
981
|
self.log.warning(f"Could not determine model type for {self.model_name}!")
|
|
959
982
|
return ModelType.UNKNOWN
|
|
960
983
|
|
|
984
|
+
def _set_model_framework(self, model_framework: ModelFramework):
|
|
985
|
+
"""Internal: Set the Model Framework for this Model"""
|
|
986
|
+
self.model_framework = model_framework
|
|
987
|
+
self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
|
|
988
|
+
self.remove_health_tag("model_framework_unknown")
|
|
989
|
+
|
|
990
|
+
def _get_model_framework(self) -> ModelFramework:
|
|
991
|
+
"""Internal: Query the Workbench Metadata to get the model framework
|
|
992
|
+
Returns:
|
|
993
|
+
ModelFramework: The ModelFramework of this Model
|
|
994
|
+
Notes:
|
|
995
|
+
This is an internal method that should not be called directly
|
|
996
|
+
Use the model_framework attribute instead
|
|
997
|
+
"""
|
|
998
|
+
model_framework = self.workbench_meta().get("workbench_model_framework")
|
|
999
|
+
try:
|
|
1000
|
+
return ModelFramework(model_framework)
|
|
1001
|
+
except ValueError:
|
|
1002
|
+
self.log.warning(f"Could not determine model framework for {self.model_name}!")
|
|
1003
|
+
return ModelFramework.UNKNOWN
|
|
1004
|
+
|
|
961
1005
|
def _load_training_metrics(self):
|
|
962
1006
|
"""Internal: Retrieve the training metrics and Confusion Matrix for this model
|
|
963
1007
|
and load the data into the Workbench Metadata
|
|
@@ -1149,13 +1193,11 @@ if __name__ == "__main__":
|
|
|
1149
1193
|
# Grab a ModelCore object and pull some information from it
|
|
1150
1194
|
my_model = ModelCore("abalone-regression")
|
|
1151
1195
|
|
|
1152
|
-
# Call the various methods
|
|
1153
|
-
|
|
1154
1196
|
# Let's do a check/validation of the Model
|
|
1155
1197
|
print(f"Model Check: {my_model.exists()}")
|
|
1156
1198
|
|
|
1157
1199
|
# Make sure the model is 'ready'
|
|
1158
|
-
|
|
1200
|
+
my_model.onboard()
|
|
1159
1201
|
|
|
1160
1202
|
# Get the ARN of the Model Group
|
|
1161
1203
|
print(f"Model Group ARN: {my_model.group_arn()}")
|
|
@@ -1221,5 +1263,10 @@ if __name__ == "__main__":
|
|
|
1221
1263
|
# Delete the Model
|
|
1222
1264
|
# ModelCore.managed_delete("wine-classification")
|
|
1223
1265
|
|
|
1266
|
+
# Check the training view logic
|
|
1267
|
+
model = ModelCore("wine-class-test-251112-BW")
|
|
1268
|
+
training_view = model.training_view()
|
|
1269
|
+
print(f"Training View Name: {training_view.name}")
|
|
1270
|
+
|
|
1224
1271
|
# Check for a model that doesn't exist
|
|
1225
1272
|
my_model = ModelCore("empty-model-group")
|