workbench 0.8.168__py3-none-any.whl → 0.8.193__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +71 -49
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +11 -6
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/generated_model_script.py +34 -43
  42. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  43. workbench/repl/workbench_shell.py +4 -4
  44. workbench/scripts/lambda_launcher.py +63 -0
  45. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  46. workbench/scripts/ml_pipeline_sqs.py +186 -0
  47. workbench/utils/chem_utils/__init__.py +0 -0
  48. workbench/utils/chem_utils/fingerprints.py +134 -0
  49. workbench/utils/chem_utils/misc.py +194 -0
  50. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  51. workbench/utils/chem_utils/mol_standardize.py +450 -0
  52. workbench/utils/chem_utils/mol_tagging.py +348 -0
  53. workbench/utils/chem_utils/projections.py +209 -0
  54. workbench/utils/chem_utils/salts.py +256 -0
  55. workbench/utils/chem_utils/sdf.py +292 -0
  56. workbench/utils/chem_utils/toxicity.py +250 -0
  57. workbench/utils/chem_utils/vis.py +253 -0
  58. workbench/utils/config_manager.py +2 -6
  59. workbench/utils/endpoint_utils.py +5 -7
  60. workbench/utils/license_manager.py +2 -6
  61. workbench/utils/model_utils.py +89 -31
  62. workbench/utils/monitor_utils.py +44 -62
  63. workbench/utils/pandas_utils.py +3 -3
  64. workbench/utils/shap_utils.py +10 -2
  65. workbench/utils/workbench_sqs.py +1 -1
  66. workbench/utils/xgboost_model_utils.py +300 -151
  67. workbench/web_interface/components/model_plot.py +7 -1
  68. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  69. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  70. workbench/web_interface/components/plugins/model_details.py +7 -2
  71. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  72. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/METADATA +24 -2
  73. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/RECORD +77 -72
  74. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/entry_points.txt +3 -1
  75. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/licenses/LICENSE +1 -1
  76. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  77. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  78. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  79. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  80. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  81. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  82. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  83. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  84. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  85. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  86. workbench/utils/chem_utils.py +0 -1556
  87. workbench/utils/fast_inference.py +0 -167
  88. workbench/utils/resource_utils.py +0 -39
  89. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/WHEEL +0 -0
  90. {workbench-0.8.168.dist-info → workbench-0.8.193.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Optional, List, Union
21
21
 
22
22
  from workbench.utils.aws_utils import aws_throttle
23
23
 
@@ -194,24 +194,24 @@ class FeatureSetCore(Artifact):
194
194
 
195
195
  return View(self, view_name)
196
196
 
197
- def set_display_columns(self, diplay_columns: list[str]):
197
+ def set_display_columns(self, display_columns: list[str]):
198
198
  """Set the display columns for this Data Source
199
199
 
200
200
  Args:
201
- diplay_columns (list[str]): The display columns for this Data Source
201
+ display_columns (list[str]): The display columns for this Data Source
202
202
  """
203
203
  # Check mismatch of display columns to computation columns
204
204
  c_view = self.view("computation")
205
205
  computation_columns = c_view.columns
206
- mismatch_columns = [col for col in diplay_columns if col not in computation_columns]
206
+ mismatch_columns = [col for col in display_columns if col not in computation_columns]
207
207
  if mismatch_columns:
208
208
  self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
209
209
 
210
- self.log.important(f"Setting Display Columns...{diplay_columns}")
210
+ self.log.important(f"Setting Display Columns...{display_columns}")
211
211
  from workbench.core.views import DisplayView
212
212
 
213
213
  # Create a NEW display view
214
- DisplayView.create(self, source_table=c_view.table, column_list=diplay_columns)
214
+ DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
215
215
 
216
216
  def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
217
217
  """Set the computation columns for this FeatureSet
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
509
509
  ].tolist()
510
510
  return hold_out_ids
511
511
 
512
+ def set_training_filter(self, filter_expression: Optional[str] = None):
513
+ """Set a filter expression for the training view for this FeatureSet
514
+
515
+ Args:
516
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
517
+ If None or empty string, will reset to training view with no filter
518
+ (default: None)
519
+ """
520
+ from workbench.core.views import TrainingView
521
+
522
+ # Grab the existing holdout ids
523
+ holdout_ids = self.get_training_holdouts()
524
+
525
+ # Create a NEW training view
526
+ self.log.important(f"Setting Training Filter: {filter_expression}")
527
+ TrainingView.create(
528
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
+ )
530
+
531
+ def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
+ """Exclude a list of IDs from the training view
533
+
534
+ Args:
535
+ ids (List[Union[str, int]],): List of IDs to exclude from training
536
+ column_name (Optional[str]): Column name to filter on.
537
+ If None, uses self.id_column (default: None)
538
+ """
539
+ # Use the default id_column if not specified
540
+ column = column_name or self.id_column
541
+
542
+ # Handle empty list case
543
+ if not ids:
544
+ self.log.warning("No IDs provided to exclude")
545
+ return
546
+
547
+ # Build the filter expression with proper SQL quoting
548
+ quoted_ids = ", ".join([repr(id) for id in ids])
549
+ filter_expression = f"{column} NOT IN ({quoted_ids})"
550
+
551
+ # Apply the filter
552
+ self.set_training_filter(filter_expression)
553
+
512
554
  @classmethod
513
555
  def delete_views(cls, table: str, database: str):
514
556
  """Delete any views associated with this FeatureSet
@@ -707,7 +749,7 @@ if __name__ == "__main__":
707
749
 
708
750
  # Test getting the holdout ids
709
751
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
752
+ holdout_ids = my_features.get_training_holdouts()
711
753
  print(f"Holdout IDs: {holdout_ids}")
712
754
 
713
755
  # Get a sample of the data
@@ -729,16 +771,33 @@ if __name__ == "__main__":
729
771
  table = my_features.view("training").table
730
772
  df = my_features.query(f'SELECT id, name FROM "{table}"')
731
773
  my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
774
+ my_features.set_training_holdouts(my_holdout_ids)
738
775
 
739
776
  # Get the training data
740
777
  print("Getting the training data...")
741
778
  training_data = my_features.get_training_data()
779
+ print(f"Training Data: {training_data.shape}")
780
+
781
+ # Test the filter expression functionality
782
+ print("Setting a filter expression...")
783
+ my_features.set_training_filter("id < 50 AND height > 65.0")
784
+ training_data = my_features.get_training_data()
785
+ print(f"Training Data: {training_data.shape}")
786
+ print(training_data)
787
+
788
+ # Remove training filter
789
+ print("Removing the filter expression...")
790
+ my_features.set_training_filter(None)
791
+ training_data = my_features.get_training_data()
792
+ print(f"Training Data: {training_data.shape}")
793
+ print(training_data)
794
+
795
+ # Test excluding ids from training
796
+ print("Excluding ids from training...")
797
+ my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
798
+ training_data = my_features.get_training_data()
799
+ print(f"Training Data: {training_data.shape}")
800
+ print(training_data)
742
801
 
743
802
  # Now delete the AWS artifacts associated with this Feature Set
744
803
  # print("Deleting Workbench Feature Set...")
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import proximity_model
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -40,58 +41,39 @@ class ModelType(Enum):
40
41
  class ModelImages:
41
42
  """Class for retrieving workbench inference images"""
42
43
 
43
- image_uris = {
44
- # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
47
- ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
50
- ),
51
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
53
- ),
54
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
55
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
- ),
57
- # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
60
- ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
63
- ),
64
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
66
- ),
67
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
68
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
- ),
70
- # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
- # Meta Endpoint inference images
78
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
80
- ),
81
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
82
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
83
- ),
44
+ # Account ID
45
+ ACCOUNT_ID = "507740646243"
46
+
47
+ # Image name mappings
48
+ IMAGE_NAMES = {
49
+ "training": "py312-general-ml-training",
50
+ "inference": "py312-general-ml-inference",
51
+ "pytorch_training": "py312-pytorch-training",
52
+ "pytorch_inference": "py312-pytorch-inference",
53
+ "meta-endpoint": "py312-meta-endpoint",
84
54
  }
85
55
 
86
56
  @classmethod
87
57
  def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
88
- key = (region, image_type, version, architecture)
89
- if key in cls.image_uris:
90
- return cls.image_uris[key]
91
- else:
92
- raise ValueError(
93
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
94
- )
58
+ """
59
+ Dynamically construct ECR image URI.
60
+
61
+ Args:
62
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
63
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
64
+ version: Image version (e.g., '0.1', '0.2')
65
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
66
+
67
+ Returns:
68
+ ECR image URI string
69
+ """
70
+ if image_type not in cls.IMAGE_NAMES:
71
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
72
+
73
+ image_name = cls.IMAGE_NAMES[image_type]
74
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
75
+
76
+ return uri
95
77
 
96
78
 
97
79
  class ModelCore(Artifact):
@@ -597,6 +579,24 @@ class ModelCore(Artifact):
597
579
  # Return the details
598
580
  return details
599
581
 
582
+ # Training View for this model
583
+ def training_view(self):
584
+ """Get the training view for this model"""
585
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
586
+ from workbench.core.views import View
587
+
588
+ # Grab our FeatureSet
589
+ fs = FeatureSetCore(self.get_input())
590
+
591
+ # See if we have a training view for this model
592
+ my_model_training_view = f"{self.name.replace('-', '_')}_training"
593
+ view = View(fs, my_model_training_view, auto_create_view=False)
594
+ if view.exists():
595
+ return view
596
+ else:
597
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
598
+ return fs.view("training")
599
+
600
600
  # Pipeline for this model
601
601
  def get_pipeline(self) -> str:
602
602
  """Get the pipeline for this model"""
@@ -867,6 +867,14 @@ class ModelCore(Artifact):
867
867
  shap_data[key] = self.df_store.get(df_location)
868
868
  return shap_data or None
869
869
 
870
+ def cross_folds(self) -> dict:
871
+ """Retrieve the cross-fold inference results(only works for XGBoost models)
872
+
873
+ Returns:
874
+ dict: Dictionary with the cross-fold inference results
875
+ """
876
+ return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
877
+
870
878
  def supported_inference_instances(self) -> Optional[list]:
871
879
  """Retrieve the supported endpoint inference instance types
872
880
 
@@ -879,10 +887,24 @@ class ModelCore(Artifact):
879
887
  except (KeyError, IndexError, TypeError):
880
888
  return None
881
889
 
890
+ def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
891
+ """Create and publish a Proximity Model for this Model
892
+
893
+ Args:
894
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
895
+ track_columns (list, optional): List of columns to track in the Proximity Model.
896
+
897
+ Returns:
898
+ Model: The published Proximity Model
899
+ """
900
+ if prox_model_name is None:
901
+ prox_model_name = self.model_name + "-prox"
902
+ return proximity_model(self, prox_model_name, track_columns=track_columns)
903
+
882
904
  def delete(self):
883
905
  """Delete the Model Packages and the Model Group"""
884
906
  if not self.exists():
885
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
907
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
886
908
 
887
909
  # Call the Class Method to delete the Model Group
888
910
  ModelCore.managed_delete(model_group_name=self.name)