workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,25 @@
1
+ # flake8: noqa: E402
2
+ import os
3
+ import sys
4
+ import logging
5
+ import importlib
6
+ import webbrowser
7
+ import readline # noqa: F401
8
+
9
+ # Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
10
+ # This is a known issue on macOS where libomp crashes during thread synchronization
11
+ # Must be set before importing numpy/pandas/torch or any library that uses OpenMP
12
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
13
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
14
+
1
15
  import IPython
2
16
  from IPython import start_ipython
3
17
  from distutils.version import LooseVersion
4
18
  from IPython.terminal.prompts import Prompts
5
19
  from IPython.terminal.ipapp import load_default_config
6
20
  from pygments.token import Token
7
- import sys
8
- import logging
9
- import importlib
10
21
  import botocore
11
- import webbrowser
12
22
  import pandas as pd
13
- import readline # noqa
14
23
 
15
24
  try:
16
25
  import matplotlib.pyplot as plt # noqa
@@ -293,11 +302,6 @@ class WorkbenchShell:
293
302
  self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
294
303
  self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
295
304
 
296
- # Algorithms
297
- self.commands["FSP"] = importlib.import_module(
298
- "workbench.algorithms.dataframe.feature_space_proximity"
299
- ).FeatureSpaceProximity
300
-
301
305
  # These are 'nice to have' imports
302
306
  self.commands["pd"] = importlib.import_module("pandas")
303
307
  self.commands["wr"] = importlib.import_module("awswrangler")
@@ -525,7 +529,7 @@ class WorkbenchShell:
525
529
  def get_meta(self):
526
530
  return self.meta
527
531
 
528
- def plot_manager(self, data, plot_type: str = "table", **kwargs):
532
+ def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
529
533
  """Plot Manager for Workbench"""
530
534
  from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
531
535
 
@@ -560,14 +564,14 @@ class WorkbenchShell:
560
564
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
561
565
 
562
566
  # Get kwargs
563
- theme = kwargs.get("theme", "dark")
567
+ theme = kwargs.get("theme", "midnight_blue")
564
568
 
565
569
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
566
570
 
567
- # Run the server and open in the browser
568
- plugin_test.run()
571
+ # Open the browser and run the dash server
569
572
  url = f"http://127.0.0.1:{plugin_test.port}"
570
573
  webbrowser.open(url)
574
+ plugin_test.run()
571
575
 
572
576
 
573
577
  # Launch Shell Entry Point
@@ -1 +1 @@
1
- eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTAxLTE0In02zCDRy41wKRViRnGmodczFWexLyfXYrJWSuVQQbhbWeRttQRv6zpo9x4O2yBjdRfhb9E7mFUppNiOS_ZGK-bL71nGHt_Mc8niG8jkpvKX9qZ6BqkXF_vzDIOcI8iGiwB3wikeVO4zRLD1AI0U3cgYmIyGXI9QKJ9L7IHyQ0TWqw==
1
+ eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
@@ -0,0 +1,162 @@
1
+ """
2
+ Local test harness for SageMaker model scripts.
3
+
4
+ Usage:
5
+ python model_script_harness.py <local_script.py> <model_name>
6
+
7
+ Example:
8
+ python model_script_harness.py pytorch.py aqsol-reg-pytorch
9
+
10
+ This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
+ Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
12
+
13
+ Optional: testing/env.json with additional environment variables
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import importlib.util
20
+ import tempfile
21
+ import shutil
22
+ import pandas as pd
23
+ import torch
24
+
25
+ # Workbench Imports
26
+ from workbench.api import Model, FeatureSet
27
+ from workbench.utils.pytorch_utils import download_and_extract_model
28
+
29
+ # Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
31
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
32
+ torch.set_default_device("cpu")
33
+ # Disable MPS entirely
34
+ if hasattr(torch.backends, "mps"):
35
+ torch.backends.mps.is_available = lambda: False
36
+
37
+
38
+ def get_eval_data(workbench_model: Model) -> pd.DataFrame:
39
+ """Get evaluation data from the FeatureSet associated with this model."""
40
+ # Get the FeatureSet
41
+ fs_name = workbench_model.get_input()
42
+ fs = FeatureSet(fs_name)
43
+ if not fs.exists():
44
+ raise ValueError(f"No FeatureSet found: {fs_name}")
45
+
46
+ # Get evaluation data (training = FALSE)
47
+ table = workbench_model.training_view().table
48
+ print(f"Querying evaluation data from {table}...")
49
+ eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
50
+ print(f"Retrieved {len(eval_df)} evaluation rows")
51
+
52
+ return eval_df
53
+
54
+
55
+ def load_model_script(script_path: str):
56
+ """Dynamically load the model script module."""
57
+ if not os.path.exists(script_path):
58
+ raise FileNotFoundError(f"Script not found: {script_path}")
59
+
60
+ spec = importlib.util.spec_from_file_location("model_script", script_path)
61
+ module = importlib.util.module_from_spec(spec)
62
+
63
+ # Add to sys.modules so imports within the script work
64
+ sys.modules["model_script"] = module
65
+
66
+ spec.loader.exec_module(module)
67
+ return module
68
+
69
+
70
+ def main():
71
+ if len(sys.argv) < 3:
72
+ print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
+ print("\nArguments:")
74
+ print(" local_script.py - Path to your LOCAL model script to test")
75
+ print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
76
+ print("\nOptional: testing/env.json with additional environment variables")
77
+ sys.exit(1)
78
+
79
+ script_path = sys.argv[1]
80
+ model_name = sys.argv[2]
81
+
82
+ # Validate local script exists
83
+ if not os.path.exists(script_path):
84
+ print(f"Error: Local script not found: {script_path}")
85
+ sys.exit(1)
86
+
87
+ # Initialize Workbench model
88
+ print(f"Loading Workbench model: {model_name}")
89
+ workbench_model = Model(model_name)
90
+ print(f"Model Framework: {workbench_model.model_framework}")
91
+ print()
92
+
93
+ # Create a temporary model directory
94
+ model_dir = tempfile.mkdtemp(prefix="model_harness_")
95
+ print(f"Using model directory: {model_dir}")
96
+
97
+ try:
98
+ # Load environment variables from env.json if it exists
99
+ if os.path.exists("testing/env.json"):
100
+ print("Loading environment variables from testing/env.json")
101
+ with open("testing/env.json") as f:
102
+ env_vars = json.load(f)
103
+ for key, value in env_vars.items():
104
+ os.environ[key] = value
105
+ print(f" Set {key} = {value}")
106
+ print()
107
+
108
+ # Set up SageMaker environment variables
109
+ os.environ["SM_MODEL_DIR"] = model_dir
110
+ print(f"Set SM_MODEL_DIR = {model_dir}")
111
+
112
+ # Download and extract model artifacts
113
+ s3_uri = workbench_model.model_data_url()
114
+ download_and_extract_model(s3_uri, model_dir)
115
+ print()
116
+
117
+ # Load the LOCAL model script
118
+ print(f"Loading LOCAL model script: {script_path}")
119
+ module = load_model_script(script_path)
120
+ print()
121
+
122
+ # Check for required functions
123
+ if not hasattr(module, "model_fn"):
124
+ raise AttributeError("Model script must have a model_fn function")
125
+ if not hasattr(module, "predict_fn"):
126
+ raise AttributeError("Model script must have a predict_fn function")
127
+
128
+ # Load the model
129
+ print("Calling model_fn...")
130
+ print("-" * 50)
131
+ model = module.model_fn(model_dir)
132
+ print("-" * 50)
133
+ print(f"Model loaded: {type(model)}")
134
+ print()
135
+
136
+ # Get evaluation data from FeatureSet
137
+ print("Pulling evaluation data from FeatureSet...")
138
+ df = get_eval_data(workbench_model)
139
+ print(f"Input shape: {df.shape}")
140
+ print(f"Columns: {df.columns.tolist()}")
141
+ print()
142
+
143
+ print("Calling predict_fn...")
144
+ print("-" * 50)
145
+ result = module.predict_fn(df, model)
146
+ print("-" * 50)
147
+ print()
148
+
149
+ print("Prediction result:")
150
+ print(f"Output shape: {result.shape}")
151
+ print(f"Output columns: {result.columns.tolist()}")
152
+ print()
153
+ print(result.head(10).to_string())
154
+
155
+ finally:
156
+ # Cleanup
157
+ print(f"\nCleaning up model directory: {model_dir}")
158
+ shutil.rmtree(model_dir, ignore_errors=True)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,73 @@
1
+ """
2
+ Local test harness for AWS Lambda scripts.
3
+
4
+ Usage:
5
+ lambda_test <lambda_script.py>
6
+
7
+ Required: testing/event.json with the event definition
8
+ Options: testing/env.json with a set of ENV vars
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import json
14
+ import importlib.util
15
+
16
+
17
+ def main():
18
+ if len(sys.argv) != 2:
19
+ print("Usage: lambda_launcher <handler_module_name>")
20
+ print("\nOptional: testing/event.json with test event")
21
+ print("Optional: testing/env.json with environment variables")
22
+ sys.exit(1)
23
+
24
+ handler_file = sys.argv[1]
25
+
26
+ # Add .py if not present
27
+ if not handler_file.endswith(".py"):
28
+ handler_file += ".py"
29
+
30
+ # Check if file exists
31
+ if not os.path.exists(handler_file):
32
+ print(f"Error: File '{handler_file}' not found")
33
+ sys.exit(1)
34
+
35
+ # Load environment variables from env.json if it exists
36
+ if os.path.exists("testing/env.json"):
37
+ print("Loading environment variables from testing/env.json")
38
+ with open("testing/env.json") as f:
39
+ env_vars = json.load(f)
40
+ for key, value in env_vars.items():
41
+ os.environ[key] = value
42
+ print(f" Set {key} = {value}")
43
+ print()
44
+
45
+ # Load event configuration
46
+ if os.path.exists("testing/event.json"):
47
+ print("Loading event from testing/event.json")
48
+ with open("testing/event.json") as f:
49
+ event = json.load(f)
50
+ else:
51
+ print("No testing/event.json found, using empty event")
52
+ event = {}
53
+
54
+ # Load the module dynamically
55
+ spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
56
+ lambda_module = importlib.util.module_from_spec(spec)
57
+ spec.loader.exec_module(lambda_module)
58
+
59
+ # Call the lambda_handler
60
+ print(f"Invoking lambda_handler from {handler_file}...")
61
+ print("-" * 50)
62
+ print(f"Event: {json.dumps(event, indent=2)}")
63
+ print("-" * 50)
64
+
65
+ result = lambda_module.lambda_handler(event, {})
66
+
67
+ print("-" * 50)
68
+ print("Result:")
69
+ print(json.dumps(result, indent=2))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
@@ -0,0 +1,35 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import argparse
8
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Simulate and analyze ensemble model performance using MetaModelSimulator."
14
+ )
15
+ parser.add_argument(
16
+ "models",
17
+ nargs="+",
18
+ help="List of model endpoint names to include in the ensemble simulation.",
19
+ )
20
+ parser.add_argument(
21
+ "--id-column",
22
+ default="molecule_name",
23
+ help="Name of the ID column (default: molecule_name)",
24
+ )
25
+ args = parser.parse_args()
26
+ models = args.models
27
+ id_column = args.id_column
28
+
29
+ # Create MetaModelSimulator instance and generate report
30
+ sim = MetaModelSimulator(models, id_column=id_column)
31
+ sim.report()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -1,6 +1,8 @@
1
1
  import argparse
2
+ import ast
2
3
  import logging
3
4
  import json
5
+ import re
4
6
  from pathlib import Path
5
7
 
6
8
  # Workbench Imports
@@ -13,27 +15,105 @@ cm = ConfigManager()
13
15
  workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
16
 
15
17
 
16
- def submit_to_sqs(script_path: str, size: str = "small") -> None:
18
+ def parse_workbench_batch(script_content: str) -> dict | None:
19
+ """Parse WORKBENCH_BATCH config from a script.
20
+
21
+ Looks for a dictionary assignment like:
22
+ WORKBENCH_BATCH = {
23
+ "outputs": ["feature_set_xyz"],
24
+ }
25
+ or:
26
+ WORKBENCH_BATCH = {
27
+ "inputs": ["feature_set_xyz"],
28
+ }
29
+
30
+ Args:
31
+ script_content: The Python script content as a string
32
+
33
+ Returns:
34
+ The parsed dictionary or None if not found
35
+ """
36
+ pattern = r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})"
37
+ match = re.search(pattern, script_content, re.DOTALL)
38
+ if match:
39
+ try:
40
+ return ast.literal_eval(match.group(1))
41
+ except (ValueError, SyntaxError) as e:
42
+ print(f"⚠️ Warning: Failed to parse WORKBENCH_BATCH: {e}")
43
+ return None
44
+ return None
45
+
46
+
47
+ def get_message_group_id(batch_config: dict | None) -> str:
48
+ """Derive MessageGroupId from outputs or inputs.
49
+
50
+ - Scripts with outputs use first output as group
51
+ - Scripts with inputs use first input as group
52
+ - Default to "ml-pipeline-jobs" if no config
53
+ """
54
+ if not batch_config:
55
+ return "ml-pipeline-jobs"
56
+
57
+ outputs = batch_config.get("outputs", [])
58
+ inputs = batch_config.get("inputs", [])
59
+
60
+ if outputs:
61
+ return outputs[0]
62
+ elif inputs:
63
+ return inputs[0]
64
+ else:
65
+ return "ml-pipeline-jobs"
66
+
67
+
68
+ def submit_to_sqs(
69
+ script_path: str,
70
+ size: str = "small",
71
+ realtime: bool = False,
72
+ dt: bool = False,
73
+ promote: bool = False,
74
+ ) -> None:
17
75
  """
18
76
  Upload script to S3 and submit message to SQS queue for processing.
77
+
19
78
  Args:
20
79
  script_path: Local path to the ML pipeline script
21
80
  size: Job size tier - "small" (default), "medium", or "large"
81
+ realtime: If True, sets serverless=False for real-time processing (default: False)
82
+ dt: If True, sets DT=True in environment (default: False)
83
+ promote: If True, sets PROMOTE=True in environment (default: False)
84
+
85
+ Raises:
86
+ ValueError: If size is invalid or script file not found
22
87
  """
23
88
  print(f"\n{'=' * 60}")
24
89
  print("🚀 SUBMITTING ML PIPELINE JOB")
25
90
  print(f"{'=' * 60}")
26
-
27
91
  if size not in ["small", "medium", "large"]:
28
92
  raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
93
+
29
94
  # Validate script exists
30
95
  script_file = Path(script_path)
31
96
  if not script_file.exists():
32
97
  raise FileNotFoundError(f"Script not found: {script_path}")
33
98
 
99
+ # Read script content and parse WORKBENCH_BATCH config
100
+ script_content = script_file.read_text()
101
+ batch_config = parse_workbench_batch(script_content)
102
+ group_id = get_message_group_id(batch_config)
103
+ outputs = (batch_config or {}).get("outputs", [])
104
+ inputs = (batch_config or {}).get("inputs", [])
105
+
34
106
  print(f"📄 Script: {script_file.name}")
35
107
  print(f"📏 Size tier: {size}")
108
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
109
+ print(f"🔄 DynamicTraining: {dt}")
110
+ print(f"🆕 Promote: {promote}")
36
111
  print(f"🪣 Bucket: {workbench_bucket}")
112
+ if outputs:
113
+ print(f"📤 Outputs: {outputs}")
114
+ if inputs:
115
+ print(f"📥 Inputs: {inputs}")
116
+ print(f"📦 Batch Group: {group_id}")
37
117
  sqs = AWSAccountClamp().boto3_session.client("sqs")
38
118
  script_name = script_file.name
39
119
 
@@ -59,7 +139,7 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
59
139
  print(f" Destination: {s3_path}")
60
140
 
61
141
  try:
62
- upload_content_to_s3(script_file.read_text(), s3_path)
142
+ upload_content_to_s3(script_content, s3_path)
63
143
  print("✅ Script uploaded successfully")
64
144
  except Exception as e:
65
145
  print(f"❌ Upload failed: {e}")
@@ -88,14 +168,21 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
88
168
 
89
169
  # Prepare message
90
170
  message = {"script_path": s3_path, "size": size}
91
- print("\n📨 Sending message to SQS...")
171
+
172
+ # Set environment variables
173
+ message["environment"] = {
174
+ "SERVERLESS": "False" if realtime else "True",
175
+ "DT": str(dt),
176
+ "PROMOTE": str(promote),
177
+ }
92
178
 
93
179
  # Send the message to SQS
94
180
  try:
181
+ print("\n📨 Sending message to SQS...")
95
182
  response = sqs.send_message(
96
183
  QueueUrl=queue_url,
97
184
  MessageBody=json.dumps(message, indent=2),
98
- MessageGroupId="ml-pipeline-jobs", # Required for FIFO
185
+ MessageGroupId=group_id, # From WORKBENCH_BATCH or default
99
186
  )
100
187
  message_id = response["MessageId"]
101
188
  print("✅ Message sent successfully!")
@@ -110,6 +197,14 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
110
197
  print(f"{'=' * 60}")
111
198
  print(f"📄 Script: {script_name}")
112
199
  print(f"📏 Size: {size}")
200
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
201
+ print(f"🔄 DynamicTraining: {dt}")
202
+ print(f"🆕 Promote: {promote}")
203
+ if outputs:
204
+ print(f"📤 Outputs: {outputs}")
205
+ if inputs:
206
+ print(f"📥 Inputs: {inputs}")
207
+ print(f"📦 Batch Group: {group_id}")
113
208
  print(f"🆔 Message ID: {message_id}")
114
209
  print("\n🔍 MONITORING LOCATIONS:")
115
210
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -126,9 +221,30 @@ def main():
126
221
  parser.add_argument(
127
222
  "--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
128
223
  )
224
+ parser.add_argument(
225
+ "--realtime",
226
+ action="store_true",
227
+ help="Create realtime endpoints (default is serverless)",
228
+ )
229
+ parser.add_argument(
230
+ "--dt",
231
+ action="store_true",
232
+ help="Set DT=True (models and endpoints will have '-dt' suffix)",
233
+ )
234
+ parser.add_argument(
235
+ "--promote",
236
+ action="store_true",
237
+ help="Set Promote=True (models and endpoints will use promoted naming",
238
+ )
129
239
  args = parser.parse_args()
130
240
  try:
131
- submit_to_sqs(args.script_file, args.size)
241
+ submit_to_sqs(
242
+ args.script_file,
243
+ args.size,
244
+ realtime=args.realtime,
245
+ dt=args.dt,
246
+ promote=args.promote,
247
+ )
132
248
  except Exception as e:
133
249
  print(f"\n❌ ERROR: {e}")
134
250
  log.error(f"Error: {e}")
@@ -0,0 +1,85 @@
1
+ """
2
+ Local test harness for SageMaker training scripts.
3
+
4
+ Usage:
5
+ python training_test.py <model_script.py> <featureset_name>
6
+
7
+ Example:
8
+ python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
9
+ """
10
+
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+
17
+ import pandas as pd
18
+
19
+ from workbench.api import FeatureSet
20
+
21
+
22
+ def get_training_data(featureset_name: str) -> pd.DataFrame:
23
+ """Get training data from the FeatureSet."""
24
+ fs = FeatureSet(featureset_name)
25
+ return fs.pull_dataframe()
26
+
27
+
28
+ def main():
29
+ if len(sys.argv) < 3:
30
+ print("Usage: python training_test.py <model_script.py> <featureset_name>")
31
+ sys.exit(1)
32
+
33
+ script_path = sys.argv[1]
34
+ featureset_name = sys.argv[2]
35
+
36
+ if not os.path.exists(script_path):
37
+ print(f"Error: Script not found: {script_path}")
38
+ sys.exit(1)
39
+
40
+ # Create temp directories
41
+ model_dir = tempfile.mkdtemp(prefix="training_model_")
42
+ train_dir = tempfile.mkdtemp(prefix="training_data_")
43
+ output_dir = tempfile.mkdtemp(prefix="training_output_")
44
+
45
+ print(f"Model dir: {model_dir}")
46
+ print(f"Train dir: {train_dir}")
47
+
48
+ try:
49
+ # Get training data and save to CSV
50
+ print(f"Loading FeatureSet: {featureset_name}")
51
+ df = get_training_data(featureset_name)
52
+ print(f"Data shape: {df.shape}")
53
+
54
+ train_file = os.path.join(train_dir, "training_data.csv")
55
+ df.to_csv(train_file, index=False)
56
+
57
+ # Set up environment
58
+ env = os.environ.copy()
59
+ env["SM_MODEL_DIR"] = model_dir
60
+ env["SM_CHANNEL_TRAIN"] = train_dir
61
+ env["SM_OUTPUT_DATA_DIR"] = output_dir
62
+
63
+ print("\n" + "=" * 60)
64
+ print("Starting training...")
65
+ print("=" * 60 + "\n")
66
+
67
+ # Run the script
68
+ cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
69
+ result = subprocess.run(cmd, env=env)
70
+
71
+ print("\n" + "=" * 60)
72
+ if result.returncode == 0:
73
+ print("Training completed successfully!")
74
+ else:
75
+ print(f"Training failed with return code: {result.returncode}")
76
+ print("=" * 60)
77
+
78
+ finally:
79
+ shutil.rmtree(model_dir, ignore_errors=True)
80
+ shutil.rmtree(train_dir, ignore_errors=True)
81
+ shutil.rmtree(output_dir, ignore_errors=True)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
@@ -110,6 +110,40 @@ a:hover {
110
110
  color: rgb(100, 255, 100);
111
111
  }
112
112
 
113
+ /* Dropdown styling (dcc.Dropdown) - override Bootstrap's variables */
114
+ .dash-dropdown {
115
+ --bs-body-bg: rgb(35, 35, 35);
116
+ --bs-body-color: rgb(210, 210, 210);
117
+ --bs-border-color: rgb(60, 60, 60);
118
+ }
119
+
120
+ /* Bootstrap form controls (dbc components) */
121
+ .form-select, .form-control {
122
+ background-color: rgb(35, 35, 35) !important;
123
+ border: 1px solid rgb(60, 60, 60) !important;
124
+ color: rgb(210, 210, 210) !important;
125
+ }
126
+
127
+ .form-select:focus, .form-control:focus {
128
+ background-color: rgb(45, 45, 45) !important;
129
+ border-color: rgb(80, 80, 80) !important;
130
+ box-shadow: 0 0 0 0.2rem rgba(80, 80, 80, 0.25) !important;
131
+ }
132
+
133
+ .dropdown-menu {
134
+ background-color: rgb(35, 35, 35) !important;
135
+ border: 1px solid rgb(60, 60, 60) !important;
136
+ }
137
+
138
+ .dropdown-item {
139
+ color: rgb(210, 210, 210) !important;
140
+ }
141
+
142
+ .dropdown-item:hover, .dropdown-item:focus {
143
+ background-color: rgb(50, 50, 50) !important;
144
+ color: rgb(230, 230, 230) !important;
145
+ }
146
+
113
147
  /* Table styling */
114
148
  table {
115
149
  width: 100%;
@@ -128,4 +162,29 @@ td {
128
162
  padding: 5px;
129
163
  border: 0.5px solid #444;
130
164
  text-align: center !important;
165
+ }
166
+
167
+ /* AG Grid table header colors - gradient theme */
168
+ /* Data Sources tables - red gradient */
169
+ #main_data_sources .ag-header,
170
+ #data_sources_table .ag-header {
171
+ background: linear-gradient(180deg, rgb(140, 60, 60) 0%, rgb(80, 35, 35) 100%) !important;
172
+ }
173
+
174
+ /* Feature Sets tables - yellow/olive gradient */
175
+ #main_feature_sets .ag-header,
176
+ #feature_sets_table .ag-header {
177
+ background: linear-gradient(180deg, rgb(120, 115, 55) 0%, rgb(70, 65, 30) 100%) !important;
178
+ }
179
+
180
+ /* Models tables - green gradient */
181
+ #main_models .ag-header,
182
+ #models_table .ag-header {
183
+ background: linear-gradient(180deg, rgb(55, 110, 55) 0%, rgb(30, 60, 30) 100%) !important;
184
+ }
185
+
186
+ /* Endpoints tables - purple gradient */
187
+ #main_endpoints .ag-header,
188
+ #endpoints_table .ag-header {
189
+ background: linear-gradient(180deg, rgb(100, 60, 120) 0%, rgb(55, 30, 70) 100%) !important;
131
190
  }