workbench 0.8.168__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +50 -15
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  42. workbench/repl/workbench_shell.py +4 -4
  43. workbench/scripts/lambda_launcher.py +63 -0
  44. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  45. workbench/scripts/ml_pipeline_sqs.py +186 -0
  46. workbench/utils/chem_utils/__init__.py +0 -0
  47. workbench/utils/chem_utils/fingerprints.py +134 -0
  48. workbench/utils/chem_utils/misc.py +194 -0
  49. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  50. workbench/utils/chem_utils/mol_standardize.py +450 -0
  51. workbench/utils/chem_utils/mol_tagging.py +348 -0
  52. workbench/utils/chem_utils/projections.py +209 -0
  53. workbench/utils/chem_utils/salts.py +256 -0
  54. workbench/utils/chem_utils/sdf.py +292 -0
  55. workbench/utils/chem_utils/toxicity.py +250 -0
  56. workbench/utils/chem_utils/vis.py +253 -0
  57. workbench/utils/config_manager.py +2 -6
  58. workbench/utils/endpoint_utils.py +5 -7
  59. workbench/utils/license_manager.py +2 -6
  60. workbench/utils/model_utils.py +76 -30
  61. workbench/utils/monitor_utils.py +44 -62
  62. workbench/utils/pandas_utils.py +3 -3
  63. workbench/utils/shap_utils.py +10 -2
  64. workbench/utils/workbench_sqs.py +1 -1
  65. workbench/utils/xgboost_model_utils.py +283 -145
  66. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  67. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  68. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  69. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/METADATA +2 -1
  70. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/RECORD +74 -70
  71. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -1
  72. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  73. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  74. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  75. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  76. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  77. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  78. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  79. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  80. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  81. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  82. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  83. workbench/utils/chem_utils.py +0 -1556
  84. workbench/utils/fast_inference.py +0 -167
  85. workbench/utils/resource_utils.py +0 -39
  86. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  87. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  88. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Optional, List, Union
21
21
 
22
22
  from workbench.utils.aws_utils import aws_throttle
23
23
 
@@ -194,24 +194,24 @@ class FeatureSetCore(Artifact):
194
194
 
195
195
  return View(self, view_name)
196
196
 
197
- def set_display_columns(self, diplay_columns: list[str]):
197
+ def set_display_columns(self, display_columns: list[str]):
198
198
  """Set the display columns for this Data Source
199
199
 
200
200
  Args:
201
- diplay_columns (list[str]): The display columns for this Data Source
201
+ display_columns (list[str]): The display columns for this Data Source
202
202
  """
203
203
  # Check mismatch of display columns to computation columns
204
204
  c_view = self.view("computation")
205
205
  computation_columns = c_view.columns
206
- mismatch_columns = [col for col in diplay_columns if col not in computation_columns]
206
+ mismatch_columns = [col for col in display_columns if col not in computation_columns]
207
207
  if mismatch_columns:
208
208
  self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
209
209
 
210
- self.log.important(f"Setting Display Columns...{diplay_columns}")
210
+ self.log.important(f"Setting Display Columns...{display_columns}")
211
211
  from workbench.core.views import DisplayView
212
212
 
213
213
  # Create a NEW display view
214
- DisplayView.create(self, source_table=c_view.table, column_list=diplay_columns)
214
+ DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
215
215
 
216
216
  def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
217
217
  """Set the computation columns for this FeatureSet
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
509
509
  ].tolist()
510
510
  return hold_out_ids
511
511
 
512
+ def set_training_filter(self, filter_expression: Optional[str] = None):
513
+ """Set a filter expression for the training view for this FeatureSet
514
+
515
+ Args:
516
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
517
+ If None or empty string, will reset to training view with no filter
518
+ (default: None)
519
+ """
520
+ from workbench.core.views import TrainingView
521
+
522
+ # Grab the existing holdout ids
523
+ holdout_ids = self.get_training_holdouts()
524
+
525
+ # Create a NEW training view
526
+ self.log.important(f"Setting Training Filter: {filter_expression}")
527
+ TrainingView.create(
528
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
+ )
530
+
531
+ def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
+ """Exclude a list of IDs from the training view
533
+
534
+ Args:
535
+ ids (List[Union[str, int]],): List of IDs to exclude from training
536
+ column_name (Optional[str]): Column name to filter on.
537
+ If None, uses self.id_column (default: None)
538
+ """
539
+ # Use the default id_column if not specified
540
+ column = column_name or self.id_column
541
+
542
+ # Handle empty list case
543
+ if not ids:
544
+ self.log.warning("No IDs provided to exclude")
545
+ return
546
+
547
+ # Build the filter expression with proper SQL quoting
548
+ quoted_ids = ", ".join([repr(id) for id in ids])
549
+ filter_expression = f"{column} NOT IN ({quoted_ids})"
550
+
551
+ # Apply the filter
552
+ self.set_training_filter(filter_expression)
553
+
512
554
  @classmethod
513
555
  def delete_views(cls, table: str, database: str):
514
556
  """Delete any views associated with this FeatureSet
@@ -707,7 +749,7 @@ if __name__ == "__main__":
707
749
 
708
750
  # Test getting the holdout ids
709
751
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
752
+ holdout_ids = my_features.get_training_holdouts()
711
753
  print(f"Holdout IDs: {holdout_ids}")
712
754
 
713
755
  # Get a sample of the data
@@ -729,16 +771,33 @@ if __name__ == "__main__":
729
771
  table = my_features.view("training").table
730
772
  df = my_features.query(f'SELECT id, name FROM "{table}"')
731
773
  my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
774
+ my_features.set_training_holdouts(my_holdout_ids)
738
775
 
739
776
  # Get the training data
740
777
  print("Getting the training data...")
741
778
  training_data = my_features.get_training_data()
779
+ print(f"Training Data: {training_data.shape}")
780
+
781
+ # Test the filter expression functionality
782
+ print("Setting a filter expression...")
783
+ my_features.set_training_filter("id < 50 AND height > 65.0")
784
+ training_data = my_features.get_training_data()
785
+ print(f"Training Data: {training_data.shape}")
786
+ print(training_data)
787
+
788
+ # Remove training filter
789
+ print("Removing the filter expression...")
790
+ my_features.set_training_filter(None)
791
+ training_data = my_features.get_training_data()
792
+ print(f"Training Data: {training_data.shape}")
793
+ print(training_data)
794
+
795
+ # Test excluding ids from training
796
+ print("Excluding ids from training...")
797
+ my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
798
+ training_data = my_features.get_training_data()
799
+ print(f"Training Data: {training_data.shape}")
800
+ print(training_data)
742
801
 
743
802
  # Now delete the AWS artifacts associated with this Feature Set
744
803
  # print("Deleting Workbench Feature Set...")
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import proximity_model
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -42,11 +43,11 @@ class ModelImages:
42
43
 
43
44
  image_uris = {
44
45
  # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
+ ("us-east-1", "training", "0.1", "x86_64"): (
47
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
47
48
  ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
+ ("us-east-1", "inference", "0.1", "x86_64"): (
50
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
50
51
  ),
51
52
  ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
53
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -55,11 +56,11 @@ class ModelImages:
55
56
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
57
  ),
57
58
  # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
59
+ ("us-west-2", "training", "0.1", "x86_64"): (
60
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
60
61
  ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
62
+ ("us-west-2", "inference", "0.1", "x86_64"): (
63
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
63
64
  ),
64
65
  ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
66
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -68,12 +69,6 @@ class ModelImages:
68
69
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
70
  ),
70
71
  # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
72
  # Meta Endpoint inference images
78
73
  ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
74
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
@@ -597,6 +592,24 @@ class ModelCore(Artifact):
597
592
  # Return the details
598
593
  return details
599
594
 
595
+ # Training View for this model
596
+ def training_view(self):
597
+ """Get the training view for this model"""
598
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
599
+ from workbench.core.views import View
600
+
601
+ # Grab our FeatureSet
602
+ fs = FeatureSetCore(self.get_input())
603
+
604
+ # See if we have a training view for this model
605
+ my_model_training_view = f"{self.name.replace('-', '_')}_training"
606
+ view = View(fs, my_model_training_view, auto_create_view=False)
607
+ if view.exists():
608
+ return view
609
+ else:
610
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
611
+ return fs.view("training")
612
+
600
613
  # Pipeline for this model
601
614
  def get_pipeline(self) -> str:
602
615
  """Get the pipeline for this model"""
@@ -867,6 +880,14 @@ class ModelCore(Artifact):
867
880
  shap_data[key] = self.df_store.get(df_location)
868
881
  return shap_data or None
869
882
 
883
+ def cross_folds(self) -> dict:
884
+ """Retrieve the cross-fold inference results(only works for XGBoost models)
885
+
886
+ Returns:
887
+ dict: Dictionary with the cross-fold inference results
888
+ """
889
+ return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
890
+
870
891
  def supported_inference_instances(self) -> Optional[list]:
871
892
  """Retrieve the supported endpoint inference instance types
872
893
 
@@ -879,10 +900,24 @@ class ModelCore(Artifact):
879
900
  except (KeyError, IndexError, TypeError):
880
901
  return None
881
902
 
903
+ def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
904
+ """Create and publish a Proximity Model for this Model
905
+
906
+ Args:
907
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
908
+ track_columns (list, optional): List of columns to track in the Proximity Model.
909
+
910
+ Returns:
911
+ Model: The published Proximity Model
912
+ """
913
+ if prox_model_name is None:
914
+ prox_model_name = self.model_name + "-prox"
915
+ return proximity_model(self, prox_model_name, track_columns=track_columns)
916
+
882
917
  def delete(self):
883
918
  """Delete the Model Packages and the Model Group"""
884
919
  if not self.exists():
885
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
920
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
886
921
 
887
922
  # Call the Class Method to delete the Model Group
888
923
  ModelCore.managed_delete(model_group_name=self.name)