workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/smart_aggregator.py +17 -12
- workbench/api/endpoint.py +13 -4
- workbench/api/model.py +2 -2
- workbench/cached/cached_model.py +2 -2
- workbench/core/artifacts/athena_source.py +5 -3
- workbench/core/artifacts/endpoint_core.py +30 -5
- workbench/core/cloud_platform/aws/aws_meta.py +2 -1
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
- workbench/model_script_utils/model_script_utils.py +225 -0
- workbench/model_script_utils/uq_harness.py +39 -21
- workbench/model_scripts/chemprop/chemprop.template +30 -15
- workbench/model_scripts/chemprop/generated_model_script.py +35 -18
- workbench/model_scripts/chemprop/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
- workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/pytorch.template +28 -14
- workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
- workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
- workbench/model_scripts/xgb_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/xgb_model.template +29 -18
- workbench/scripts/ml_pipeline_batch.py +47 -2
- workbench/scripts/ml_pipeline_launcher.py +410 -0
- workbench/scripts/ml_pipeline_sqs.py +22 -2
- workbench/themes/dark/custom.css +29 -0
- workbench/themes/light/custom.css +29 -0
- workbench/themes/midnight_blue/custom.css +28 -0
- workbench/utils/model_utils.py +9 -0
- workbench/utils/theme_manager.py +95 -0
- workbench/web_interface/components/component_interface.py +3 -0
- workbench/web_interface/components/plugin_interface.py +26 -0
- workbench/web_interface/components/plugins/ag_table.py +4 -11
- workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
- workbench/web_interface/components/plugins/model_plot.py +156 -0
- workbench/web_interface/components/plugins/scatter_plot.py +9 -2
- workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
- workbench/web_interface/components/settings_menu.py +10 -49
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
- workbench/web_interface/components/model_plot.py +0 -75
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""Launch ML pipelines via SQS for testing.
|
|
2
|
+
|
|
3
|
+
Run this from a directory containing pipeline subdirectories (e.g., ml_pipelines/).
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
ml_pipeline_launcher --dt # Launch 1 random pipeline group (all scripts in a directory)
|
|
7
|
+
ml_pipeline_launcher --dt -n 3 # Launch 3 random pipeline groups
|
|
8
|
+
ml_pipeline_launcher --dt --all # Launch ALL pipelines
|
|
9
|
+
ml_pipeline_launcher --dt caco2 # Launch pipelines matching 'caco2'
|
|
10
|
+
ml_pipeline_launcher --dt caco2 ppb # Launch pipelines matching 'caco2' or 'ppb'
|
|
11
|
+
ml_pipeline_launcher --promote --all # Promote ALL pipelines
|
|
12
|
+
ml_pipeline_launcher --test-promote --all # Test-promote ALL pipelines
|
|
13
|
+
ml_pipeline_launcher --dt --dry-run # Show what would be launched without launching
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import ast
|
|
18
|
+
import random
|
|
19
|
+
import re
|
|
20
|
+
import subprocess
|
|
21
|
+
import time
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_workbench_batch(script_path: Path) -> dict | None:
|
|
26
|
+
"""Parse WORKBENCH_BATCH config from a script file."""
|
|
27
|
+
content = script_path.read_text()
|
|
28
|
+
match = re.search(r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})", content, re.DOTALL)
|
|
29
|
+
if match:
|
|
30
|
+
try:
|
|
31
|
+
return ast.literal_eval(match.group(1))
|
|
32
|
+
except (ValueError, SyntaxError):
|
|
33
|
+
return None
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_dependency_graph(configs: dict[Path, dict]) -> dict[str, str]:
|
|
38
|
+
"""Build a mapping from each output to its root producer.
|
|
39
|
+
|
|
40
|
+
For a chain like A -> B -> C (where B depends on A, C depends on B),
|
|
41
|
+
this returns {A: A, B: A, C: A} so all are in the same message group.
|
|
42
|
+
"""
|
|
43
|
+
# Build output -> input mapping (what does each output depend on?)
|
|
44
|
+
output_to_input = {}
|
|
45
|
+
for config in configs.values():
|
|
46
|
+
if not config:
|
|
47
|
+
continue
|
|
48
|
+
outputs = config.get("outputs", [])
|
|
49
|
+
inputs = config.get("inputs", [])
|
|
50
|
+
for output in outputs:
|
|
51
|
+
output_to_input[output] = inputs[0] if inputs else None
|
|
52
|
+
|
|
53
|
+
# Walk chain to find root
|
|
54
|
+
def find_root(output: str, visited: set = None) -> str:
|
|
55
|
+
if visited is None:
|
|
56
|
+
visited = set()
|
|
57
|
+
if output in visited:
|
|
58
|
+
return output
|
|
59
|
+
visited.add(output)
|
|
60
|
+
parent = output_to_input.get(output)
|
|
61
|
+
if parent is None:
|
|
62
|
+
return output
|
|
63
|
+
return find_root(parent, visited)
|
|
64
|
+
|
|
65
|
+
return {output: find_root(output) for output in output_to_input}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_group_id(config: dict | None, root_map: dict[str, str]) -> str | None:
|
|
69
|
+
"""Get the root group_id for a pipeline based on its config and root_map."""
|
|
70
|
+
if not config:
|
|
71
|
+
return None
|
|
72
|
+
outputs = config.get("outputs", [])
|
|
73
|
+
inputs = config.get("inputs", [])
|
|
74
|
+
# Check inputs first (this script depends on something)
|
|
75
|
+
if inputs and inputs[0] in root_map:
|
|
76
|
+
return root_map[inputs[0]]
|
|
77
|
+
# Check outputs (this script produces something)
|
|
78
|
+
if outputs and outputs[0] in root_map:
|
|
79
|
+
return root_map[outputs[0]]
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def sort_by_dependencies(pipelines: list[Path]) -> tuple[list[Path], dict[Path, dict], dict[str, str]]:
|
|
84
|
+
"""Sort pipelines by dependency chains. Returns (sorted_list, configs, root_map)."""
|
|
85
|
+
# Parse all configs
|
|
86
|
+
configs = {}
|
|
87
|
+
for pipeline in pipelines:
|
|
88
|
+
configs[pipeline] = parse_workbench_batch(pipeline)
|
|
89
|
+
|
|
90
|
+
# Build root map for group_id resolution
|
|
91
|
+
root_map = build_dependency_graph(configs)
|
|
92
|
+
|
|
93
|
+
# Build output -> pipeline mapping
|
|
94
|
+
output_to_pipeline = {}
|
|
95
|
+
for pipeline, config in configs.items():
|
|
96
|
+
if config and config.get("outputs"):
|
|
97
|
+
for output in config["outputs"]:
|
|
98
|
+
output_to_pipeline[output] = pipeline
|
|
99
|
+
|
|
100
|
+
# Build chains by walking from root producers
|
|
101
|
+
sorted_pipelines = []
|
|
102
|
+
used = set()
|
|
103
|
+
|
|
104
|
+
for pipeline in sorted(pipelines):
|
|
105
|
+
config = configs.get(pipeline)
|
|
106
|
+
|
|
107
|
+
# Skip if already used or has inputs (not a root)
|
|
108
|
+
if pipeline in used:
|
|
109
|
+
continue
|
|
110
|
+
if config and config.get("inputs"):
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
# Walk the chain from this root
|
|
114
|
+
chain = [pipeline]
|
|
115
|
+
used.add(pipeline)
|
|
116
|
+
|
|
117
|
+
current = pipeline
|
|
118
|
+
while True:
|
|
119
|
+
current_config = configs.get(current)
|
|
120
|
+
if not current_config or not current_config.get("outputs"):
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
current_output = current_config["outputs"][0]
|
|
124
|
+
# Find pipeline that consumes this output
|
|
125
|
+
next_pipeline = None
|
|
126
|
+
for p, c in configs.items():
|
|
127
|
+
if p in used or p not in pipelines:
|
|
128
|
+
continue
|
|
129
|
+
if c and c.get("inputs") and current_output in c["inputs"]:
|
|
130
|
+
next_pipeline = p
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
if next_pipeline:
|
|
134
|
+
chain.append(next_pipeline)
|
|
135
|
+
used.add(next_pipeline)
|
|
136
|
+
current = next_pipeline
|
|
137
|
+
else:
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
sorted_pipelines.extend(chain)
|
|
141
|
+
|
|
142
|
+
# Add any remaining pipelines not in chains
|
|
143
|
+
for pipeline in sorted(pipelines):
|
|
144
|
+
if pipeline not in used:
|
|
145
|
+
sorted_pipelines.append(pipeline)
|
|
146
|
+
|
|
147
|
+
return sorted_pipelines, configs, root_map
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def format_dependency_chains(pipelines: list[Path], configs: dict[Path, dict]) -> list[str]:
|
|
151
|
+
"""Format pipelines as dependency chains for display."""
|
|
152
|
+
# Build output -> pipeline mapping
|
|
153
|
+
output_to_pipeline = {}
|
|
154
|
+
for pipeline, config in configs.items():
|
|
155
|
+
if config and config.get("outputs"):
|
|
156
|
+
for output in config["outputs"]:
|
|
157
|
+
output_to_pipeline[output] = pipeline
|
|
158
|
+
|
|
159
|
+
# Build chains by walking from root producers
|
|
160
|
+
chains = []
|
|
161
|
+
used = set()
|
|
162
|
+
|
|
163
|
+
for pipeline in pipelines:
|
|
164
|
+
config = configs.get(pipeline)
|
|
165
|
+
|
|
166
|
+
# Skip if already part of a chain or has inputs (not a root)
|
|
167
|
+
if pipeline in used:
|
|
168
|
+
continue
|
|
169
|
+
if config and config.get("inputs"):
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
# Start a new chain from this root producer (or standalone)
|
|
173
|
+
chain = [pipeline]
|
|
174
|
+
used.add(pipeline)
|
|
175
|
+
|
|
176
|
+
# Walk the chain: find who consumes our output
|
|
177
|
+
current = pipeline
|
|
178
|
+
while True:
|
|
179
|
+
current_config = configs.get(current)
|
|
180
|
+
if not current_config or not current_config.get("outputs"):
|
|
181
|
+
break
|
|
182
|
+
|
|
183
|
+
current_output = current_config["outputs"][0]
|
|
184
|
+
# Find a pipeline that takes this output as input
|
|
185
|
+
next_pipeline = None
|
|
186
|
+
for p, c in configs.items():
|
|
187
|
+
if p in used or p not in pipelines:
|
|
188
|
+
continue
|
|
189
|
+
if c and c.get("inputs") and current_output in c["inputs"]:
|
|
190
|
+
next_pipeline = p
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
if next_pipeline:
|
|
194
|
+
chain.append(next_pipeline)
|
|
195
|
+
used.add(next_pipeline)
|
|
196
|
+
current = next_pipeline
|
|
197
|
+
else:
|
|
198
|
+
break
|
|
199
|
+
|
|
200
|
+
chains.append(chain)
|
|
201
|
+
|
|
202
|
+
# Add any remaining pipelines not in chains (shouldn't happen but just in case)
|
|
203
|
+
for pipeline in pipelines:
|
|
204
|
+
if pipeline not in used:
|
|
205
|
+
chains.append([pipeline])
|
|
206
|
+
|
|
207
|
+
# Format chains as strings
|
|
208
|
+
lines = []
|
|
209
|
+
for chain in chains:
|
|
210
|
+
names = [p.stem for p in chain]
|
|
211
|
+
lines.append(" " + " --> ".join(names))
|
|
212
|
+
|
|
213
|
+
return lines
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def get_all_pipelines() -> list[Path]:
|
|
217
|
+
"""Get all ML pipeline scripts from subdirectories of current working directory."""
|
|
218
|
+
cwd = Path.cwd()
|
|
219
|
+
# Find all .py files in subdirectories (not in cwd itself)
|
|
220
|
+
pipelines = []
|
|
221
|
+
for subdir in cwd.iterdir():
|
|
222
|
+
if subdir.is_dir():
|
|
223
|
+
pipelines.extend(subdir.rglob("*.py"))
|
|
224
|
+
return pipelines
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def get_pipeline_groups(pipelines: list[Path]) -> dict[Path, list[Path]]:
|
|
228
|
+
"""Group pipelines by their parent directory (leaf directories)."""
|
|
229
|
+
groups = {}
|
|
230
|
+
for pipeline in pipelines:
|
|
231
|
+
parent = pipeline.parent
|
|
232
|
+
groups.setdefault(parent, []).append(pipeline)
|
|
233
|
+
return groups
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def select_random_groups(pipelines: list[Path], num_groups: int) -> list[Path]:
|
|
237
|
+
"""Select pipelines from n random leaf directories."""
|
|
238
|
+
groups = get_pipeline_groups(pipelines)
|
|
239
|
+
if not groups:
|
|
240
|
+
return []
|
|
241
|
+
|
|
242
|
+
# Select up to num_groups random directories
|
|
243
|
+
dirs = list(groups.keys())
|
|
244
|
+
selected_dirs = random.sample(dirs, min(num_groups, len(dirs)))
|
|
245
|
+
|
|
246
|
+
# Return all pipelines from those directories
|
|
247
|
+
selected = []
|
|
248
|
+
for d in selected_dirs:
|
|
249
|
+
selected.extend(groups[d])
|
|
250
|
+
return selected
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def filter_pipelines_by_patterns(pipelines: list[Path], patterns: list[str]) -> list[Path]:
|
|
254
|
+
"""Filter pipelines by substring patterns matching the basename."""
|
|
255
|
+
if not patterns:
|
|
256
|
+
return pipelines
|
|
257
|
+
|
|
258
|
+
matched = []
|
|
259
|
+
for pipeline in pipelines:
|
|
260
|
+
basename = pipeline.stem.lower()
|
|
261
|
+
if any(pattern.lower() in basename for pattern in patterns):
|
|
262
|
+
matched.append(pipeline)
|
|
263
|
+
return matched
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def main():
|
|
267
|
+
parser = argparse.ArgumentParser(description="Launch ML pipelines via SQS for testing")
|
|
268
|
+
parser.add_argument(
|
|
269
|
+
"patterns",
|
|
270
|
+
nargs="*",
|
|
271
|
+
help="Substring patterns to filter pipelines by basename (e.g., 'caco2' 'ppb')",
|
|
272
|
+
)
|
|
273
|
+
parser.add_argument(
|
|
274
|
+
"-n",
|
|
275
|
+
"--num-groups",
|
|
276
|
+
type=int,
|
|
277
|
+
default=1,
|
|
278
|
+
help="Number of random pipeline groups to launch (default: 1, ignored if --all or patterns specified)",
|
|
279
|
+
)
|
|
280
|
+
parser.add_argument(
|
|
281
|
+
"--all",
|
|
282
|
+
action="store_true",
|
|
283
|
+
help="Launch ALL pipelines (ignores -n)",
|
|
284
|
+
)
|
|
285
|
+
parser.add_argument(
|
|
286
|
+
"--realtime",
|
|
287
|
+
action="store_true",
|
|
288
|
+
help="Create realtime endpoints (default is serverless)",
|
|
289
|
+
)
|
|
290
|
+
parser.add_argument(
|
|
291
|
+
"--dry-run",
|
|
292
|
+
action="store_true",
|
|
293
|
+
help="Show what would be launched without actually launching",
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Mode flags (mutually exclusive)
|
|
297
|
+
mode_group = parser.add_mutually_exclusive_group(required=True)
|
|
298
|
+
mode_group.add_argument(
|
|
299
|
+
"--dt",
|
|
300
|
+
action="store_true",
|
|
301
|
+
help="Launch with DT=True (dynamic training mode)",
|
|
302
|
+
)
|
|
303
|
+
mode_group.add_argument(
|
|
304
|
+
"--promote",
|
|
305
|
+
action="store_true",
|
|
306
|
+
help="Launch with PROMOTE=True (promotion mode)",
|
|
307
|
+
)
|
|
308
|
+
mode_group.add_argument(
|
|
309
|
+
"--test-promote",
|
|
310
|
+
action="store_true",
|
|
311
|
+
help="Launch with TEST_PROMOTE=True (test promotion mode)",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
args = parser.parse_args()
|
|
315
|
+
|
|
316
|
+
# Get all pipelines from subdirectories of current working directory
|
|
317
|
+
all_pipelines = get_all_pipelines()
|
|
318
|
+
if not all_pipelines:
|
|
319
|
+
print(f"No pipeline scripts found in subdirectories of {Path.cwd()}")
|
|
320
|
+
exit(1)
|
|
321
|
+
|
|
322
|
+
# Determine which pipelines to run
|
|
323
|
+
if args.patterns:
|
|
324
|
+
# Filter by patterns
|
|
325
|
+
selected_pipelines = filter_pipelines_by_patterns(all_pipelines, args.patterns)
|
|
326
|
+
if not selected_pipelines:
|
|
327
|
+
print(f"No pipelines matching patterns: {args.patterns}")
|
|
328
|
+
exit(1)
|
|
329
|
+
selection_mode = f"matching {args.patterns}"
|
|
330
|
+
elif args.all:
|
|
331
|
+
# Run all pipelines
|
|
332
|
+
selected_pipelines = all_pipelines
|
|
333
|
+
selection_mode = "ALL"
|
|
334
|
+
else:
|
|
335
|
+
# Random group selection
|
|
336
|
+
selected_pipelines = select_random_groups(all_pipelines, args.num_groups)
|
|
337
|
+
if not selected_pipelines:
|
|
338
|
+
print("No pipeline groups found")
|
|
339
|
+
exit(1)
|
|
340
|
+
# Get the directory names for display
|
|
341
|
+
groups = get_pipeline_groups(selected_pipelines)
|
|
342
|
+
group_names = [d.name for d in groups.keys()]
|
|
343
|
+
selection_mode = f"RANDOM {args.num_groups} group(s): {group_names}"
|
|
344
|
+
|
|
345
|
+
# Sort by dependencies (producers before consumers)
|
|
346
|
+
selected_pipelines, configs, root_map = sort_by_dependencies(selected_pipelines)
|
|
347
|
+
|
|
348
|
+
# Determine mode for display and CLI flag
|
|
349
|
+
if args.dt:
|
|
350
|
+
mode_name = "DT (Dynamic Training)"
|
|
351
|
+
mode_flag = "--dt"
|
|
352
|
+
elif args.promote:
|
|
353
|
+
mode_name = "PROMOTE"
|
|
354
|
+
mode_flag = "--promote"
|
|
355
|
+
else:
|
|
356
|
+
mode_name = "TEST_PROMOTE"
|
|
357
|
+
mode_flag = "--test-promote"
|
|
358
|
+
|
|
359
|
+
print(f"\n{'=' * 60}")
|
|
360
|
+
print(f"{'DRY RUN - ' if args.dry_run else ''}LAUNCHING {len(selected_pipelines)} PIPELINES")
|
|
361
|
+
print(f"{'=' * 60}")
|
|
362
|
+
print(f"Source: {Path.cwd()}")
|
|
363
|
+
print(f"Selection: {selection_mode}")
|
|
364
|
+
print(f"Mode: {mode_name}")
|
|
365
|
+
print(f"Endpoint: {'Realtime' if args.realtime else 'Serverless'}")
|
|
366
|
+
print("\nPipeline Chains:")
|
|
367
|
+
for line in format_dependency_chains(selected_pipelines, configs):
|
|
368
|
+
print(line)
|
|
369
|
+
print()
|
|
370
|
+
|
|
371
|
+
# Dry run - just show what would be launched
|
|
372
|
+
if args.dry_run:
|
|
373
|
+
print("Dry run complete. No pipelines were launched.\n")
|
|
374
|
+
return
|
|
375
|
+
|
|
376
|
+
# Countdown before launching
|
|
377
|
+
print("Launching in ", end="", flush=True)
|
|
378
|
+
for i in range(10, 0, -1):
|
|
379
|
+
print(f"{i}...", end="", flush=True)
|
|
380
|
+
time.sleep(1)
|
|
381
|
+
print(" GO!\n")
|
|
382
|
+
|
|
383
|
+
# Launch each pipeline using the CLI
|
|
384
|
+
for i, pipeline in enumerate(selected_pipelines, 1):
|
|
385
|
+
print(f"\n{'─' * 60}")
|
|
386
|
+
print(f"Launching pipeline {i}/{len(selected_pipelines)}: {pipeline.name}")
|
|
387
|
+
print(f"{'─' * 60}")
|
|
388
|
+
|
|
389
|
+
# Build the command
|
|
390
|
+
cmd = ["ml_pipeline_sqs", str(pipeline), mode_flag]
|
|
391
|
+
if args.realtime:
|
|
392
|
+
cmd.append("--realtime")
|
|
393
|
+
|
|
394
|
+
# Pass root group_id for dependency chain ordering
|
|
395
|
+
group_id = get_group_id(configs.get(pipeline), root_map)
|
|
396
|
+
if group_id:
|
|
397
|
+
cmd.extend(["--group-id", group_id])
|
|
398
|
+
|
|
399
|
+
print(f"Running: {' '.join(cmd)}\n")
|
|
400
|
+
result = subprocess.run(cmd)
|
|
401
|
+
if result.returncode != 0:
|
|
402
|
+
print(f"Failed to launch {pipeline.name} (exit code: {result.returncode})")
|
|
403
|
+
|
|
404
|
+
print(f"\n{'=' * 60}")
|
|
405
|
+
print(f"FINISHED LAUNCHING {len(selected_pipelines)} PIPELINES")
|
|
406
|
+
print(f"{'=' * 60}\n")
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
if __name__ == "__main__":
|
|
410
|
+
main()
|
|
@@ -71,6 +71,8 @@ def submit_to_sqs(
|
|
|
71
71
|
realtime: bool = False,
|
|
72
72
|
dt: bool = False,
|
|
73
73
|
promote: bool = False,
|
|
74
|
+
test_promote: bool = False,
|
|
75
|
+
group_id: str | None = None,
|
|
74
76
|
) -> None:
|
|
75
77
|
"""
|
|
76
78
|
Upload script to S3 and submit message to SQS queue for processing.
|
|
@@ -81,6 +83,8 @@ def submit_to_sqs(
|
|
|
81
83
|
realtime: If True, sets serverless=False for real-time processing (default: False)
|
|
82
84
|
dt: If True, sets DT=True in environment (default: False)
|
|
83
85
|
promote: If True, sets PROMOTE=True in environment (default: False)
|
|
86
|
+
test_promote: If True, sets TEST_PROMOTE=True in environment (default: False)
|
|
87
|
+
group_id: Optional MessageGroupId override for dependency chains (default: derived from script)
|
|
84
88
|
|
|
85
89
|
Raises:
|
|
86
90
|
ValueError: If size is invalid or script file not found
|
|
@@ -99,7 +103,8 @@ def submit_to_sqs(
|
|
|
99
103
|
# Read script content and parse WORKBENCH_BATCH config
|
|
100
104
|
script_content = script_file.read_text()
|
|
101
105
|
batch_config = parse_workbench_batch(script_content)
|
|
102
|
-
group_id
|
|
106
|
+
if group_id is None:
|
|
107
|
+
group_id = get_message_group_id(batch_config)
|
|
103
108
|
outputs = (batch_config or {}).get("outputs", [])
|
|
104
109
|
inputs = (batch_config or {}).get("inputs", [])
|
|
105
110
|
|
|
@@ -108,6 +113,7 @@ def submit_to_sqs(
|
|
|
108
113
|
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
|
|
109
114
|
print(f"🔄 DynamicTraining: {dt}")
|
|
110
115
|
print(f"🆕 Promote: {promote}")
|
|
116
|
+
print(f"🧪 Test Promote: {test_promote}")
|
|
111
117
|
print(f"🪣 Bucket: {workbench_bucket}")
|
|
112
118
|
if outputs:
|
|
113
119
|
print(f"📤 Outputs: {outputs}")
|
|
@@ -174,6 +180,7 @@ def submit_to_sqs(
|
|
|
174
180
|
"SERVERLESS": "False" if realtime else "True",
|
|
175
181
|
"DT": str(dt),
|
|
176
182
|
"PROMOTE": str(promote),
|
|
183
|
+
"TEST_PROMOTE": str(test_promote),
|
|
177
184
|
}
|
|
178
185
|
|
|
179
186
|
# Send the message to SQS
|
|
@@ -200,6 +207,7 @@ def submit_to_sqs(
|
|
|
200
207
|
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
|
|
201
208
|
print(f"🔄 DynamicTraining: {dt}")
|
|
202
209
|
print(f"🆕 Promote: {promote}")
|
|
210
|
+
print(f"🧪 Test Promote: {test_promote}")
|
|
203
211
|
if outputs:
|
|
204
212
|
print(f"📤 Outputs: {outputs}")
|
|
205
213
|
if inputs:
|
|
@@ -234,7 +242,17 @@ def main():
|
|
|
234
242
|
parser.add_argument(
|
|
235
243
|
"--promote",
|
|
236
244
|
action="store_true",
|
|
237
|
-
help="Set Promote=True (models and endpoints will use promoted naming",
|
|
245
|
+
help="Set Promote=True (models and endpoints will use promoted naming)",
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--test-promote",
|
|
249
|
+
action="store_true",
|
|
250
|
+
help="Set TEST_PROMOTE=True (creates test endpoint with '-test' suffix)",
|
|
251
|
+
)
|
|
252
|
+
parser.add_argument(
|
|
253
|
+
"--group-id",
|
|
254
|
+
default=None,
|
|
255
|
+
help="Override MessageGroupId for SQS (used for dependency chain ordering)",
|
|
238
256
|
)
|
|
239
257
|
args = parser.parse_args()
|
|
240
258
|
try:
|
|
@@ -244,6 +262,8 @@ def main():
|
|
|
244
262
|
realtime=args.realtime,
|
|
245
263
|
dt=args.dt,
|
|
246
264
|
promote=args.promote,
|
|
265
|
+
test_promote=args.test_promote,
|
|
266
|
+
group_id=args.group_id,
|
|
247
267
|
)
|
|
248
268
|
except Exception as e:
|
|
249
269
|
print(f"\n❌ ERROR: {e}")
|
workbench/themes/dark/custom.css
CHANGED
|
@@ -135,6 +135,35 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
135
135
|
--bs-border-color: rgb(60, 60, 60);
|
|
136
136
|
}
|
|
137
137
|
|
|
138
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
139
|
+
.Select-control {
|
|
140
|
+
background-color: rgb(35, 35, 35) !important;
|
|
141
|
+
border-color: rgb(60, 60, 60) !important;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
.Select-value-label, .Select-input input {
|
|
145
|
+
color: rgb(210, 210, 210) !important;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
.Select-placeholder {
|
|
149
|
+
color: rgb(150, 150, 150) !important;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
.Select-menu-outer {
|
|
153
|
+
background-color: rgb(35, 35, 35) !important;
|
|
154
|
+
border-color: rgb(60, 60, 60) !important;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
.VirtualizedSelectOption {
|
|
158
|
+
background-color: rgb(35, 35, 35) !important;
|
|
159
|
+
color: rgb(210, 210, 210) !important;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
.VirtualizedSelectFocusedOption {
|
|
163
|
+
background-color: rgb(60, 60, 60) !important;
|
|
164
|
+
color: rgb(230, 230, 230) !important;
|
|
165
|
+
}
|
|
166
|
+
|
|
138
167
|
/* Bootstrap form controls (dbc components) */
|
|
139
168
|
.form-select, .form-control {
|
|
140
169
|
background-color: rgb(35, 35, 35) !important;
|
|
@@ -180,6 +180,35 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
180
180
|
--bs-border-color: var(--wb-accent);
|
|
181
181
|
}
|
|
182
182
|
|
|
183
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
184
|
+
.Select-control {
|
|
185
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
186
|
+
border-color: var(--wb-accent) !important;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
.Select-value-label, .Select-input input {
|
|
190
|
+
color: var(--wb-text-primary) !important;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
.Select-placeholder {
|
|
194
|
+
color: var(--wb-text-muted) !important;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
.Select-menu-outer {
|
|
198
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
199
|
+
border-color: var(--wb-accent) !important;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
.VirtualizedSelectOption {
|
|
203
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
204
|
+
color: var(--wb-text-primary) !important;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
.VirtualizedSelectFocusedOption {
|
|
208
|
+
background-color: var(--wb-dropdown-hover) !important;
|
|
209
|
+
color: var(--wb-text-primary) !important;
|
|
210
|
+
}
|
|
211
|
+
|
|
183
212
|
/* Bootstrap form controls (dbc components) */
|
|
184
213
|
.form-select, .form-control {
|
|
185
214
|
background-color: var(--wb-dropdown-bg) !important;
|
|
@@ -133,6 +133,34 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
133
133
|
--bs-border-color: rgb(80, 85, 115);
|
|
134
134
|
}
|
|
135
135
|
|
|
136
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
137
|
+
.Select-control {
|
|
138
|
+
background-color: rgb(55, 60, 90) !important;
|
|
139
|
+
border-color: rgb(80, 85, 115) !important;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
.Select-value-label, .Select-input input {
|
|
143
|
+
color: rgb(210, 210, 210) !important;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
.Select-placeholder {
|
|
147
|
+
color: rgb(150, 150, 170) !important;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
.Select-menu-outer {
|
|
151
|
+
background-color: rgb(55, 60, 90) !important;
|
|
152
|
+
border-color: rgb(80, 85, 115) !important;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
.VirtualizedSelectOption {
|
|
156
|
+
background-color: rgb(55, 60, 90) !important;
|
|
157
|
+
color: rgb(210, 210, 210) !important;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
.VirtualizedSelectFocusedOption {
|
|
161
|
+
background-color: rgb(70, 75, 110) !important;
|
|
162
|
+
color: rgb(230, 230, 230) !important;
|
|
163
|
+
}
|
|
136
164
|
|
|
137
165
|
/* Bootstrap form controls (dbc components) */
|
|
138
166
|
.form-select, .form-control {
|
workbench/utils/model_utils.py
CHANGED
|
@@ -459,6 +459,12 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
459
459
|
# Spearman correlation for robustness
|
|
460
460
|
interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
|
|
461
461
|
|
|
462
|
+
# --- Confidence to Error Correlation ---
|
|
463
|
+
# If confidence column exists, compute correlation (should be negative: high confidence = low error)
|
|
464
|
+
confidence_to_error_corr = None
|
|
465
|
+
if "confidence" in df.columns:
|
|
466
|
+
confidence_to_error_corr = spearmanr(df["confidence"], abs_residuals)[0]
|
|
467
|
+
|
|
462
468
|
# Collect results
|
|
463
469
|
results = {
|
|
464
470
|
"coverage_68": coverage_68,
|
|
@@ -472,6 +478,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
472
478
|
"median_width_90": median_width_90,
|
|
473
479
|
"median_width_95": median_width_95,
|
|
474
480
|
"interval_to_error_corr": interval_to_error_corr,
|
|
481
|
+
"confidence_to_error_corr": confidence_to_error_corr,
|
|
475
482
|
"n_samples": len(df),
|
|
476
483
|
}
|
|
477
484
|
|
|
@@ -489,6 +496,8 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
489
496
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
490
497
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
491
498
|
print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
|
|
499
|
+
if confidence_to_error_corr is not None:
|
|
500
|
+
print(f"Confidence/Error Corr: {confidence_to_error_corr:.3f} (lower is better, target: <-0.5)")
|
|
492
501
|
print(f"Samples: {len(df)}")
|
|
493
502
|
return results
|
|
494
503
|
|