workbench 0.8.168__py3-none-any.whl → 0.8.193__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +3 -2
- workbench/api/feature_set.py +4 -4
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +113 -27
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +71 -49
- workbench/core/artifacts/monitor_core.py +33 -249
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +11 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +11 -6
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +9 -18
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +34 -43
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +4 -4
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +89 -31
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +300 -151
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +7 -2
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/METADATA +24 -2
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/RECORD +77 -72
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/entry_points.txt +3 -1
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/WHEEL +0 -0
- {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
|
|
|
17
17
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
18
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
19
|
|
|
20
|
-
from typing import TYPE_CHECKING
|
|
20
|
+
from typing import TYPE_CHECKING, Optional, List, Union
|
|
21
21
|
|
|
22
22
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
23
|
|
|
@@ -194,24 +194,24 @@ class FeatureSetCore(Artifact):
|
|
|
194
194
|
|
|
195
195
|
return View(self, view_name)
|
|
196
196
|
|
|
197
|
-
def set_display_columns(self,
|
|
197
|
+
def set_display_columns(self, display_columns: list[str]):
|
|
198
198
|
"""Set the display columns for this Data Source
|
|
199
199
|
|
|
200
200
|
Args:
|
|
201
|
-
|
|
201
|
+
display_columns (list[str]): The display columns for this Data Source
|
|
202
202
|
"""
|
|
203
203
|
# Check mismatch of display columns to computation columns
|
|
204
204
|
c_view = self.view("computation")
|
|
205
205
|
computation_columns = c_view.columns
|
|
206
|
-
mismatch_columns = [col for col in
|
|
206
|
+
mismatch_columns = [col for col in display_columns if col not in computation_columns]
|
|
207
207
|
if mismatch_columns:
|
|
208
208
|
self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
|
|
209
209
|
|
|
210
|
-
self.log.important(f"Setting Display Columns...{
|
|
210
|
+
self.log.important(f"Setting Display Columns...{display_columns}")
|
|
211
211
|
from workbench.core.views import DisplayView
|
|
212
212
|
|
|
213
213
|
# Create a NEW display view
|
|
214
|
-
DisplayView.create(self, source_table=c_view.table, column_list=
|
|
214
|
+
DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
|
|
215
215
|
|
|
216
216
|
def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
|
|
217
217
|
"""Set the computation columns for this FeatureSet
|
|
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
|
|
|
509
509
|
].tolist()
|
|
510
510
|
return hold_out_ids
|
|
511
511
|
|
|
512
|
+
def set_training_filter(self, filter_expression: Optional[str] = None):
|
|
513
|
+
"""Set a filter expression for the training view for this FeatureSet
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
|
|
517
|
+
If None or empty string, will reset to training view with no filter
|
|
518
|
+
(default: None)
|
|
519
|
+
"""
|
|
520
|
+
from workbench.core.views import TrainingView
|
|
521
|
+
|
|
522
|
+
# Grab the existing holdout ids
|
|
523
|
+
holdout_ids = self.get_training_holdouts()
|
|
524
|
+
|
|
525
|
+
# Create a NEW training view
|
|
526
|
+
self.log.important(f"Setting Training Filter: {filter_expression}")
|
|
527
|
+
TrainingView.create(
|
|
528
|
+
self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
|
|
532
|
+
"""Exclude a list of IDs from the training view
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
ids (List[Union[str, int]],): List of IDs to exclude from training
|
|
536
|
+
column_name (Optional[str]): Column name to filter on.
|
|
537
|
+
If None, uses self.id_column (default: None)
|
|
538
|
+
"""
|
|
539
|
+
# Use the default id_column if not specified
|
|
540
|
+
column = column_name or self.id_column
|
|
541
|
+
|
|
542
|
+
# Handle empty list case
|
|
543
|
+
if not ids:
|
|
544
|
+
self.log.warning("No IDs provided to exclude")
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
# Build the filter expression with proper SQL quoting
|
|
548
|
+
quoted_ids = ", ".join([repr(id) for id in ids])
|
|
549
|
+
filter_expression = f"{column} NOT IN ({quoted_ids})"
|
|
550
|
+
|
|
551
|
+
# Apply the filter
|
|
552
|
+
self.set_training_filter(filter_expression)
|
|
553
|
+
|
|
512
554
|
@classmethod
|
|
513
555
|
def delete_views(cls, table: str, database: str):
|
|
514
556
|
"""Delete any views associated with this FeatureSet
|
|
@@ -707,7 +749,7 @@ if __name__ == "__main__":
|
|
|
707
749
|
|
|
708
750
|
# Test getting the holdout ids
|
|
709
751
|
print("Getting the hold out ids...")
|
|
710
|
-
holdout_ids = my_features.get_training_holdouts(
|
|
752
|
+
holdout_ids = my_features.get_training_holdouts()
|
|
711
753
|
print(f"Holdout IDs: {holdout_ids}")
|
|
712
754
|
|
|
713
755
|
# Get a sample of the data
|
|
@@ -729,16 +771,33 @@ if __name__ == "__main__":
|
|
|
729
771
|
table = my_features.view("training").table
|
|
730
772
|
df = my_features.query(f'SELECT id, name FROM "{table}"')
|
|
731
773
|
my_holdout_ids = [id for id in df["id"] if id < 20]
|
|
732
|
-
my_features.set_training_holdouts(
|
|
733
|
-
|
|
734
|
-
# Test the hold out set functionality with strings
|
|
735
|
-
print("Setting hold out ids (strings)...")
|
|
736
|
-
my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
|
|
737
|
-
my_features.set_training_holdouts("name", my_holdout_ids)
|
|
774
|
+
my_features.set_training_holdouts(my_holdout_ids)
|
|
738
775
|
|
|
739
776
|
# Get the training data
|
|
740
777
|
print("Getting the training data...")
|
|
741
778
|
training_data = my_features.get_training_data()
|
|
779
|
+
print(f"Training Data: {training_data.shape}")
|
|
780
|
+
|
|
781
|
+
# Test the filter expression functionality
|
|
782
|
+
print("Setting a filter expression...")
|
|
783
|
+
my_features.set_training_filter("id < 50 AND height > 65.0")
|
|
784
|
+
training_data = my_features.get_training_data()
|
|
785
|
+
print(f"Training Data: {training_data.shape}")
|
|
786
|
+
print(training_data)
|
|
787
|
+
|
|
788
|
+
# Remove training filter
|
|
789
|
+
print("Removing the filter expression...")
|
|
790
|
+
my_features.set_training_filter(None)
|
|
791
|
+
training_data = my_features.get_training_data()
|
|
792
|
+
print(f"Training Data: {training_data.shape}")
|
|
793
|
+
print(training_data)
|
|
794
|
+
|
|
795
|
+
# Test excluding ids from training
|
|
796
|
+
print("Excluding ids from training...")
|
|
797
|
+
my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
|
|
798
|
+
training_data = my_features.get_training_data()
|
|
799
|
+
print(f"Training Data: {training_data.shape}")
|
|
800
|
+
print(training_data)
|
|
742
801
|
|
|
743
802
|
# Now delete the AWS artifacts associated with this Feature Set
|
|
744
803
|
# print("Deleting Workbench Feature Set...")
|
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import proximity_model
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -40,58 +41,39 @@ class ModelType(Enum):
|
|
|
40
41
|
class ModelImages:
|
|
41
42
|
"""Class for retrieving workbench inference images"""
|
|
42
43
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
),
|
|
54
|
-
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
55
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
|
-
),
|
|
57
|
-
# US West 2 images
|
|
58
|
-
("us-west-2", "xgb_training", "0.1", "x86_64"): (
|
|
59
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
60
|
-
),
|
|
61
|
-
("us-west-2", "xgb_inference", "0.1", "x86_64"): (
|
|
62
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
63
|
-
),
|
|
64
|
-
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
66
|
-
),
|
|
67
|
-
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
68
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
|
-
),
|
|
70
|
-
# ARM64 images
|
|
71
|
-
("us-east-1", "xgb_inference", "0.1", "arm64"): (
|
|
72
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
73
|
-
),
|
|
74
|
-
("us-west-2", "xgb_inference", "0.1", "arm64"): (
|
|
75
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
76
|
-
),
|
|
77
|
-
# Meta Endpoint inference images
|
|
78
|
-
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
79
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
80
|
-
),
|
|
81
|
-
("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
|
|
82
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
83
|
-
),
|
|
44
|
+
# Account ID
|
|
45
|
+
ACCOUNT_ID = "507740646243"
|
|
46
|
+
|
|
47
|
+
# Image name mappings
|
|
48
|
+
IMAGE_NAMES = {
|
|
49
|
+
"training": "py312-general-ml-training",
|
|
50
|
+
"inference": "py312-general-ml-inference",
|
|
51
|
+
"pytorch_training": "py312-pytorch-training",
|
|
52
|
+
"pytorch_inference": "py312-pytorch-inference",
|
|
53
|
+
"meta-endpoint": "py312-meta-endpoint",
|
|
84
54
|
}
|
|
85
55
|
|
|
86
56
|
@classmethod
|
|
87
57
|
def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
58
|
+
"""
|
|
59
|
+
Dynamically construct ECR image URI.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
63
|
+
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
64
|
+
version: Image version (e.g., '0.1', '0.2')
|
|
65
|
+
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
ECR image URI string
|
|
69
|
+
"""
|
|
70
|
+
if image_type not in cls.IMAGE_NAMES:
|
|
71
|
+
raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
|
|
72
|
+
|
|
73
|
+
image_name = cls.IMAGE_NAMES[image_type]
|
|
74
|
+
uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
|
|
75
|
+
|
|
76
|
+
return uri
|
|
95
77
|
|
|
96
78
|
|
|
97
79
|
class ModelCore(Artifact):
|
|
@@ -597,6 +579,24 @@ class ModelCore(Artifact):
|
|
|
597
579
|
# Return the details
|
|
598
580
|
return details
|
|
599
581
|
|
|
582
|
+
# Training View for this model
|
|
583
|
+
def training_view(self):
|
|
584
|
+
"""Get the training view for this model"""
|
|
585
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
586
|
+
from workbench.core.views import View
|
|
587
|
+
|
|
588
|
+
# Grab our FeatureSet
|
|
589
|
+
fs = FeatureSetCore(self.get_input())
|
|
590
|
+
|
|
591
|
+
# See if we have a training view for this model
|
|
592
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training"
|
|
593
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
594
|
+
if view.exists():
|
|
595
|
+
return view
|
|
596
|
+
else:
|
|
597
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
598
|
+
return fs.view("training")
|
|
599
|
+
|
|
600
600
|
# Pipeline for this model
|
|
601
601
|
def get_pipeline(self) -> str:
|
|
602
602
|
"""Get the pipeline for this model"""
|
|
@@ -867,6 +867,14 @@ class ModelCore(Artifact):
|
|
|
867
867
|
shap_data[key] = self.df_store.get(df_location)
|
|
868
868
|
return shap_data or None
|
|
869
869
|
|
|
870
|
+
def cross_folds(self) -> dict:
|
|
871
|
+
"""Retrieve the cross-fold inference results(only works for XGBoost models)
|
|
872
|
+
|
|
873
|
+
Returns:
|
|
874
|
+
dict: Dictionary with the cross-fold inference results
|
|
875
|
+
"""
|
|
876
|
+
return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
|
|
877
|
+
|
|
870
878
|
def supported_inference_instances(self) -> Optional[list]:
|
|
871
879
|
"""Retrieve the supported endpoint inference instance types
|
|
872
880
|
|
|
@@ -879,10 +887,24 @@ class ModelCore(Artifact):
|
|
|
879
887
|
except (KeyError, IndexError, TypeError):
|
|
880
888
|
return None
|
|
881
889
|
|
|
890
|
+
def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
|
|
891
|
+
"""Create and publish a Proximity Model for this Model
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
895
|
+
track_columns (list, optional): List of columns to track in the Proximity Model.
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
Model: The published Proximity Model
|
|
899
|
+
"""
|
|
900
|
+
if prox_model_name is None:
|
|
901
|
+
prox_model_name = self.model_name + "-prox"
|
|
902
|
+
return proximity_model(self, prox_model_name, track_columns=track_columns)
|
|
903
|
+
|
|
882
904
|
def delete(self):
|
|
883
905
|
"""Delete the Model Packages and the Model Group"""
|
|
884
906
|
if not self.exists():
|
|
885
|
-
self.log.warning(f"Trying to delete
|
|
907
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
886
908
|
|
|
887
909
|
# Call the Class Method to delete the Model Group
|
|
888
910
|
ModelCore.managed_delete(model_group_name=self.name)
|