workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,25 @@
1
+ # flake8: noqa: E402
2
+ import os
3
+ import sys
4
+ import logging
5
+ import importlib
6
+ import webbrowser
7
+ import readline # noqa: F401
8
+
9
+ # Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
10
+ # This is a known issue on macOS where libomp crashes during thread synchronization
11
+ # Must be set before importing numpy/pandas/torch or any library that uses OpenMP
12
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
13
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
14
+
15
+ import IPython
1
16
  from IPython import start_ipython
17
+ from distutils.version import LooseVersion
2
18
  from IPython.terminal.prompts import Prompts
3
19
  from IPython.terminal.ipapp import load_default_config
4
20
  from pygments.token import Token
5
- import sys
6
- import logging
7
- import importlib
8
21
  import botocore
9
- import webbrowser
10
22
  import pandas as pd
11
- import readline # noqa
12
23
 
13
24
  try:
14
25
  import matplotlib.pyplot as plt # noqa
@@ -39,7 +50,7 @@ from workbench.cached.cached_meta import CachedMeta
39
50
  try:
40
51
  import rdkit # noqa
41
52
  import mordred # noqa
42
- from workbench.utils import chem_utils
53
+ from workbench.utils.chem_utils import vis
43
54
 
44
55
  HAVE_CHEM_UTILS = True
45
56
  except ImportError:
@@ -70,7 +81,7 @@ if not ConfigManager().config_okay():
70
81
 
71
82
  # Set the log level to important
72
83
  log = logging.getLogger("workbench")
73
- log.setLevel(IMPORTANT_LEVEL_NUM)
84
+ log.setLevel(logging.INFO)
74
85
  log.addFilter(
75
86
  lambda record: not (
76
87
  record.getMessage().startswith("Async: Metadata") or record.getMessage().startswith("Updated Metadata")
@@ -176,12 +187,12 @@ class WorkbenchShell:
176
187
 
177
188
  # Add cheminformatics utils if available
178
189
  if HAVE_CHEM_UTILS:
179
- self.commands["show"] = chem_utils.show
190
+ self.commands["show"] = vis.show
180
191
 
181
192
  def start(self):
182
193
  """Start the Workbench IPython shell"""
183
194
  cprint("magenta", "\nWelcome to Workbench!")
184
- if self.aws_status is False:
195
+ if not self.aws_status:
185
196
  cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
186
197
  cprint("red", f"Path: {self.cm.site_config_path}")
187
198
  self.show_config()
@@ -202,7 +213,10 @@ class WorkbenchShell:
202
213
 
203
214
  # Start IPython with the config and commands in the namespace
204
215
  try:
205
- ipython_argv = ["--no-tip", "--theme", "linux"]
216
+ if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
217
+ ipython_argv = ["--no-tip", "--theme", "linux"]
218
+ else:
219
+ ipython_argv = []
206
220
  start_ipython(ipython_argv, user_ns=locs, config=config)
207
221
  finally:
208
222
  spinner = self.spinner_start("Goodbye to AWS:")
@@ -288,11 +302,6 @@ class WorkbenchShell:
288
302
  self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
289
303
  self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
290
304
 
291
- # Algorithms
292
- self.commands["FSP"] = importlib.import_module(
293
- "workbench.algorithms.dataframe.feature_space_proximity"
294
- ).FeatureSpaceProximity
295
-
296
305
  # These are 'nice to have' imports
297
306
  self.commands["pd"] = importlib.import_module("pandas")
298
307
  self.commands["wr"] = importlib.import_module("awswrangler")
@@ -520,7 +529,7 @@ class WorkbenchShell:
520
529
  def get_meta(self):
521
530
  return self.meta
522
531
 
523
- def plot_manager(self, data, plot_type: str = "table", **kwargs):
532
+ def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
524
533
  """Plot Manager for Workbench"""
525
534
  from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
526
535
 
@@ -555,14 +564,14 @@ class WorkbenchShell:
555
564
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
556
565
 
557
566
  # Get kwargs
558
- theme = kwargs.get("theme", "dark")
567
+ theme = kwargs.get("theme", "midnight_blue")
559
568
 
560
569
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
561
570
 
562
- # Run the server and open in the browser
563
- plugin_test.run()
571
+ # Open the browser and run the dash server
564
572
  url = f"http://127.0.0.1:{plugin_test.port}"
565
573
  webbrowser.open(url)
574
+ plugin_test.run()
566
575
 
567
576
 
568
577
  # Launch Shell Entry Point
@@ -1 +1 @@
1
- eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTAxLTE0In02zCDRy41wKRViRnGmodczFWexLyfXYrJWSuVQQbhbWeRttQRv6zpo9x4O2yBjdRfhb9E7mFUppNiOS_ZGK-bL71nGHt_Mc8niG8jkpvKX9qZ6BqkXF_vzDIOcI8iGiwB3wikeVO4zRLD1AI0U3cgYmIyGXI9QKJ9L7IHyQ0TWqw==
1
+ eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
@@ -0,0 +1,162 @@
1
+ """
2
+ Local test harness for SageMaker model scripts.
3
+
4
+ Usage:
5
+ python model_script_harness.py <local_script.py> <model_name>
6
+
7
+ Example:
8
+ python model_script_harness.py pytorch.py aqsol-reg-pytorch
9
+
10
+ This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
+ Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
12
+
13
+ Optional: testing/env.json with additional environment variables
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import importlib.util
20
+ import tempfile
21
+ import shutil
22
+ import pandas as pd
23
+ import torch
24
+
25
+ # Workbench Imports
26
+ from workbench.api import Model, FeatureSet
27
+ from workbench.utils.pytorch_utils import download_and_extract_model
28
+
29
+ # Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
31
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
32
+ torch.set_default_device("cpu")
33
+ # Disable MPS entirely
34
+ if hasattr(torch.backends, "mps"):
35
+ torch.backends.mps.is_available = lambda: False
36
+
37
+
38
+ def get_eval_data(workbench_model: Model) -> pd.DataFrame:
39
+ """Get evaluation data from the FeatureSet associated with this model."""
40
+ # Get the FeatureSet
41
+ fs_name = workbench_model.get_input()
42
+ fs = FeatureSet(fs_name)
43
+ if not fs.exists():
44
+ raise ValueError(f"No FeatureSet found: {fs_name}")
45
+
46
+ # Get evaluation data (training = FALSE)
47
+ table = workbench_model.training_view().table
48
+ print(f"Querying evaluation data from {table}...")
49
+ eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
50
+ print(f"Retrieved {len(eval_df)} evaluation rows")
51
+
52
+ return eval_df
53
+
54
+
55
+ def load_model_script(script_path: str):
56
+ """Dynamically load the model script module."""
57
+ if not os.path.exists(script_path):
58
+ raise FileNotFoundError(f"Script not found: {script_path}")
59
+
60
+ spec = importlib.util.spec_from_file_location("model_script", script_path)
61
+ module = importlib.util.module_from_spec(spec)
62
+
63
+ # Add to sys.modules so imports within the script work
64
+ sys.modules["model_script"] = module
65
+
66
+ spec.loader.exec_module(module)
67
+ return module
68
+
69
+
70
+ def main():
71
+ if len(sys.argv) < 3:
72
+ print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
+ print("\nArguments:")
74
+ print(" local_script.py - Path to your LOCAL model script to test")
75
+ print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
76
+ print("\nOptional: testing/env.json with additional environment variables")
77
+ sys.exit(1)
78
+
79
+ script_path = sys.argv[1]
80
+ model_name = sys.argv[2]
81
+
82
+ # Validate local script exists
83
+ if not os.path.exists(script_path):
84
+ print(f"Error: Local script not found: {script_path}")
85
+ sys.exit(1)
86
+
87
+ # Initialize Workbench model
88
+ print(f"Loading Workbench model: {model_name}")
89
+ workbench_model = Model(model_name)
90
+ print(f"Model Framework: {workbench_model.model_framework}")
91
+ print()
92
+
93
+ # Create a temporary model directory
94
+ model_dir = tempfile.mkdtemp(prefix="model_harness_")
95
+ print(f"Using model directory: {model_dir}")
96
+
97
+ try:
98
+ # Load environment variables from env.json if it exists
99
+ if os.path.exists("testing/env.json"):
100
+ print("Loading environment variables from testing/env.json")
101
+ with open("testing/env.json") as f:
102
+ env_vars = json.load(f)
103
+ for key, value in env_vars.items():
104
+ os.environ[key] = value
105
+ print(f" Set {key} = {value}")
106
+ print()
107
+
108
+ # Set up SageMaker environment variables
109
+ os.environ["SM_MODEL_DIR"] = model_dir
110
+ print(f"Set SM_MODEL_DIR = {model_dir}")
111
+
112
+ # Download and extract model artifacts
113
+ s3_uri = workbench_model.model_data_url()
114
+ download_and_extract_model(s3_uri, model_dir)
115
+ print()
116
+
117
+ # Load the LOCAL model script
118
+ print(f"Loading LOCAL model script: {script_path}")
119
+ module = load_model_script(script_path)
120
+ print()
121
+
122
+ # Check for required functions
123
+ if not hasattr(module, "model_fn"):
124
+ raise AttributeError("Model script must have a model_fn function")
125
+ if not hasattr(module, "predict_fn"):
126
+ raise AttributeError("Model script must have a predict_fn function")
127
+
128
+ # Load the model
129
+ print("Calling model_fn...")
130
+ print("-" * 50)
131
+ model = module.model_fn(model_dir)
132
+ print("-" * 50)
133
+ print(f"Model loaded: {type(model)}")
134
+ print()
135
+
136
+ # Get evaluation data from FeatureSet
137
+ print("Pulling evaluation data from FeatureSet...")
138
+ df = get_eval_data(workbench_model)
139
+ print(f"Input shape: {df.shape}")
140
+ print(f"Columns: {df.columns.tolist()}")
141
+ print()
142
+
143
+ print("Calling predict_fn...")
144
+ print("-" * 50)
145
+ result = module.predict_fn(df, model)
146
+ print("-" * 50)
147
+ print()
148
+
149
+ print("Prediction result:")
150
+ print(f"Output shape: {result.shape}")
151
+ print(f"Output columns: {result.columns.tolist()}")
152
+ print()
153
+ print(result.head(10).to_string())
154
+
155
+ finally:
156
+ # Cleanup
157
+ print(f"\nCleaning up model directory: {model_dir}")
158
+ shutil.rmtree(model_dir, ignore_errors=True)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,73 @@
1
+ """
2
+ Local test harness for AWS Lambda scripts.
3
+
4
+ Usage:
5
+ lambda_test <lambda_script.py>
6
+
7
+ Required: testing/event.json with the event definition
8
+ Options: testing/env.json with a set of ENV vars
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import json
14
+ import importlib.util
15
+
16
+
17
+ def main():
18
+ if len(sys.argv) != 2:
19
+ print("Usage: lambda_launcher <handler_module_name>")
20
+ print("\nOptional: testing/event.json with test event")
21
+ print("Optional: testing/env.json with environment variables")
22
+ sys.exit(1)
23
+
24
+ handler_file = sys.argv[1]
25
+
26
+ # Add .py if not present
27
+ if not handler_file.endswith(".py"):
28
+ handler_file += ".py"
29
+
30
+ # Check if file exists
31
+ if not os.path.exists(handler_file):
32
+ print(f"Error: File '{handler_file}' not found")
33
+ sys.exit(1)
34
+
35
+ # Load environment variables from env.json if it exists
36
+ if os.path.exists("testing/env.json"):
37
+ print("Loading environment variables from testing/env.json")
38
+ with open("testing/env.json") as f:
39
+ env_vars = json.load(f)
40
+ for key, value in env_vars.items():
41
+ os.environ[key] = value
42
+ print(f" Set {key} = {value}")
43
+ print()
44
+
45
+ # Load event configuration
46
+ if os.path.exists("testing/event.json"):
47
+ print("Loading event from testing/event.json")
48
+ with open("testing/event.json") as f:
49
+ event = json.load(f)
50
+ else:
51
+ print("No testing/event.json found, using empty event")
52
+ event = {}
53
+
54
+ # Load the module dynamically
55
+ spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
56
+ lambda_module = importlib.util.module_from_spec(spec)
57
+ spec.loader.exec_module(lambda_module)
58
+
59
+ # Call the lambda_handler
60
+ print(f"Invoking lambda_handler from {handler_file}...")
61
+ print("-" * 50)
62
+ print(f"Event: {json.dumps(event, indent=2)}")
63
+ print("-" * 50)
64
+
65
+ result = lambda_module.lambda_handler(event, {})
66
+
67
+ print("-" * 50)
68
+ print("Result:")
69
+ print(json.dumps(result, indent=2))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
@@ -0,0 +1,35 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import argparse
8
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Simulate and analyze ensemble model performance using MetaModelSimulator."
14
+ )
15
+ parser.add_argument(
16
+ "models",
17
+ nargs="+",
18
+ help="List of model endpoint names to include in the ensemble simulation.",
19
+ )
20
+ parser.add_argument(
21
+ "--id-column",
22
+ default="molecule_name",
23
+ help="Name of the ID column (default: molecule_name)",
24
+ )
25
+ args = parser.parse_args()
26
+ models = args.models
27
+ id_column = args.id_column
28
+
29
+ # Create MetaModelSimulator instance and generate report
30
+ sim = MetaModelSimulator(models, id_column=id_column)
31
+ sim.report()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -0,0 +1,137 @@
1
+ import argparse
2
+ import logging
3
+ import time
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ # Workbench Imports
8
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
9
+ from workbench.utils.config_manager import ConfigManager
10
+ from workbench.utils.s3_utils import upload_content_to_s3
11
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
12
+
13
+ log = logging.getLogger("workbench")
14
+ cm = ConfigManager()
15
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
16
+
17
+
18
+ def get_ecr_image_uri() -> str:
19
+ """Get the ECR image URI for the current region."""
20
+ region = AWSAccountClamp().region
21
+ return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
22
+
23
+
24
+ def get_batch_role_arn() -> str:
25
+ """Get the Batch execution role ARN."""
26
+ account_id = AWSAccountClamp().account_id
27
+ return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
28
+
29
+
30
+ def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
31
+ """
32
+ Helper method to log CloudWatch logs link with clickable URL and full URL display.
33
+
34
+ Args:
35
+ job: Batch job description dictionary
36
+ message_prefix: Prefix for the log message (default: "View logs")
37
+ """
38
+ log_stream = job.get("container", {}).get("logStreamName")
39
+ logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
40
+ if logs_url:
41
+ clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
42
+ log.info(f"{message_prefix}: {clickable_url}")
43
+ else:
44
+ log.info("Check AWS Batch console for logs")
45
+
46
+
47
+ def run_batch_job(script_path: str, size: str = "small") -> int:
48
+ """
49
+ Submit and monitor an AWS Batch job for ML pipeline execution.
50
+
51
+ Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
52
+
53
+ Args:
54
+ script_path: Local path to the ML pipeline script
55
+ size: Job size tier - "small" (default), "medium", or "large"
56
+ - small: 2 vCPU, 4GB RAM for lightweight processing
57
+ - medium: 4 vCPU, 8GB RAM for standard ML workloads
58
+ - large: 8 vCPU, 16GB RAM for heavy training/inference
59
+
60
+ Returns:
61
+ Exit code (0 for success/disconnected, non-zero for failure)
62
+ """
63
+ if size not in ["small", "medium", "large"]:
64
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
65
+
66
+ batch = AWSAccountClamp().boto3_session.client("batch")
67
+ script_name = Path(script_path).stem
68
+
69
+ # Upload script to S3
70
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
71
+ log.info(f"Uploading script to {s3_path}")
72
+ upload_content_to_s3(Path(script_path).read_text(), s3_path)
73
+
74
+ # Submit job
75
+ job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
76
+ response = batch.submit_job(
77
+ jobName=job_name,
78
+ jobQueue="workbench-job-queue",
79
+ jobDefinition=f"workbench-batch-{size}",
80
+ containerOverrides={
81
+ "environment": [
82
+ {"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
83
+ {"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
84
+ ]
85
+ },
86
+ )
87
+ job_id = response["jobId"]
88
+ log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
89
+
90
+ # Monitor job
91
+ last_status, running_start = None, None
92
+ while True:
93
+ job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
94
+ status = job["status"]
95
+
96
+ if status != last_status:
97
+ log.info(f"Job status: {status}")
98
+ last_status = status
99
+ if status == "RUNNING":
100
+ running_start = time.time()
101
+
102
+ # Disconnect after 2 minutes of running
103
+ if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
104
+ log.info("✅ ML Pipeline is running successfully!")
105
+ _log_cloudwatch_link(job, "📊 Monitor logs")
106
+ return 0
107
+
108
+ # Handle completion
109
+ if status in ["SUCCEEDED", "FAILED"]:
110
+ exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
111
+ msg = (
112
+ "Job completed successfully"
113
+ if status == "SUCCEEDED"
114
+ else f"Job failed: {job.get('statusReason', 'Unknown')}"
115
+ )
116
+ log.info(msg) if status == "SUCCEEDED" else log.error(msg)
117
+ _log_cloudwatch_link(job)
118
+ return exit_code
119
+
120
+ time.sleep(10)
121
+
122
+
123
+ def main():
124
+ """CLI entry point for running ML pipelines on AWS Batch."""
125
+ parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
126
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
127
+ args = parser.parse_args()
128
+ try:
129
+ exit_code = run_batch_job(args.script_file)
130
+ exit(exit_code)
131
+ except Exception as e:
132
+ log.error(f"Error: {e}")
133
+ exit(1)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()