workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
16
16
  from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
+ from workbench.utils.deprecated_utils import deprecated
19
20
 
20
- from typing import TYPE_CHECKING
21
+ from typing import TYPE_CHECKING, Optional, List, Dict, Union
21
22
 
22
23
  from workbench.utils.aws_utils import aws_throttle
23
24
 
@@ -194,24 +195,24 @@ class FeatureSetCore(Artifact):
194
195
 
195
196
  return View(self, view_name)
196
197
 
197
- def set_display_columns(self, diplay_columns: list[str]):
198
+ def set_display_columns(self, display_columns: list[str]):
198
199
  """Set the display columns for this Data Source
199
200
 
200
201
  Args:
201
- diplay_columns (list[str]): The display columns for this Data Source
202
+ display_columns (list[str]): The display columns for this Data Source
202
203
  """
203
204
  # Check mismatch of display columns to computation columns
204
205
  c_view = self.view("computation")
205
206
  computation_columns = c_view.columns
206
- mismatch_columns = [col for col in diplay_columns if col not in computation_columns]
207
+ mismatch_columns = [col for col in display_columns if col not in computation_columns]
207
208
  if mismatch_columns:
208
209
  self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
209
210
 
210
- self.log.important(f"Setting Display Columns...{diplay_columns}")
211
+ self.log.important(f"Setting Display Columns...{display_columns}")
211
212
  from workbench.core.views import DisplayView
212
213
 
213
214
  # Create a NEW display view
214
- DisplayView.create(self, source_table=c_view.table, column_list=diplay_columns)
215
+ DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
215
216
 
216
217
  def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
217
218
  """Set the computation columns for this FeatureSet
@@ -509,6 +510,184 @@ class FeatureSetCore(Artifact):
509
510
  ].tolist()
510
511
  return hold_out_ids
511
512
 
513
+ def set_sample_weights(
514
+ self,
515
+ weight_dict: Dict[Union[str, int], float],
516
+ default_weight: float = 1.0,
517
+ exclude_zero_weights: bool = True,
518
+ ):
519
+ """Configure training view with sample weights for each ID.
520
+
521
+ Args:
522
+ weight_dict: Mapping of ID to sample weight
523
+ - weight > 1.0: oversample/emphasize
524
+ - weight = 1.0: normal (default)
525
+ - 0 < weight < 1.0: downweight/de-emphasize
526
+ - weight = 0.0: exclude from training
527
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
528
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
529
+
530
+ Example:
531
+ weights = {
532
+ 'compound_42': 3.0, # oversample 3x
533
+ 'compound_99': 0.1, # noisy, downweight
534
+ 'compound_123': 0.0, # exclude from training
535
+ }
536
+ model.set_sample_weights(weights) # zeros automatically excluded
537
+ model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
538
+ """
539
+ from workbench.core.views import TrainingView
540
+
541
+ if not weight_dict:
542
+ self.log.important("Empty weight_dict, creating standard training view")
543
+ TrainingView.create(self, id_column=self.id_column)
544
+ return
545
+
546
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
547
+
548
+ # Helper to format IDs for SQL
549
+ def format_id(id_val):
550
+ return repr(id_val)
551
+
552
+ # Build CASE statement for sample_weight
553
+ case_conditions = [
554
+ f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
555
+ ]
556
+ case_statement = "\n ".join(case_conditions)
557
+
558
+ # Build inner query with sample weights
559
+ inner_sql = f"""SELECT
560
+ *,
561
+ CASE
562
+ {case_statement}
563
+ ELSE {default_weight}
564
+ END AS sample_weight
565
+ FROM {self.table}"""
566
+
567
+ # Optionally filter out zero weights
568
+ if exclude_zero_weights:
569
+ zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
570
+ custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
571
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
572
+ else:
573
+ custom_sql = inner_sql
574
+
575
+ TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
576
+
577
+ @deprecated(version=0.9)
578
+ def set_training_filter(self, filter_expression: Optional[str] = None):
579
+ """Set a filter expression for the training view for this FeatureSet
580
+
581
+ Args:
582
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
583
+ If None or empty string, will reset to training view with no filter
584
+ (default: None)
585
+ """
586
+ from workbench.core.views import TrainingView
587
+
588
+ # Grab the existing holdout ids
589
+ holdout_ids = self.get_training_holdouts()
590
+
591
+ # Create a NEW training view
592
+ self.log.important(f"Setting Training Filter: {filter_expression}")
593
+ TrainingView.create(
594
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
595
+ )
596
+
597
+ @deprecated(version="0.9")
598
+ def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
599
+ """Exclude a list of IDs from the training view
600
+
601
+ Args:
602
+ ids (List[Union[str, int]],): List of IDs to exclude from training
603
+ column_name (Optional[str]): Column name to filter on.
604
+ If None, uses self.id_column (default: None)
605
+ """
606
+ # Use the default id_column if not specified
607
+ column = column_name or self.id_column
608
+
609
+ # Handle empty list case
610
+ if not ids:
611
+ self.log.warning("No IDs provided to exclude")
612
+ return
613
+
614
+ # Build the filter expression with proper SQL quoting
615
+ quoted_ids = ", ".join([repr(id) for id in ids])
616
+ filter_expression = f"{column} NOT IN ({quoted_ids})"
617
+
618
+ # Apply the filter
619
+ self.set_training_filter(filter_expression)
620
+
621
+ @deprecated(version="0.9")
622
+ def set_training_sampling(
623
+ self,
624
+ exclude_ids: Optional[List[Union[str, int]]] = None,
625
+ replicate_ids: Optional[List[Union[str, int]]] = None,
626
+ replication_factor: int = 2,
627
+ ):
628
+ """Configure training view with ID exclusions and replications (oversampling).
629
+
630
+ Args:
631
+ exclude_ids: List of IDs to exclude from training view
632
+ replicate_ids: List of IDs to replicate in training view for oversampling
633
+ replication_factor: Number of times to replicate each ID (default: 2)
634
+
635
+ Note:
636
+ If an ID appears in both lists, exclusion takes precedence.
637
+ """
638
+ from workbench.core.views import TrainingView
639
+
640
+ # Normalize to empty lists if None
641
+ exclude_ids = exclude_ids or []
642
+ replicate_ids = replicate_ids or []
643
+
644
+ # Remove any replicate_ids that are also in exclude_ids (exclusion wins)
645
+ replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
646
+
647
+ # If no sampling needed, just create normal view
648
+ if not exclude_ids and not replicate_ids:
649
+ self.log.important("No sampling specified, creating standard training view")
650
+ TrainingView.create(self, id_column=self.id_column)
651
+ return
652
+
653
+ # Build the custom SQL query
654
+ self.log.important(
655
+ f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
656
+ f"(factor: {replication_factor}x)"
657
+ )
658
+
659
+ # Helper to format IDs for SQL
660
+ def format_ids(ids):
661
+ return ", ".join([repr(id) for id in ids])
662
+
663
+ # Start with base query
664
+ base_query = f"SELECT * FROM {self.table}"
665
+
666
+ # Add exclusions if needed
667
+ if exclude_ids:
668
+ base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
669
+
670
+ # Build full query with replication
671
+ if replicate_ids:
672
+ # Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
673
+ # We want N-1 additional copies since the original row is already in base_query
674
+ values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
675
+
676
+ custom_sql = f"""{base_query}
677
+
678
+ UNION ALL
679
+
680
+ SELECT t.*
681
+ FROM {self.table} t
682
+ CROSS JOIN (VALUES {values_clause}) AS n(num)
683
+ WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
684
+ else:
685
+ # Only exclusions, no UNION needed
686
+ custom_sql = base_query
687
+
688
+ # Create the training view with our custom SQL
689
+ TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
690
+
512
691
  @classmethod
513
692
  def delete_views(cls, table: str, database: str):
514
693
  """Delete any views associated with this FeatureSet
@@ -667,7 +846,7 @@ if __name__ == "__main__":
667
846
  pd.set_option("display.width", 1000)
668
847
 
669
848
  # Grab a FeatureSet object and pull some information from it
670
- my_features = LocalFeatureSetCore("test_features")
849
+ my_features = LocalFeatureSetCore("abalone_features")
671
850
  if not my_features.exists():
672
851
  print("FeatureSet not found!")
673
852
  sys.exit(1)
@@ -707,7 +886,7 @@ if __name__ == "__main__":
707
886
 
708
887
  # Test getting the holdout ids
709
888
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
889
+ holdout_ids = my_features.get_training_holdouts()
711
890
  print(f"Holdout IDs: {holdout_ids}")
712
891
 
713
892
  # Get a sample of the data
@@ -727,20 +906,90 @@ if __name__ == "__main__":
727
906
  # Set the holdout ids for the training view
728
907
  print("Setting hold out ids...")
729
908
  table = my_features.view("training").table
730
- df = my_features.query(f'SELECT id, name FROM "{table}"')
731
- my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
909
+ df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
910
+ my_holdout_ids = [id for id in df["auto_id"] if id < 20]
911
+ my_features.set_training_holdouts(my_holdout_ids)
738
912
 
739
913
  # Get the training data
740
914
  print("Getting the training data...")
741
915
  training_data = my_features.get_training_data()
916
+ print(f"Training Data: {training_data.shape}")
917
+
918
+ # Test the filter expression functionality
919
+ print("Setting a filter expression...")
920
+ my_features.set_training_filter("auto_id < 50 AND length > 65.0")
921
+ training_data = my_features.get_training_data()
922
+ print(f"Training Data: {training_data.shape}")
923
+ print(training_data)
924
+
925
+ # Remove training filter
926
+ print("Removing the filter expression...")
927
+ my_features.set_training_filter(None)
928
+ training_data = my_features.get_training_data()
929
+ print(f"Training Data: {training_data.shape}")
930
+ print(training_data)
931
+
932
+ # Test excluding ids from training
933
+ print("Excluding ids from training...")
934
+ my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
935
+ training_data = my_features.get_training_data()
936
+ print(f"Training Data: {training_data.shape}")
937
+ print(training_data)
742
938
 
743
939
  # Now delete the AWS artifacts associated with this Feature Set
744
940
  # print("Deleting Workbench Feature Set...")
745
941
  # my_features.delete()
746
942
  # print("Done")
943
+
944
+ # Test set_training_sampling with exclusions and replications
945
+ print("\n--- Testing set_training_sampling ---")
946
+ my_features.set_training_filter(None) # Reset any existing filters
947
+ original_count = num_rows
948
+
949
+ # Get valid IDs from the table
950
+ all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
951
+ valid_ids = sorted(all_data["auto_id"].tolist())
952
+ print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
953
+
954
+ exclude_list = valid_ids[0:3] # First 3 IDs
955
+ replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
956
+
957
+ print(f"Original row count: {original_count}")
958
+ print(f"Excluding IDs: {exclude_list}")
959
+ print(f"Replicating IDs: {replicate_list}")
960
+
961
+ # Test with default replication factor (2x)
962
+ print("\n--- Testing with replication_factor=2 (default) ---")
963
+ my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
964
+ training_data = my_features.get_training_data()
965
+ print(f"Training Data after sampling: {training_data.shape}")
966
+
967
+ # Verify exclusions
968
+ for exc_id in exclude_list:
969
+ count = len(training_data[training_data["auto_id"] == exc_id])
970
+ print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
971
+
972
+ # Verify replications
973
+ for rep_id in replicate_list:
974
+ count = len(training_data[training_data["auto_id"] == rep_id])
975
+ print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
976
+
977
+ # Test with replication factor of 5
978
+ print("\n--- Testing with replication_factor=5 ---")
979
+ replicate_list_5x = [20, 21]
980
+ my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
981
+ training_data = my_features.get_training_data()
982
+ print(f"Training Data after sampling: {training_data.shape}")
983
+
984
+ # Verify 5x replication
985
+ for rep_id in replicate_list_5x:
986
+ count = len(training_data[training_data["auto_id"] == rep_id])
987
+ print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
988
+
989
+ # Test with large replication list (simulate 100 IDs)
990
+ print("\n--- Testing with large ID list (100 IDs) ---")
991
+ large_replicate_list = list(range(30, 130)) # 100 IDs
992
+ my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
993
+ training_data = my_features.get_training_data()
994
+ print(f"Training Data after sampling: {training_data.shape}")
995
+ print(f"Expected extra rows: {len(large_replicate_list) * 3}")
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import proximity_model
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -29,69 +30,62 @@ class ModelType(Enum):
29
30
  CLASSIFIER = "classifier"
30
31
  REGRESSOR = "regressor"
31
32
  CLUSTERER = "clusterer"
32
- TRANSFORMER = "transformer"
33
33
  PROXIMITY = "proximity"
34
34
  PROJECTION = "projection"
35
35
  UQ_REGRESSOR = "uq_regressor"
36
36
  ENSEMBLE_REGRESSOR = "ensemble_regressor"
37
+ TRANSFORMER = "transformer"
38
+ UNKNOWN = "unknown"
39
+
40
+
41
+ class ModelFramework(Enum):
42
+ """Enumerated Types for Workbench Model Frameworks"""
43
+
44
+ SKLEARN = "sklearn"
45
+ XGBOOST = "xgboost"
46
+ LIGHTGBM = "lightgbm"
47
+ PYTORCH_TABULAR = "pytorch_tabular"
48
+ CHEMPROP = "chemprop"
49
+ TRANSFORMER = "transformer"
37
50
  UNKNOWN = "unknown"
38
51
 
39
52
 
40
53
  class ModelImages:
41
54
  """Class for retrieving workbench inference images"""
42
55
 
43
- image_uris = {
44
- # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
47
- ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
50
- ),
51
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
53
- ),
54
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
55
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
- ),
57
- # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
60
- ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
63
- ),
64
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
66
- ),
67
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
68
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
- ),
70
- # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
- # Meta Endpoint inference images
78
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
80
- ),
81
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
82
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
83
- ),
56
+ # Account ID
57
+ ACCOUNT_ID = "507740646243"
58
+
59
+ # Image name mappings
60
+ IMAGE_NAMES = {
61
+ "training": "py312-general-ml-training",
62
+ "inference": "py312-general-ml-inference",
63
+ "pytorch_training": "py312-pytorch-training",
64
+ "pytorch_inference": "py312-pytorch-inference",
65
+ "meta-endpoint": "py312-meta-endpoint",
84
66
  }
85
67
 
86
68
  @classmethod
87
- def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
88
- key = (region, image_type, version, architecture)
89
- if key in cls.image_uris:
90
- return cls.image_uris[key]
91
- else:
92
- raise ValueError(
93
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
94
- )
69
+ def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
70
+ """
71
+ Dynamically construct ECR image URI.
72
+
73
+ Args:
74
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
75
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
76
+ version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
77
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
78
+
79
+ Returns:
80
+ ECR image URI string
81
+ """
82
+ if image_type not in cls.IMAGE_NAMES:
83
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
84
+
85
+ image_name = cls.IMAGE_NAMES[image_type]
86
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
87
+
88
+ return uri
95
89
 
96
90
 
97
91
  class ModelCore(Artifact):
@@ -105,11 +99,10 @@ class ModelCore(Artifact):
105
99
  ```
106
100
  """
107
101
 
108
- def __init__(self, model_name: str, model_type: ModelType = None, **kwargs):
102
+ def __init__(self, model_name: str, **kwargs):
109
103
  """ModelCore Initialization
110
104
  Args:
111
105
  model_name (str): Name of Model in Workbench.
112
- model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
113
106
  **kwargs: Additional keyword arguments
114
107
  """
115
108
 
@@ -143,10 +136,8 @@ class ModelCore(Artifact):
143
136
  self.latest_model = self.model_meta["ModelPackageList"][0]
144
137
  self.description = self.latest_model.get("ModelPackageDescription", "-")
145
138
  self.training_job_name = self._extract_training_job_name()
146
- if model_type:
147
- self._set_model_type(model_type)
148
- else:
149
- self.model_type = self._get_model_type()
139
+ self.model_type = self._get_model_type()
140
+ self.model_framework = self._get_model_framework()
150
141
  except (IndexError, KeyError):
151
142
  self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
152
143
  return
@@ -597,6 +588,24 @@ class ModelCore(Artifact):
597
588
  # Return the details
598
589
  return details
599
590
 
591
+ # Training View for this model
592
+ def training_view(self):
593
+ """Get the training view for this model"""
594
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
595
+ from workbench.core.views import View
596
+
597
+ # Grab our FeatureSet
598
+ fs = FeatureSetCore(self.get_input())
599
+
600
+ # See if we have a training view for this model
601
+ my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
602
+ view = View(fs, my_model_training_view, auto_create_view=False)
603
+ if view.exists():
604
+ return view
605
+ else:
606
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
607
+ return fs.view("training")
608
+
600
609
  # Pipeline for this model
601
610
  def get_pipeline(self) -> str:
602
611
  """Get the pipeline for this model"""
@@ -879,10 +888,24 @@ class ModelCore(Artifact):
879
888
  except (KeyError, IndexError, TypeError):
880
889
  return None
881
890
 
891
+ def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
892
+ """Create and publish a Proximity Model for this Model
893
+
894
+ Args:
895
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
896
+ track_columns (list, optional): List of columns to track in the Proximity Model.
897
+
898
+ Returns:
899
+ Model: The published Proximity Model
900
+ """
901
+ if prox_model_name is None:
902
+ prox_model_name = self.model_name + "-prox"
903
+ return proximity_model(self, prox_model_name, track_columns=track_columns)
904
+
882
905
  def delete(self):
883
906
  """Delete the Model Packages and the Model Group"""
884
907
  if not self.exists():
885
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
908
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
886
909
 
887
910
  # Call the Class Method to delete the Model Group
888
911
  ModelCore.managed_delete(model_group_name=self.name)
@@ -958,6 +981,27 @@ class ModelCore(Artifact):
958
981
  self.log.warning(f"Could not determine model type for {self.model_name}!")
959
982
  return ModelType.UNKNOWN
960
983
 
984
+ def _set_model_framework(self, model_framework: ModelFramework):
985
+ """Internal: Set the Model Framework for this Model"""
986
+ self.model_framework = model_framework
987
+ self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
988
+ self.remove_health_tag("model_framework_unknown")
989
+
990
+ def _get_model_framework(self) -> ModelFramework:
991
+ """Internal: Query the Workbench Metadata to get the model framework
992
+ Returns:
993
+ ModelFramework: The ModelFramework of this Model
994
+ Notes:
995
+ This is an internal method that should not be called directly
996
+ Use the model_framework attribute instead
997
+ """
998
+ model_framework = self.workbench_meta().get("workbench_model_framework")
999
+ try:
1000
+ return ModelFramework(model_framework)
1001
+ except ValueError:
1002
+ self.log.warning(f"Could not determine model framework for {self.model_name}!")
1003
+ return ModelFramework.UNKNOWN
1004
+
961
1005
  def _load_training_metrics(self):
962
1006
  """Internal: Retrieve the training metrics and Confusion Matrix for this model
963
1007
  and load the data into the Workbench Metadata
@@ -1149,13 +1193,11 @@ if __name__ == "__main__":
1149
1193
  # Grab a ModelCore object and pull some information from it
1150
1194
  my_model = ModelCore("abalone-regression")
1151
1195
 
1152
- # Call the various methods
1153
-
1154
1196
  # Let's do a check/validation of the Model
1155
1197
  print(f"Model Check: {my_model.exists()}")
1156
1198
 
1157
1199
  # Make sure the model is 'ready'
1158
- # my_model.onboard()
1200
+ my_model.onboard()
1159
1201
 
1160
1202
  # Get the ARN of the Model Group
1161
1203
  print(f"Model Group ARN: {my_model.group_arn()}")
@@ -1221,5 +1263,10 @@ if __name__ == "__main__":
1221
1263
  # Delete the Model
1222
1264
  # ModelCore.managed_delete("wine-classification")
1223
1265
 
1266
+ # Check the training view logic
1267
+ model = ModelCore("wine-class-test-251112-BW")
1268
+ training_view = model.training_view()
1269
+ print(f"Training View Name: {training_view.name}")
1270
+
1224
1271
  # Check for a model that doesn't exist
1225
1272
  my_model = ModelCore("empty-model-group")