workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/dataframe/smart_aggregator.py +17 -12
  2. workbench/api/endpoint.py +13 -4
  3. workbench/api/model.py +2 -2
  4. workbench/cached/cached_model.py +2 -2
  5. workbench/core/artifacts/athena_source.py +5 -3
  6. workbench/core/artifacts/endpoint_core.py +30 -5
  7. workbench/core/cloud_platform/aws/aws_meta.py +2 -1
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
  9. workbench/model_script_utils/model_script_utils.py +225 -0
  10. workbench/model_script_utils/uq_harness.py +39 -21
  11. workbench/model_scripts/chemprop/chemprop.template +30 -15
  12. workbench/model_scripts/chemprop/generated_model_script.py +35 -18
  13. workbench/model_scripts/chemprop/model_script_utils.py +225 -0
  14. workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
  15. workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
  16. workbench/model_scripts/pytorch_model/pytorch.template +28 -14
  17. workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
  18. workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
  19. workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
  20. workbench/model_scripts/xgb_model/uq_harness.py +39 -21
  21. workbench/model_scripts/xgb_model/xgb_model.template +29 -18
  22. workbench/scripts/ml_pipeline_batch.py +47 -2
  23. workbench/scripts/ml_pipeline_launcher.py +410 -0
  24. workbench/scripts/ml_pipeline_sqs.py +22 -2
  25. workbench/themes/dark/custom.css +29 -0
  26. workbench/themes/light/custom.css +29 -0
  27. workbench/themes/midnight_blue/custom.css +28 -0
  28. workbench/utils/model_utils.py +9 -0
  29. workbench/utils/theme_manager.py +95 -0
  30. workbench/web_interface/components/component_interface.py +3 -0
  31. workbench/web_interface/components/plugin_interface.py +26 -0
  32. workbench/web_interface/components/plugins/ag_table.py +4 -11
  33. workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
  34. workbench/web_interface/components/plugins/model_plot.py +156 -0
  35. workbench/web_interface/components/plugins/scatter_plot.py +9 -2
  36. workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
  37. workbench/web_interface/components/settings_menu.py +10 -49
  38. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
  39. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
  40. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
  41. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
  42. workbench/web_interface/components/model_plot.py +0 -75
  43. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ """Launch ML pipelines via SQS for testing.
2
+
3
+ Run this from a directory containing pipeline subdirectories (e.g., ml_pipelines/).
4
+
5
+ Usage:
6
+ ml_pipeline_launcher --dt # Launch 1 random pipeline group (all scripts in a directory)
7
+ ml_pipeline_launcher --dt -n 3 # Launch 3 random pipeline groups
8
+ ml_pipeline_launcher --dt --all # Launch ALL pipelines
9
+ ml_pipeline_launcher --dt caco2 # Launch pipelines matching 'caco2'
10
+ ml_pipeline_launcher --dt caco2 ppb # Launch pipelines matching 'caco2' or 'ppb'
11
+ ml_pipeline_launcher --promote --all # Promote ALL pipelines
12
+ ml_pipeline_launcher --test-promote --all # Test-promote ALL pipelines
13
+ ml_pipeline_launcher --dt --dry-run # Show what would be launched without launching
14
+ """
15
+
16
+ import argparse
17
+ import ast
18
+ import random
19
+ import re
20
+ import subprocess
21
+ import time
22
+ from pathlib import Path
23
+
24
+
25
+ def parse_workbench_batch(script_path: Path) -> dict | None:
26
+ """Parse WORKBENCH_BATCH config from a script file."""
27
+ content = script_path.read_text()
28
+ match = re.search(r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})", content, re.DOTALL)
29
+ if match:
30
+ try:
31
+ return ast.literal_eval(match.group(1))
32
+ except (ValueError, SyntaxError):
33
+ return None
34
+ return None
35
+
36
+
37
+ def build_dependency_graph(configs: dict[Path, dict]) -> dict[str, str]:
38
+ """Build a mapping from each output to its root producer.
39
+
40
+ For a chain like A -> B -> C (where B depends on A, C depends on B),
41
+ this returns {A: A, B: A, C: A} so all are in the same message group.
42
+ """
43
+ # Build output -> input mapping (what does each output depend on?)
44
+ output_to_input = {}
45
+ for config in configs.values():
46
+ if not config:
47
+ continue
48
+ outputs = config.get("outputs", [])
49
+ inputs = config.get("inputs", [])
50
+ for output in outputs:
51
+ output_to_input[output] = inputs[0] if inputs else None
52
+
53
+ # Walk chain to find root
54
+ def find_root(output: str, visited: set = None) -> str:
55
+ if visited is None:
56
+ visited = set()
57
+ if output in visited:
58
+ return output
59
+ visited.add(output)
60
+ parent = output_to_input.get(output)
61
+ if parent is None:
62
+ return output
63
+ return find_root(parent, visited)
64
+
65
+ return {output: find_root(output) for output in output_to_input}
66
+
67
+
68
+ def get_group_id(config: dict | None, root_map: dict[str, str]) -> str | None:
69
+ """Get the root group_id for a pipeline based on its config and root_map."""
70
+ if not config:
71
+ return None
72
+ outputs = config.get("outputs", [])
73
+ inputs = config.get("inputs", [])
74
+ # Check inputs first (this script depends on something)
75
+ if inputs and inputs[0] in root_map:
76
+ return root_map[inputs[0]]
77
+ # Check outputs (this script produces something)
78
+ if outputs and outputs[0] in root_map:
79
+ return root_map[outputs[0]]
80
+ return None
81
+
82
+
83
+ def sort_by_dependencies(pipelines: list[Path]) -> tuple[list[Path], dict[Path, dict], dict[str, str]]:
84
+ """Sort pipelines by dependency chains. Returns (sorted_list, configs, root_map)."""
85
+ # Parse all configs
86
+ configs = {}
87
+ for pipeline in pipelines:
88
+ configs[pipeline] = parse_workbench_batch(pipeline)
89
+
90
+ # Build root map for group_id resolution
91
+ root_map = build_dependency_graph(configs)
92
+
93
+ # Build output -> pipeline mapping
94
+ output_to_pipeline = {}
95
+ for pipeline, config in configs.items():
96
+ if config and config.get("outputs"):
97
+ for output in config["outputs"]:
98
+ output_to_pipeline[output] = pipeline
99
+
100
+ # Build chains by walking from root producers
101
+ sorted_pipelines = []
102
+ used = set()
103
+
104
+ for pipeline in sorted(pipelines):
105
+ config = configs.get(pipeline)
106
+
107
+ # Skip if already used or has inputs (not a root)
108
+ if pipeline in used:
109
+ continue
110
+ if config and config.get("inputs"):
111
+ continue
112
+
113
+ # Walk the chain from this root
114
+ chain = [pipeline]
115
+ used.add(pipeline)
116
+
117
+ current = pipeline
118
+ while True:
119
+ current_config = configs.get(current)
120
+ if not current_config or not current_config.get("outputs"):
121
+ break
122
+
123
+ current_output = current_config["outputs"][0]
124
+ # Find pipeline that consumes this output
125
+ next_pipeline = None
126
+ for p, c in configs.items():
127
+ if p in used or p not in pipelines:
128
+ continue
129
+ if c and c.get("inputs") and current_output in c["inputs"]:
130
+ next_pipeline = p
131
+ break
132
+
133
+ if next_pipeline:
134
+ chain.append(next_pipeline)
135
+ used.add(next_pipeline)
136
+ current = next_pipeline
137
+ else:
138
+ break
139
+
140
+ sorted_pipelines.extend(chain)
141
+
142
+ # Add any remaining pipelines not in chains
143
+ for pipeline in sorted(pipelines):
144
+ if pipeline not in used:
145
+ sorted_pipelines.append(pipeline)
146
+
147
+ return sorted_pipelines, configs, root_map
148
+
149
+
150
+ def format_dependency_chains(pipelines: list[Path], configs: dict[Path, dict]) -> list[str]:
151
+ """Format pipelines as dependency chains for display."""
152
+ # Build output -> pipeline mapping
153
+ output_to_pipeline = {}
154
+ for pipeline, config in configs.items():
155
+ if config and config.get("outputs"):
156
+ for output in config["outputs"]:
157
+ output_to_pipeline[output] = pipeline
158
+
159
+ # Build chains by walking from root producers
160
+ chains = []
161
+ used = set()
162
+
163
+ for pipeline in pipelines:
164
+ config = configs.get(pipeline)
165
+
166
+ # Skip if already part of a chain or has inputs (not a root)
167
+ if pipeline in used:
168
+ continue
169
+ if config and config.get("inputs"):
170
+ continue
171
+
172
+ # Start a new chain from this root producer (or standalone)
173
+ chain = [pipeline]
174
+ used.add(pipeline)
175
+
176
+ # Walk the chain: find who consumes our output
177
+ current = pipeline
178
+ while True:
179
+ current_config = configs.get(current)
180
+ if not current_config or not current_config.get("outputs"):
181
+ break
182
+
183
+ current_output = current_config["outputs"][0]
184
+ # Find a pipeline that takes this output as input
185
+ next_pipeline = None
186
+ for p, c in configs.items():
187
+ if p in used or p not in pipelines:
188
+ continue
189
+ if c and c.get("inputs") and current_output in c["inputs"]:
190
+ next_pipeline = p
191
+ break
192
+
193
+ if next_pipeline:
194
+ chain.append(next_pipeline)
195
+ used.add(next_pipeline)
196
+ current = next_pipeline
197
+ else:
198
+ break
199
+
200
+ chains.append(chain)
201
+
202
+ # Add any remaining pipelines not in chains (shouldn't happen but just in case)
203
+ for pipeline in pipelines:
204
+ if pipeline not in used:
205
+ chains.append([pipeline])
206
+
207
+ # Format chains as strings
208
+ lines = []
209
+ for chain in chains:
210
+ names = [p.stem for p in chain]
211
+ lines.append(" " + " --> ".join(names))
212
+
213
+ return lines
214
+
215
+
216
+ def get_all_pipelines() -> list[Path]:
217
+ """Get all ML pipeline scripts from subdirectories of current working directory."""
218
+ cwd = Path.cwd()
219
+ # Find all .py files in subdirectories (not in cwd itself)
220
+ pipelines = []
221
+ for subdir in cwd.iterdir():
222
+ if subdir.is_dir():
223
+ pipelines.extend(subdir.rglob("*.py"))
224
+ return pipelines
225
+
226
+
227
+ def get_pipeline_groups(pipelines: list[Path]) -> dict[Path, list[Path]]:
228
+ """Group pipelines by their parent directory (leaf directories)."""
229
+ groups = {}
230
+ for pipeline in pipelines:
231
+ parent = pipeline.parent
232
+ groups.setdefault(parent, []).append(pipeline)
233
+ return groups
234
+
235
+
236
+ def select_random_groups(pipelines: list[Path], num_groups: int) -> list[Path]:
237
+ """Select pipelines from n random leaf directories."""
238
+ groups = get_pipeline_groups(pipelines)
239
+ if not groups:
240
+ return []
241
+
242
+ # Select up to num_groups random directories
243
+ dirs = list(groups.keys())
244
+ selected_dirs = random.sample(dirs, min(num_groups, len(dirs)))
245
+
246
+ # Return all pipelines from those directories
247
+ selected = []
248
+ for d in selected_dirs:
249
+ selected.extend(groups[d])
250
+ return selected
251
+
252
+
253
+ def filter_pipelines_by_patterns(pipelines: list[Path], patterns: list[str]) -> list[Path]:
254
+ """Filter pipelines by substring patterns matching the basename."""
255
+ if not patterns:
256
+ return pipelines
257
+
258
+ matched = []
259
+ for pipeline in pipelines:
260
+ basename = pipeline.stem.lower()
261
+ if any(pattern.lower() in basename for pattern in patterns):
262
+ matched.append(pipeline)
263
+ return matched
264
+
265
+
266
+ def main():
267
+ parser = argparse.ArgumentParser(description="Launch ML pipelines via SQS for testing")
268
+ parser.add_argument(
269
+ "patterns",
270
+ nargs="*",
271
+ help="Substring patterns to filter pipelines by basename (e.g., 'caco2' 'ppb')",
272
+ )
273
+ parser.add_argument(
274
+ "-n",
275
+ "--num-groups",
276
+ type=int,
277
+ default=1,
278
+ help="Number of random pipeline groups to launch (default: 1, ignored if --all or patterns specified)",
279
+ )
280
+ parser.add_argument(
281
+ "--all",
282
+ action="store_true",
283
+ help="Launch ALL pipelines (ignores -n)",
284
+ )
285
+ parser.add_argument(
286
+ "--realtime",
287
+ action="store_true",
288
+ help="Create realtime endpoints (default is serverless)",
289
+ )
290
+ parser.add_argument(
291
+ "--dry-run",
292
+ action="store_true",
293
+ help="Show what would be launched without actually launching",
294
+ )
295
+
296
+ # Mode flags (mutually exclusive)
297
+ mode_group = parser.add_mutually_exclusive_group(required=True)
298
+ mode_group.add_argument(
299
+ "--dt",
300
+ action="store_true",
301
+ help="Launch with DT=True (dynamic training mode)",
302
+ )
303
+ mode_group.add_argument(
304
+ "--promote",
305
+ action="store_true",
306
+ help="Launch with PROMOTE=True (promotion mode)",
307
+ )
308
+ mode_group.add_argument(
309
+ "--test-promote",
310
+ action="store_true",
311
+ help="Launch with TEST_PROMOTE=True (test promotion mode)",
312
+ )
313
+
314
+ args = parser.parse_args()
315
+
316
+ # Get all pipelines from subdirectories of current working directory
317
+ all_pipelines = get_all_pipelines()
318
+ if not all_pipelines:
319
+ print(f"No pipeline scripts found in subdirectories of {Path.cwd()}")
320
+ exit(1)
321
+
322
+ # Determine which pipelines to run
323
+ if args.patterns:
324
+ # Filter by patterns
325
+ selected_pipelines = filter_pipelines_by_patterns(all_pipelines, args.patterns)
326
+ if not selected_pipelines:
327
+ print(f"No pipelines matching patterns: {args.patterns}")
328
+ exit(1)
329
+ selection_mode = f"matching {args.patterns}"
330
+ elif args.all:
331
+ # Run all pipelines
332
+ selected_pipelines = all_pipelines
333
+ selection_mode = "ALL"
334
+ else:
335
+ # Random group selection
336
+ selected_pipelines = select_random_groups(all_pipelines, args.num_groups)
337
+ if not selected_pipelines:
338
+ print("No pipeline groups found")
339
+ exit(1)
340
+ # Get the directory names for display
341
+ groups = get_pipeline_groups(selected_pipelines)
342
+ group_names = [d.name for d in groups.keys()]
343
+ selection_mode = f"RANDOM {args.num_groups} group(s): {group_names}"
344
+
345
+ # Sort by dependencies (producers before consumers)
346
+ selected_pipelines, configs, root_map = sort_by_dependencies(selected_pipelines)
347
+
348
+ # Determine mode for display and CLI flag
349
+ if args.dt:
350
+ mode_name = "DT (Dynamic Training)"
351
+ mode_flag = "--dt"
352
+ elif args.promote:
353
+ mode_name = "PROMOTE"
354
+ mode_flag = "--promote"
355
+ else:
356
+ mode_name = "TEST_PROMOTE"
357
+ mode_flag = "--test-promote"
358
+
359
+ print(f"\n{'=' * 60}")
360
+ print(f"{'DRY RUN - ' if args.dry_run else ''}LAUNCHING {len(selected_pipelines)} PIPELINES")
361
+ print(f"{'=' * 60}")
362
+ print(f"Source: {Path.cwd()}")
363
+ print(f"Selection: {selection_mode}")
364
+ print(f"Mode: {mode_name}")
365
+ print(f"Endpoint: {'Realtime' if args.realtime else 'Serverless'}")
366
+ print("\nPipeline Chains:")
367
+ for line in format_dependency_chains(selected_pipelines, configs):
368
+ print(line)
369
+ print()
370
+
371
+ # Dry run - just show what would be launched
372
+ if args.dry_run:
373
+ print("Dry run complete. No pipelines were launched.\n")
374
+ return
375
+
376
+ # Countdown before launching
377
+ print("Launching in ", end="", flush=True)
378
+ for i in range(10, 0, -1):
379
+ print(f"{i}...", end="", flush=True)
380
+ time.sleep(1)
381
+ print(" GO!\n")
382
+
383
+ # Launch each pipeline using the CLI
384
+ for i, pipeline in enumerate(selected_pipelines, 1):
385
+ print(f"\n{'─' * 60}")
386
+ print(f"Launching pipeline {i}/{len(selected_pipelines)}: {pipeline.name}")
387
+ print(f"{'─' * 60}")
388
+
389
+ # Build the command
390
+ cmd = ["ml_pipeline_sqs", str(pipeline), mode_flag]
391
+ if args.realtime:
392
+ cmd.append("--realtime")
393
+
394
+ # Pass root group_id for dependency chain ordering
395
+ group_id = get_group_id(configs.get(pipeline), root_map)
396
+ if group_id:
397
+ cmd.extend(["--group-id", group_id])
398
+
399
+ print(f"Running: {' '.join(cmd)}\n")
400
+ result = subprocess.run(cmd)
401
+ if result.returncode != 0:
402
+ print(f"Failed to launch {pipeline.name} (exit code: {result.returncode})")
403
+
404
+ print(f"\n{'=' * 60}")
405
+ print(f"FINISHED LAUNCHING {len(selected_pipelines)} PIPELINES")
406
+ print(f"{'=' * 60}\n")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ main()
@@ -71,6 +71,8 @@ def submit_to_sqs(
71
71
  realtime: bool = False,
72
72
  dt: bool = False,
73
73
  promote: bool = False,
74
+ test_promote: bool = False,
75
+ group_id: str | None = None,
74
76
  ) -> None:
75
77
  """
76
78
  Upload script to S3 and submit message to SQS queue for processing.
@@ -81,6 +83,8 @@ def submit_to_sqs(
81
83
  realtime: If True, sets serverless=False for real-time processing (default: False)
82
84
  dt: If True, sets DT=True in environment (default: False)
83
85
  promote: If True, sets PROMOTE=True in environment (default: False)
86
+ test_promote: If True, sets TEST_PROMOTE=True in environment (default: False)
87
+ group_id: Optional MessageGroupId override for dependency chains (default: derived from script)
84
88
 
85
89
  Raises:
86
90
  ValueError: If size is invalid or script file not found
@@ -99,7 +103,8 @@ def submit_to_sqs(
99
103
  # Read script content and parse WORKBENCH_BATCH config
100
104
  script_content = script_file.read_text()
101
105
  batch_config = parse_workbench_batch(script_content)
102
- group_id = get_message_group_id(batch_config)
106
+ if group_id is None:
107
+ group_id = get_message_group_id(batch_config)
103
108
  outputs = (batch_config or {}).get("outputs", [])
104
109
  inputs = (batch_config or {}).get("inputs", [])
105
110
 
@@ -108,6 +113,7 @@ def submit_to_sqs(
108
113
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
109
114
  print(f"🔄 DynamicTraining: {dt}")
110
115
  print(f"🆕 Promote: {promote}")
116
+ print(f"🧪 Test Promote: {test_promote}")
111
117
  print(f"🪣 Bucket: {workbench_bucket}")
112
118
  if outputs:
113
119
  print(f"📤 Outputs: {outputs}")
@@ -174,6 +180,7 @@ def submit_to_sqs(
174
180
  "SERVERLESS": "False" if realtime else "True",
175
181
  "DT": str(dt),
176
182
  "PROMOTE": str(promote),
183
+ "TEST_PROMOTE": str(test_promote),
177
184
  }
178
185
 
179
186
  # Send the message to SQS
@@ -200,6 +207,7 @@ def submit_to_sqs(
200
207
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
201
208
  print(f"🔄 DynamicTraining: {dt}")
202
209
  print(f"🆕 Promote: {promote}")
210
+ print(f"🧪 Test Promote: {test_promote}")
203
211
  if outputs:
204
212
  print(f"📤 Outputs: {outputs}")
205
213
  if inputs:
@@ -234,7 +242,17 @@ def main():
234
242
  parser.add_argument(
235
243
  "--promote",
236
244
  action="store_true",
237
- help="Set Promote=True (models and endpoints will use promoted naming",
245
+ help="Set Promote=True (models and endpoints will use promoted naming)",
246
+ )
247
+ parser.add_argument(
248
+ "--test-promote",
249
+ action="store_true",
250
+ help="Set TEST_PROMOTE=True (creates test endpoint with '-test' suffix)",
251
+ )
252
+ parser.add_argument(
253
+ "--group-id",
254
+ default=None,
255
+ help="Override MessageGroupId for SQS (used for dependency chain ordering)",
238
256
  )
239
257
  args = parser.parse_args()
240
258
  try:
@@ -244,6 +262,8 @@ def main():
244
262
  realtime=args.realtime,
245
263
  dt=args.dt,
246
264
  promote=args.promote,
265
+ test_promote=args.test_promote,
266
+ group_id=args.group_id,
247
267
  )
248
268
  except Exception as e:
249
269
  print(f"\n❌ ERROR: {e}")
@@ -135,6 +135,35 @@ div:has(> [class*="ag-theme-"]) {
135
135
  --bs-border-color: rgb(60, 60, 60);
136
136
  }
137
137
 
138
+ /* React-select dropdown styling - target actual rendered elements */
139
+ .Select-control {
140
+ background-color: rgb(35, 35, 35) !important;
141
+ border-color: rgb(60, 60, 60) !important;
142
+ }
143
+
144
+ .Select-value-label, .Select-input input {
145
+ color: rgb(210, 210, 210) !important;
146
+ }
147
+
148
+ .Select-placeholder {
149
+ color: rgb(150, 150, 150) !important;
150
+ }
151
+
152
+ .Select-menu-outer {
153
+ background-color: rgb(35, 35, 35) !important;
154
+ border-color: rgb(60, 60, 60) !important;
155
+ }
156
+
157
+ .VirtualizedSelectOption {
158
+ background-color: rgb(35, 35, 35) !important;
159
+ color: rgb(210, 210, 210) !important;
160
+ }
161
+
162
+ .VirtualizedSelectFocusedOption {
163
+ background-color: rgb(60, 60, 60) !important;
164
+ color: rgb(230, 230, 230) !important;
165
+ }
166
+
138
167
  /* Bootstrap form controls (dbc components) */
139
168
  .form-select, .form-control {
140
169
  background-color: rgb(35, 35, 35) !important;
@@ -180,6 +180,35 @@ div:has(> [class*="ag-theme-"]) {
180
180
  --bs-border-color: var(--wb-accent);
181
181
  }
182
182
 
183
+ /* React-select dropdown styling - target actual rendered elements */
184
+ .Select-control {
185
+ background-color: var(--wb-dropdown-bg) !important;
186
+ border-color: var(--wb-accent) !important;
187
+ }
188
+
189
+ .Select-value-label, .Select-input input {
190
+ color: var(--wb-text-primary) !important;
191
+ }
192
+
193
+ .Select-placeholder {
194
+ color: var(--wb-text-muted) !important;
195
+ }
196
+
197
+ .Select-menu-outer {
198
+ background-color: var(--wb-dropdown-bg) !important;
199
+ border-color: var(--wb-accent) !important;
200
+ }
201
+
202
+ .VirtualizedSelectOption {
203
+ background-color: var(--wb-dropdown-bg) !important;
204
+ color: var(--wb-text-primary) !important;
205
+ }
206
+
207
+ .VirtualizedSelectFocusedOption {
208
+ background-color: var(--wb-dropdown-hover) !important;
209
+ color: var(--wb-text-primary) !important;
210
+ }
211
+
183
212
  /* Bootstrap form controls (dbc components) */
184
213
  .form-select, .form-control {
185
214
  background-color: var(--wb-dropdown-bg) !important;
@@ -133,6 +133,34 @@ div:has(> [class*="ag-theme-"]) {
133
133
  --bs-border-color: rgb(80, 85, 115);
134
134
  }
135
135
 
136
+ /* React-select dropdown styling - target actual rendered elements */
137
+ .Select-control {
138
+ background-color: rgb(55, 60, 90) !important;
139
+ border-color: rgb(80, 85, 115) !important;
140
+ }
141
+
142
+ .Select-value-label, .Select-input input {
143
+ color: rgb(210, 210, 210) !important;
144
+ }
145
+
146
+ .Select-placeholder {
147
+ color: rgb(150, 150, 170) !important;
148
+ }
149
+
150
+ .Select-menu-outer {
151
+ background-color: rgb(55, 60, 90) !important;
152
+ border-color: rgb(80, 85, 115) !important;
153
+ }
154
+
155
+ .VirtualizedSelectOption {
156
+ background-color: rgb(55, 60, 90) !important;
157
+ color: rgb(210, 210, 210) !important;
158
+ }
159
+
160
+ .VirtualizedSelectFocusedOption {
161
+ background-color: rgb(70, 75, 110) !important;
162
+ color: rgb(230, 230, 230) !important;
163
+ }
136
164
 
137
165
  /* Bootstrap form controls (dbc components) */
138
166
  .form-select, .form-control {
@@ -459,6 +459,12 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
459
459
  # Spearman correlation for robustness
460
460
  interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
461
461
 
462
+ # --- Confidence to Error Correlation ---
463
+ # If confidence column exists, compute correlation (should be negative: high confidence = low error)
464
+ confidence_to_error_corr = None
465
+ if "confidence" in df.columns:
466
+ confidence_to_error_corr = spearmanr(df["confidence"], abs_residuals)[0]
467
+
462
468
  # Collect results
463
469
  results = {
464
470
  "coverage_68": coverage_68,
@@ -472,6 +478,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
472
478
  "median_width_90": median_width_90,
473
479
  "median_width_95": median_width_95,
474
480
  "interval_to_error_corr": interval_to_error_corr,
481
+ "confidence_to_error_corr": confidence_to_error_corr,
475
482
  "n_samples": len(df),
476
483
  }
477
484
 
@@ -489,6 +496,8 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
489
496
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
490
497
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
491
498
  print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
499
+ if confidence_to_error_corr is not None:
500
+ print(f"Confidence/Error Corr: {confidence_to_error_corr:.3f} (lower is better, target: <-0.5)")
492
501
  print(f"Samples: {len(df)}")
493
502
  return results
494
503