wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/autotuner.py ADDED
@@ -0,0 +1,1080 @@
1
+ """Autotuner CLI - Run hyperparameter sweep experiments.
2
+
3
+ This module provides the implementation for the `wafer autotuner` commands.
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ from datetime import UTC
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+
13
+ def run_sweep_command(
14
+ config_file: Path | None = None,
15
+ parallel: int = 4,
16
+ resume_sweep_id: str | None = None,
17
+ ) -> str:
18
+ """Run an autotuner sweep from a JSON config file or resume existing sweep.
19
+
20
+ Args:
21
+ config_file: Path to JSON config file (required if not resuming)
22
+ parallel: Number of trials to run concurrently
23
+ resume_sweep_id: Sweep ID to resume (optional)
24
+
25
+ Returns:
26
+ JSON string with sweep results
27
+
28
+ Raises:
29
+ ValueError: If config file is invalid or sweep not found
30
+ FileNotFoundError: If config file doesn't exist
31
+ """
32
+ if config_file and not config_file.exists():
33
+ raise FileNotFoundError(f"Config file not found: {config_file}")
34
+
35
+ # Import autotuner core
36
+ from datetime import datetime
37
+ from uuid import uuid4
38
+
39
+ import trio
40
+ from wafer_core.tools.autotuner import AutotunerConfig, run_sweep
41
+ from wafer_core.tools.autotuner.dtypes import Sweep, Trial
42
+ from wafer_core.tools.autotuner.search import generate_grid_trials
43
+ from wafer_core.tools.autotuner.storage import add_trial, create_sweep, get_sweep, get_trials
44
+
45
+ # Load or reconstruct config
46
+ if resume_sweep_id:
47
+ # Resume existing sweep - load from database
48
+ try:
49
+ existing_sweep = asyncio.run(get_sweep(resume_sweep_id))
50
+ existing_trials = asyncio.run(get_trials(resume_sweep_id))
51
+ except Exception as e:
52
+ raise ValueError(f"Failed to load sweep {resume_sweep_id}: {e}") from e
53
+
54
+ # Reconstruct config from stored sweep config
55
+ config_dict = existing_sweep.config
56
+ config = AutotunerConfig(
57
+ name=config_dict["name"],
58
+ description=config_dict.get("description"),
59
+ search_space=config_dict["search_space"],
60
+ command=config_dict["command"],
61
+ metrics=config_dict["metrics"],
62
+ max_trials=config_dict.get("max_trials", 0),
63
+ parallel=parallel, # Use new parallel value
64
+ timeout=config_dict.get("timeout", 300),
65
+ trials_per_config=config_dict.get("trials_per_config", 1),
66
+ )
67
+
68
+ # Reconstruct objectives if present
69
+ if "objectives" in config_dict:
70
+ from wafer_core.tools.autotuner.dtypes import Objective
71
+ config.objectives = [
72
+ Objective(
73
+ metric=obj["metric"],
74
+ direction=obj["direction"],
75
+ weight=obj.get("weight", 1.0),
76
+ )
77
+ for obj in config_dict["objectives"]
78
+ ]
79
+
80
+ # Reconstruct constraints if present
81
+ if "constraints" in config_dict:
82
+ from wafer_core.tools.autotuner.dtypes import Constraint
83
+ config.constraints = [
84
+ Constraint(
85
+ metric=c["metric"],
86
+ min=c.get("min"),
87
+ max=c.get("max"),
88
+ equals=c.get("equals"),
89
+ )
90
+ for c in config_dict["constraints"]
91
+ ]
92
+
93
+ actual_sweep_id = resume_sweep_id
94
+ is_resume = True
95
+ # Use current working directory for resume (user should run from same place)
96
+ working_dir = Path.cwd()
97
+ else:
98
+ # New sweep
99
+ if not config_file:
100
+ raise ValueError("config_file is required when not resuming")
101
+
102
+ # Load config
103
+ try:
104
+ config = AutotunerConfig.from_json(config_file)
105
+ except Exception as e:
106
+ raise ValueError(f"Failed to parse config: {e}") from e
107
+
108
+ # Override parallel from CLI flag
109
+ config.parallel = parallel
110
+ is_resume = False
111
+ working_dir = config_file.parent
112
+ actual_sweep_id = None # Will be set after creating sweep
113
+
114
+ # Run sweep synchronously
115
+ try:
116
+ async def _run_sweep() -> str:
117
+ nonlocal actual_sweep_id
118
+
119
+ # Calculate total trials
120
+ search_space = config.get_search_space()
121
+ trial_configs = generate_grid_trials(search_space, config.max_trials)
122
+ total_trials = len(trial_configs) * config.trials_per_config
123
+
124
+ if is_resume:
125
+ # Resume mode
126
+ # Print resume status
127
+ num_configs = len(trial_configs)
128
+ print(f"\n🔄 Resuming sweep: {config.name}")
129
+ print(f"Sweep ID: {actual_sweep_id}")
130
+ if config.trials_per_config > 1:
131
+ print(f"Configurations: {num_configs}")
132
+ print(f"Trials per config: {config.trials_per_config}")
133
+ print(f"Total trials: {total_trials}")
134
+ else:
135
+ print(f"Total trials: {total_trials}")
136
+ print(f"Already completed: {len(existing_trials)}")
137
+ print(f"Parallelism: {config.parallel}")
138
+ print()
139
+
140
+ # Track progress (starting from existing)
141
+ completed_count = len(existing_trials)
142
+ success_count = sum(1 for t in existing_trials if t.status.value == "success")
143
+ failed_count = len(existing_trials) - success_count
144
+
145
+ else:
146
+ # New sweep mode
147
+ # Generate sweep ID
148
+ sweep_id = str(uuid4())
149
+
150
+ # Serialize config to dict for storage
151
+ config_dict: dict[str, Any] = {
152
+ "name": config.name,
153
+ "description": config.description,
154
+ "search_space": config.search_space,
155
+ "command": config.command,
156
+ "metrics": config.metrics,
157
+ "max_trials": config.max_trials,
158
+ "parallel": config.parallel,
159
+ "timeout": config.timeout,
160
+ "trials_per_config": config.trials_per_config,
161
+ }
162
+
163
+ if config.objectives:
164
+ config_dict["objectives"] = [
165
+ {
166
+ "metric": obj.metric,
167
+ "direction": obj.direction,
168
+ "weight": obj.weight,
169
+ }
170
+ for obj in config.objectives
171
+ ]
172
+
173
+ if config.constraints:
174
+ config_dict["constraints"] = [
175
+ {
176
+ "metric": c.metric,
177
+ "min": c.min,
178
+ "max": c.max,
179
+ "equals": c.equals,
180
+ }
181
+ for c in config.constraints
182
+ ]
183
+
184
+ # Create sweep in database
185
+ sweep = Sweep(
186
+ id=sweep_id,
187
+ user_id="", # Will be filled by API from auth
188
+ name=config.name,
189
+ description=config.description,
190
+ config=config_dict,
191
+ status="running",
192
+ total_trials=total_trials,
193
+ completed_trials=0,
194
+ created_at=datetime.now(UTC),
195
+ updated_at=datetime.now(UTC),
196
+ )
197
+
198
+ # Create sweep and get the actual ID from the API
199
+ actual_sweep_id = await create_sweep(sweep)
200
+
201
+ # Print initial status
202
+ num_configs = len(trial_configs)
203
+ print(f"\n🚀 Starting sweep: {config.name}")
204
+ print(f"Sweep ID: {actual_sweep_id}")
205
+ if config.trials_per_config > 1:
206
+ print(f"Configurations: {num_configs}")
207
+ print(f"Trials per config: {config.trials_per_config}")
208
+ print(f"Total trials: {total_trials}")
209
+ else:
210
+ print(f"Total trials: {total_trials}")
211
+ print(f"Parallelism: {config.parallel}")
212
+ print()
213
+
214
+ # Track progress
215
+ completed_count = 0
216
+ success_count = 0
217
+ failed_count = 0
218
+
219
+ # Define callback to upload and print progress
220
+ async def on_trial_complete(trial: Trial) -> None:
221
+ nonlocal completed_count, success_count, failed_count
222
+
223
+ # Upload trial to database immediately
224
+ await add_trial(trial)
225
+
226
+ # Update counters
227
+ completed_count += 1
228
+ if trial.status.value == "success":
229
+ success_count += 1
230
+ else:
231
+ failed_count += 1
232
+
233
+ # Print progress
234
+ status_icon = "✓" if trial.status.value == "success" else "✗"
235
+ constraint_str = " (passed)" if trial.passed_constraints else " (constraint violation)" if trial.status.value == "success" else ""
236
+
237
+ # Calculate config number and run number
238
+ config_idx = trial.trial_number // config.trials_per_config
239
+ run_idx = (trial.trial_number % config.trials_per_config) + 1
240
+
241
+ # Show config and run info
242
+ if config.trials_per_config > 1:
243
+ print(f"[{completed_count}/{total_trials}] {status_icon} Config #{config_idx + 1}, Run {run_idx}/{config.trials_per_config}{constraint_str}")
244
+ else:
245
+ print(f"[{completed_count}/{total_trials}] {status_icon} Config #{config_idx + 1}{constraint_str}")
246
+
247
+ # Helper to update sweep status
248
+ async def update_sweep_status(status: str) -> None:
249
+ import httpx
250
+ from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url
251
+
252
+ api_url = get_api_url()
253
+ headers = _get_auth_headers()
254
+
255
+ async with httpx.AsyncClient(timeout=30.0, headers=headers) as client:
256
+ await client.patch(
257
+ f"{api_url}/v1/autotuner/sweeps/{actual_sweep_id}/status",
258
+ json={"status": status},
259
+ )
260
+
261
+ # Run trials with the actual sweep ID and callback
262
+ # Note: working_dir already set based on is_resume flag
263
+
264
+ try:
265
+ await run_sweep(
266
+ config=config,
267
+ sweep_id=actual_sweep_id,
268
+ working_dir=working_dir,
269
+ on_trial_complete=on_trial_complete,
270
+ existing_trials=existing_trials if is_resume else None,
271
+ )
272
+
273
+ # Mark as completed
274
+ await update_sweep_status("completed")
275
+
276
+ # Print final summary
277
+ print()
278
+ print("✅ Sweep completed!")
279
+ print(f" Total: {total_trials} trials")
280
+ print(f" Success: {success_count}")
281
+ print(f" Failed: {failed_count}")
282
+ print(f" Constraint violations: {completed_count - success_count - failed_count}")
283
+ print()
284
+
285
+ # Return result
286
+ return json.dumps(
287
+ {
288
+ "success": True,
289
+ "sweep_id": actual_sweep_id,
290
+ "name": config.name,
291
+ "total_trials": total_trials,
292
+ "completed_trials": completed_count,
293
+ "success_trials": success_count,
294
+ "failed_trials": failed_count,
295
+ },
296
+ indent=2,
297
+ )
298
+
299
+ except KeyboardInterrupt:
300
+ # User pressed Ctrl+C
301
+ print()
302
+ print("❌ Sweep interrupted by user (Ctrl+C)")
303
+ print(f" Completed: {completed_count}/{total_trials} trials")
304
+ await update_sweep_status("failed")
305
+ raise
306
+
307
+ except Exception as e:
308
+ # Any other error
309
+ import traceback
310
+ print()
311
+ print(f"❌ Sweep failed with error: {e}")
312
+ print(f" Completed: {completed_count}/{total_trials} trials")
313
+
314
+ # For Trio nursery exceptions (MultiError/ExceptionGroup), show all sub-exceptions
315
+ if hasattr(e, '__cause__') and e.__cause__ is not None:
316
+ print(f"\nCause: {e.__cause__}")
317
+ if hasattr(e, 'exceptions'):
318
+ print(f"\nSub-exceptions ({len(e.exceptions)}):")
319
+ for i, exc in enumerate(e.exceptions, 1):
320
+ print(f" {i}. {type(exc).__name__}: {exc}")
321
+ if hasattr(exc, '__traceback__'):
322
+ print(f" {''.join(traceback.format_tb(exc.__traceback__, limit=3))}")
323
+
324
+ await update_sweep_status("failed")
325
+ raise
326
+
327
+ return trio.run(_run_sweep)
328
+
329
+ except Exception as e:
330
+ raise ValueError(f"Failed to run sweep: {e}") from e
331
+
332
+
333
+ def results_command(
334
+ sweep_id: str,
335
+ sort_by: str | None = None,
336
+ direction: str = "maximize",
337
+ pareto: str | None = None,
338
+ show_all: bool = False,
339
+ limit: int | None = None,
340
+ ) -> str:
341
+ """Show results from a sweep with optional sorting.
342
+
343
+ Args:
344
+ sweep_id: Sweep ID to retrieve
345
+ sort_by: Metric name to sort by (optional)
346
+ direction: Sort direction - "maximize" or "minimize"
347
+ pareto: Comma-separated list of metrics for Pareto frontier
348
+ show_all: Include failed and constraint-violated trials
349
+ limit: Maximum number of results to show (default: all)
350
+
351
+ Returns:
352
+ Formatted string with results
353
+ """
354
+ from wafer_core.tools.autotuner import compute_pareto_frontier
355
+ from wafer_core.tools.autotuner.aggregation import aggregate_trials_by_config
356
+ from wafer_core.tools.autotuner.storage import get_sweep, get_trials
357
+
358
+ try:
359
+ # Get sweep and trials
360
+ sweep = asyncio.run(get_sweep(sweep_id))
361
+ trials = asyncio.run(get_trials(sweep_id))
362
+
363
+ # Check if we should aggregate by config
364
+ trials_per_config = sweep.config.get("trials_per_config", 1) if sweep.config else 1
365
+ use_aggregation = trials_per_config > 1
366
+
367
+ # Filter trials based on show_all flag
368
+ if show_all:
369
+ # Show all trials, but separate them by status
370
+ valid_trials = [t for t in trials if t.status.value in ("success", "completed") and t.passed_constraints]
371
+ failed_trials = [t for t in trials if t.status.value in ("failed", "timeout")]
372
+ constraint_violated_trials = [t for t in trials if t.status.value == "constraint_violation" or not t.passed_constraints]
373
+ completed_trials = valid_trials # For ranking/sorting
374
+ else:
375
+ # Default: only show successful trials
376
+ completed_trials = [t for t in trials if t.status.value in ("success", "completed")]
377
+ valid_trials = completed_trials
378
+ failed_trials = []
379
+ constraint_violated_trials = []
380
+
381
+ if not valid_trials and not show_all:
382
+ return f"No completed trials found for sweep {sweep_id}"
383
+
384
+ # Aggregate trials if trials_per_config > 1
385
+ aggregated_configs = None
386
+ if use_aggregation:
387
+ aggregated_configs = aggregate_trials_by_config(completed_trials, trials_per_config)
388
+ if not aggregated_configs:
389
+ return f"No valid configurations found for sweep {sweep_id}"
390
+
391
+ # Build result string
392
+ lines = [
393
+ f"Sweep: {sweep.name}",
394
+ f"Status: {sweep.status}",
395
+ ]
396
+
397
+ if show_all:
398
+ lines.append(f"Trials: {len(valid_trials)} valid, {len(constraint_violated_trials)} constraint violations, {len(failed_trials)} failed / {sweep.total_trials} total")
399
+ else:
400
+ lines.append(f"Trials: {len(completed_trials)} completed / {sweep.total_trials} total")
401
+
402
+ lines.append("")
403
+
404
+ # Handle Pareto frontier
405
+ if pareto:
406
+ from wafer_core.tools.autotuner.dtypes import Objective
407
+
408
+ metrics = [m.strip() for m in pareto.split(",")]
409
+ # Create objectives (default to maximize for all)
410
+ objectives = [Objective(metric=m, direction="maximize") for m in metrics]
411
+
412
+ if use_aggregation:
413
+ from wafer_core.tools.autotuner.scoring import compute_pareto_frontier_configs
414
+ pareto_configs = compute_pareto_frontier_configs(aggregated_configs, objectives)
415
+
416
+ lines.append(f"Pareto Frontier ({len(pareto_configs)} configs):")
417
+ lines.append("No single config dominates on all metrics.")
418
+ lines.append("")
419
+
420
+ for i, config in enumerate(pareto_configs, 1):
421
+ lines.append(f"Config {i}: {json.dumps(config.config)}")
422
+ for metric in metrics:
423
+ if metric in config.metrics:
424
+ stats = config.metrics[metric]
425
+ lines.append(f" {metric}: {stats.mean:.2f} ± {stats.std:.2f}")
426
+ else:
427
+ lines.append(f" {metric}: N/A")
428
+ lines.append(f" runs: {len(config.trials)}")
429
+ lines.append("")
430
+ else:
431
+ pareto_trials = compute_pareto_frontier(completed_trials, objectives)
432
+
433
+ lines.append(f"Pareto Frontier ({len(pareto_trials)} configs):")
434
+ lines.append("No single config dominates on all metrics.")
435
+ lines.append("")
436
+
437
+ for i, trial in enumerate(pareto_trials, 1):
438
+ lines.append(f"Config {i}: {json.dumps(trial.config)}")
439
+ for metric in metrics:
440
+ value = trial.metrics.get(metric, "N/A")
441
+ lines.append(f" {metric}: {value}")
442
+ lines.append("")
443
+
444
+ # Handle single metric sorting
445
+ elif sort_by:
446
+ from wafer_core.tools.autotuner.dtypes import Objective
447
+
448
+ objective = Objective(metric=sort_by, direction=direction)
449
+
450
+ if use_aggregation:
451
+ from wafer_core.tools.autotuner.scoring import rank_configs_single_objective
452
+ best_configs = rank_configs_single_objective(aggregated_configs, objective)
453
+
454
+ lines.append(f"Results (sorted by {sort_by}, {direction}):")
455
+ lines.append("")
456
+
457
+ # Apply limit if specified
458
+ configs_to_show = best_configs[:limit] if limit else best_configs
459
+
460
+ for i, config in enumerate(configs_to_show, 1):
461
+ marker = " ⭐" if i == 1 else ""
462
+ lines.append(f"Rank {i}{marker}: {json.dumps(config.config)}")
463
+ for metric_name, stats in config.metrics.items():
464
+ lines.append(f" {metric_name}: {stats.mean:.2f} ± {stats.std:.2f}")
465
+ lines.append(f" runs: {len(config.trials)}")
466
+ lines.append("")
467
+
468
+ # Show count if limited
469
+ if limit and len(best_configs) > limit:
470
+ lines.append(f"... and {len(best_configs) - limit} more results")
471
+ lines.append("")
472
+ else:
473
+ from wafer_core.tools.autotuner.scoring import rank_trials_single_objective
474
+ best_trials = rank_trials_single_objective(completed_trials, objective)
475
+
476
+ lines.append(f"Results (sorted by {sort_by}, {direction}):")
477
+ lines.append("")
478
+
479
+ # Apply limit if specified
480
+ trials_to_show = best_trials[:limit] if limit else best_trials
481
+
482
+ for i, trial in enumerate(trials_to_show, 1):
483
+ marker = " ⭐" if i == 1 else ""
484
+ lines.append(f"Rank {i}{marker}: {json.dumps(trial.config)}")
485
+ for metric_name, metric_value in trial.metrics.items():
486
+ lines.append(f" {metric_name}: {metric_value}")
487
+ lines.append(f" duration: {trial.duration_ms}ms")
488
+ lines.append("")
489
+
490
+ # Show count if limited
491
+ if limit and len(best_trials) > limit:
492
+ lines.append(f"... and {len(best_trials) - limit} more results")
493
+ lines.append("")
494
+
495
+ # Default: use objectives from config
496
+ else:
497
+ if sweep.config and "objectives" in sweep.config:
498
+ from wafer_core.tools.autotuner.dtypes import Objective
499
+
500
+ objectives_data = sweep.config["objectives"]
501
+
502
+ if use_aggregation:
503
+ # Use aggregated config scoring
504
+ if len(objectives_data) > 1:
505
+ # Multi-objective: compute Pareto
506
+ from wafer_core.tools.autotuner.scoring import (
507
+ compute_pareto_frontier_configs,
508
+ rank_pareto_configs,
509
+ )
510
+ objectives = [
511
+ Objective(
512
+ metric=obj["metric"],
513
+ direction=obj["direction"],
514
+ weight=obj.get("weight", 1.0)
515
+ )
516
+ for obj in objectives_data
517
+ ]
518
+ pareto_configs = compute_pareto_frontier_configs(aggregated_configs, objectives)
519
+ ranked_configs = rank_pareto_configs(pareto_configs, objectives)
520
+
521
+ lines.append("Pareto Frontier (using config objectives):")
522
+ lines.append(f"Found {len(ranked_configs)} non-dominated configurations.")
523
+ lines.append("")
524
+
525
+ for i, config in enumerate(ranked_configs, 1):
526
+ lines.append(f"Config {i}: {json.dumps(config.config)}")
527
+ # Show all metrics, not just objectives
528
+ for metric_name, stats in sorted(config.metrics.items()):
529
+ lines.append(f" {metric_name}: {stats.mean:.2f} ± {stats.std:.2f}")
530
+ lines.append(f" runs: {len(config.trials)}")
531
+ lines.append("")
532
+ else:
533
+ # Single objective
534
+ from wafer_core.tools.autotuner.scoring import rank_configs_single_objective
535
+
536
+ obj = objectives_data[0]
537
+ objective = Objective(metric=obj["metric"], direction=obj["direction"])
538
+ best_configs = rank_configs_single_objective(aggregated_configs, objective)
539
+
540
+ lines.append(f"Results (sorted by {obj['metric']}, {obj['direction']}):")
541
+ lines.append("")
542
+
543
+ # Apply limit if specified
544
+ configs_to_show = best_configs[:limit] if limit else best_configs
545
+
546
+ for i, config in enumerate(configs_to_show, 1):
547
+ lines.append(f"Rank {i}: {json.dumps(config.config)}")
548
+ for metric_name, stats in config.metrics.items():
549
+ lines.append(f" {metric_name}: {stats.mean:.2f} ± {stats.std:.2f}")
550
+ lines.append(f" runs: {len(config.trials)}")
551
+ lines.append("")
552
+
553
+ # Show count if limited
554
+ if limit and len(best_configs) > limit:
555
+ lines.append(f"... and {len(best_configs) - limit} more results")
556
+ lines.append("")
557
+ else:
558
+ # Use individual trial scoring
559
+ if len(objectives_data) > 1:
560
+ # Multi-objective: compute Pareto
561
+ objectives = [
562
+ Objective(
563
+ metric=obj["metric"],
564
+ direction=obj["direction"],
565
+ weight=obj.get("weight", 1.0)
566
+ )
567
+ for obj in objectives_data
568
+ ]
569
+ pareto_trials = compute_pareto_frontier(completed_trials, objectives)
570
+
571
+ lines.append("Pareto Frontier (using config objectives):")
572
+ lines.append(f"Found {len(pareto_trials)} non-dominated configurations.")
573
+ lines.append("")
574
+
575
+ for i, trial in enumerate(pareto_trials, 1):
576
+ lines.append(f"Config {i}: {json.dumps(trial.config)}")
577
+ lines.append(f" Trial: {trial.trial_number + 1}")
578
+ # Show all metrics, not just objectives
579
+ for metric_name, metric_value in sorted(trial.metrics.items()):
580
+ lines.append(f" {metric_name}: {metric_value}")
581
+ lines.append(f" duration: {trial.duration_ms}ms")
582
+ lines.append("")
583
+ else:
584
+ # Single objective
585
+ from wafer_core.tools.autotuner.scoring import rank_trials_single_objective
586
+
587
+ obj = objectives_data[0]
588
+ objective = Objective(metric=obj["metric"], direction=obj["direction"])
589
+ best_trials = rank_trials_single_objective(completed_trials, objective)
590
+
591
+ lines.append(f"Results (sorted by {obj['metric']}, {obj['direction']}):")
592
+
593
+ # Apply limit if specified
594
+ trials_to_show = best_trials[:limit] if limit else best_trials
595
+
596
+ for i, trial in enumerate(trials_to_show, 1):
597
+ lines.append(f"Rank {i}: {json.dumps(trial.config)}")
598
+ for metric_name, metric_value in trial.metrics.items():
599
+ lines.append(f" {metric_name}: {metric_value}")
600
+ lines.append("")
601
+
602
+ # Show count if limited
603
+ if limit and len(best_trials) > limit:
604
+ lines.append(f"... and {len(best_trials) - limit} more results")
605
+ lines.append("")
606
+ else:
607
+ # No objectives defined - just list trials or configs
608
+ if use_aggregation:
609
+ lines.append("Results (no objectives defined):")
610
+
611
+ # Apply limit if specified
612
+ configs_to_show = aggregated_configs[:limit] if limit else aggregated_configs
613
+
614
+ for i, config in enumerate(configs_to_show, 1):
615
+ lines.append(f"Config {i}: {json.dumps(config.config)}")
616
+ for metric_name, stats in config.metrics.items():
617
+ lines.append(f" {metric_name}: {stats.mean:.2f} ± {stats.std:.2f}")
618
+ lines.append(f" runs: {len(config.trials)}")
619
+ lines.append("")
620
+
621
+ # Show count if limited
622
+ if limit and len(aggregated_configs) > limit:
623
+ lines.append(f"... and {len(aggregated_configs) - limit} more results")
624
+ lines.append("")
625
+ else:
626
+ lines.append("Results (no objectives defined):")
627
+
628
+ # Apply limit if specified
629
+ trials_to_show = completed_trials[:limit] if limit else completed_trials
630
+
631
+ for i, trial in enumerate(trials_to_show, 1):
632
+ lines.append(f"Trial {i}: {json.dumps(trial.config)}")
633
+ for metric_name, metric_value in trial.metrics.items():
634
+ lines.append(f" {metric_name}: {metric_value}")
635
+ lines.append("")
636
+
637
+ # Show count if limited
638
+ if limit and len(completed_trials) > limit:
639
+ lines.append(f"... and {len(completed_trials) - limit} more results")
640
+ lines.append("")
641
+
642
+ # If show_all is enabled, append failed and constraint-violated trials
643
+ if show_all and (constraint_violated_trials or failed_trials):
644
+ lines.append("")
645
+ lines.append("=" * 60)
646
+ lines.append("Failed and Constraint-Violated Trials")
647
+ lines.append("=" * 60)
648
+ lines.append("")
649
+
650
+ if constraint_violated_trials:
651
+ lines.append(f"Constraint Violations ({len(constraint_violated_trials)} trials):")
652
+ lines.append("These configs failed correctness checks or other constraints")
653
+ lines.append("")
654
+
655
+ for i, trial in enumerate(constraint_violated_trials[:20], 1): # Show up to 20
656
+ lines.append(f"Trial {trial.trial_number}: {json.dumps(trial.config)}")
657
+ lines.append(f" status: {trial.status.value}")
658
+ if trial.metrics:
659
+ for metric_name, metric_value in list(trial.metrics.items())[:5]: # First 5 metrics
660
+ lines.append(f" {metric_name}: {metric_value}")
661
+ if trial.stderr and len(trial.stderr) < 200:
662
+ lines.append(f" error: {trial.stderr.strip()}")
663
+ lines.append("")
664
+
665
+ if len(constraint_violated_trials) > 20:
666
+ lines.append(f"... and {len(constraint_violated_trials) - 20} more constraint violations")
667
+ lines.append("")
668
+
669
+ if failed_trials:
670
+ lines.append(f"Failed Trials ({len(failed_trials)} trials):")
671
+ lines.append("These configs crashed, timed out, or had execution errors")
672
+ lines.append("")
673
+
674
+ for i, trial in enumerate(failed_trials[:20], 1): # Show up to 20
675
+ lines.append(f"Trial {trial.trial_number}: {json.dumps(trial.config)}")
676
+ lines.append(f" status: {trial.status.value}")
677
+ lines.append(f" exit_code: {trial.exit_code}")
678
+ if trial.stderr and len(trial.stderr) < 200:
679
+ lines.append(f" error: {trial.stderr.strip()}")
680
+ elif trial.stderr:
681
+ lines.append(f" error: {trial.stderr[:200].strip()}...")
682
+ lines.append("")
683
+
684
+ if len(failed_trials) > 20:
685
+ lines.append(f"... and {len(failed_trials) - 20} more failed trials")
686
+ lines.append("")
687
+
688
+ return "\n".join(lines)
689
+
690
+ except Exception as e:
691
+ raise ValueError(f"Failed to get results: {e}") from e
692
+
693
+
694
+ def best_command(
695
+ sweep_id: str,
696
+ metric: str,
697
+ ) -> str:
698
+ """Show the single best config from a sweep by a specific metric.
699
+
700
+ Args:
701
+ sweep_id: Sweep ID to retrieve
702
+ metric: Metric to optimize (REQUIRED)
703
+
704
+ Returns:
705
+ Formatted string with best config
706
+ """
707
+ from wafer_core.tools.autotuner.aggregation import aggregate_trials_by_config
708
+ from wafer_core.tools.autotuner.storage import get_sweep, get_trials
709
+
710
+ try:
711
+ # Get sweep and trials
712
+ sweep = asyncio.run(get_sweep(sweep_id))
713
+ trials = asyncio.run(get_trials(sweep_id))
714
+
715
+ # Filter to completed trials only
716
+ completed_trials = [t for t in trials if t.status.value in ("success", "completed")]
717
+
718
+ if not completed_trials:
719
+ return f"No completed trials found for sweep {sweep_id}"
720
+
721
+ # Check if we should aggregate
722
+ trials_per_config = sweep.config.get("trials_per_config", 1) if sweep.config else 1
723
+ use_aggregation = trials_per_config > 1
724
+
725
+ # Determine direction from config objectives if available
726
+ from wafer_core.tools.autotuner.dtypes import Objective
727
+
728
+ direction = "maximize" # Default
729
+ if sweep.config and "objectives" in sweep.config:
730
+ for obj in sweep.config["objectives"]:
731
+ if obj["metric"] == metric:
732
+ direction = obj["direction"]
733
+ break
734
+
735
+ objective = Objective(metric=metric, direction=direction)
736
+
737
+ if use_aggregation:
738
+ # Use aggregated configs
739
+ from wafer_core.tools.autotuner.scoring import rank_configs_single_objective
740
+
741
+ aggregated_configs = aggregate_trials_by_config(completed_trials, trials_per_config)
742
+ if not aggregated_configs:
743
+ return f"No valid configurations found for sweep {sweep_id}"
744
+
745
+ best_configs = rank_configs_single_objective(aggregated_configs, objective)
746
+ if not best_configs:
747
+ return f"No configs found with metric '{metric}'"
748
+
749
+ best_config = best_configs[0]
750
+
751
+ # Format output with aggregated stats
752
+ lines = [
753
+ f"=== Best Config (by {metric}, {direction}) ===",
754
+ "",
755
+ f"Config: {best_config.config_number + 1}",
756
+ f"Runs: {len(best_config.trials)} (all successful)",
757
+ f"All Passed Constraints: {best_config.all_passed_constraints}",
758
+ "",
759
+ "Configuration:",
760
+ json.dumps(best_config.config, indent=2),
761
+ "",
762
+ "Metrics (mean ± std):",
763
+ ]
764
+
765
+ for metric_name, stats in best_config.metrics.items():
766
+ lines.append(f" {metric_name}: {stats.mean:.4f} ± {stats.std:.4f} (min: {stats.min:.4f}, max: {stats.max:.4f})")
767
+
768
+ # Show one representative trial's stdout/stderr
769
+ representative_trial = best_config.trials[0]
770
+ lines.append("")
771
+ lines.append("=" * 60)
772
+ lines.append("STDOUT (from first run):")
773
+ lines.append("=" * 60)
774
+ if representative_trial.stdout.strip():
775
+ lines.append(representative_trial.stdout)
776
+ else:
777
+ lines.append("(empty)")
778
+
779
+ lines.append("")
780
+ lines.append("=" * 60)
781
+ lines.append("STDERR (from first run):")
782
+ lines.append("=" * 60)
783
+ if representative_trial.stderr.strip():
784
+ lines.append(representative_trial.stderr)
785
+ else:
786
+ lines.append("(empty)")
787
+
788
+ return "\n".join(lines)
789
+
790
+ else:
791
+ # Use individual trials
792
+ from wafer_core.tools.autotuner.scoring import rank_trials_single_objective
793
+
794
+ best_trials = rank_trials_single_objective(completed_trials, objective)
795
+ if not best_trials:
796
+ return f"No trials found with metric '{metric}'"
797
+ best = best_trials[0]
798
+
799
+ # Format output with full details (similar to trial command)
800
+ lines = [
801
+ f"=== Best Config (by {metric}, {direction}) ===",
802
+ "",
803
+ f"Trial: {best.trial_number}",
804
+ f"Status: {best.status.value}",
805
+ f"Duration: {best.duration_ms}ms",
806
+ f"Exit Code: {best.exit_code}",
807
+ f"Passed Constraints: {best.passed_constraints}",
808
+ f"Started: {best.started_at.isoformat()}",
809
+ f"Completed: {best.completed_at.isoformat()}",
810
+ "",
811
+ "Configuration:",
812
+ json.dumps(best.config, indent=2),
813
+ "",
814
+ "Metrics:",
815
+ ]
816
+
817
+ for metric_name, metric_value in best.metrics.items():
818
+ lines.append(f" {metric_name}: {metric_value}")
819
+
820
+ lines.append("")
821
+ lines.append("=" * 60)
822
+ lines.append("STDOUT:")
823
+ lines.append("=" * 60)
824
+ if best.stdout.strip():
825
+ lines.append(best.stdout)
826
+ else:
827
+ lines.append("(empty)")
828
+
829
+ lines.append("")
830
+ lines.append("=" * 60)
831
+ lines.append("STDERR:")
832
+ lines.append("=" * 60)
833
+ if best.stderr.strip():
834
+ lines.append(best.stderr)
835
+ else:
836
+ lines.append("(empty)")
837
+
838
+ return "\n".join(lines)
839
+
840
+ except Exception as e:
841
+ raise ValueError(f"Failed to get best config: {e}") from e
842
+
843
+
844
+ def trial_command(
845
+ sweep_id: str,
846
+ trial_number: int,
847
+ ) -> str:
848
+ """Show detailed information about a specific trial.
849
+
850
+ Args:
851
+ sweep_id: Sweep ID
852
+ trial_number: Trial number to inspect (1-indexed, as displayed to user)
853
+
854
+ Returns:
855
+ Formatted string with trial details including stdout, stderr, config, and metrics
856
+ """
857
+ from wafer_core.tools.autotuner.storage import get_trials
858
+
859
+ try:
860
+ # Convert from 1-indexed (user input) to 0-indexed (internal storage)
861
+ trial_number_internal = trial_number - 1
862
+
863
+ # Get all trials for this sweep
864
+ trials = asyncio.run(get_trials(sweep_id))
865
+
866
+ # Find the specific trial
867
+ trial = None
868
+ for t in trials:
869
+ if t.trial_number == trial_number_internal:
870
+ trial = t
871
+ break
872
+
873
+ if not trial:
874
+ return f"Config #{trial_number} not found in sweep {sweep_id}"
875
+
876
+ # Format output with full details (display as 1-indexed)
877
+ lines = [
878
+ f"=== Config #{trial.trial_number + 1} (Sweep: {sweep_id[:8]}...) ===",
879
+ "",
880
+ f"Status: {trial.status.value}",
881
+ f"Duration: {trial.duration_ms}ms",
882
+ f"Exit Code: {trial.exit_code}",
883
+ f"Passed Constraints: {trial.passed_constraints}",
884
+ f"Started: {trial.started_at.isoformat()}",
885
+ f"Completed: {trial.completed_at.isoformat()}",
886
+ "",
887
+ "Configuration:",
888
+ json.dumps(trial.config, indent=2),
889
+ "",
890
+ "Metrics:",
891
+ ]
892
+
893
+ if trial.metrics:
894
+ for metric_name, metric_value in trial.metrics.items():
895
+ lines.append(f" {metric_name}: {metric_value}")
896
+ else:
897
+ lines.append(" (none)")
898
+
899
+ lines.append("")
900
+ lines.append("=" * 60)
901
+ lines.append("STDOUT:")
902
+ lines.append("=" * 60)
903
+ if trial.stdout.strip():
904
+ lines.append(trial.stdout)
905
+ else:
906
+ lines.append("(empty)")
907
+
908
+ lines.append("")
909
+ lines.append("=" * 60)
910
+ lines.append("STDERR:")
911
+ lines.append("=" * 60)
912
+ if trial.stderr.strip():
913
+ lines.append(trial.stderr)
914
+ else:
915
+ lines.append("(empty)")
916
+
917
+ return "\n".join(lines)
918
+
919
+ except Exception as e:
920
+ raise ValueError(f"Failed to get trial details: {e}") from e
921
+
922
+
923
+ def list_command(show_all: bool = False) -> str:
924
+ """List sweeps for the current user.
925
+
926
+ Args:
927
+ show_all: If False (default), only show running and completed sweeps.
928
+ If True, show all sweeps including pending and failed.
929
+
930
+ Returns:
931
+ Formatted string with sweep list
932
+ """
933
+ from wafer_core.tools.autotuner.storage import list_sweeps
934
+
935
+ try:
936
+ # Get all sweeps
937
+ all_sweeps = asyncio.run(list_sweeps())
938
+
939
+ if not all_sweeps:
940
+ return "No sweeps found."
941
+
942
+ # Filter by status unless --all is specified
943
+ if show_all:
944
+ sweeps = all_sweeps
945
+ else:
946
+ sweeps = [s for s in all_sweeps if s.status in ("running", "completed")]
947
+
948
+ if not sweeps:
949
+ if show_all:
950
+ return "No sweeps found."
951
+ else:
952
+ return "No running or completed sweeps found. Use --all to see pending/failed sweeps."
953
+
954
+ # Sort by creation time (most recent first)
955
+ sweeps.sort(key=lambda s: s.created_at, reverse=True)
956
+
957
+ lines = [
958
+ f"Found {len(sweeps)} sweep(s)" + (" (showing all)" if show_all else " (running/completed only)") + ":",
959
+ "",
960
+ ]
961
+
962
+ for sweep in sweeps:
963
+ # Format timestamps
964
+ created = sweep.created_at.strftime("%Y-%m-%d %H:%M:%S")
965
+
966
+ # Status emoji
967
+ status_emoji = {
968
+ "pending": "⏳",
969
+ "running": "🔄",
970
+ "completed": "✅",
971
+ "failed": "❌",
972
+ }.get(sweep.status, "❓")
973
+
974
+ lines.append(f"{status_emoji} {sweep.name}")
975
+ lines.append(f" ID: {sweep.id}")
976
+ lines.append(f" Status: {sweep.status}")
977
+ lines.append(f" Trials: {sweep.completed_trials}/{sweep.total_trials}")
978
+ lines.append(f" Created: {created}")
979
+ if sweep.description:
980
+ lines.append(f" Description: {sweep.description}")
981
+ lines.append("")
982
+
983
+ return "\n".join(lines)
984
+
985
+ except Exception as e:
986
+ raise ValueError(f"Failed to list sweeps: {e}") from e
987
+
988
+
989
+ def delete_command(sweep_id: str) -> str:
990
+ """Delete a sweep and all its trials.
991
+
992
+ Args:
993
+ sweep_id: Sweep ID to delete
994
+
995
+ Returns:
996
+ Success message
997
+ """
998
+ import httpx
999
+ from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url
1000
+
1001
+ try:
1002
+ api_url = get_api_url()
1003
+ headers = _get_auth_headers()
1004
+
1005
+ async def _delete() -> str:
1006
+ async with httpx.AsyncClient(timeout=30.0, headers=headers) as client:
1007
+ response = await client.delete(f"{api_url}/v1/autotuner/sweeps/{sweep_id}")
1008
+ response.raise_for_status()
1009
+ return f"Successfully deleted sweep {sweep_id}"
1010
+
1011
+ return asyncio.run(_delete())
1012
+
1013
+ except httpx.HTTPStatusError as e:
1014
+ if e.response.status_code == 404:
1015
+ raise ValueError(f"Sweep {sweep_id} not found")
1016
+ raise ValueError(f"Failed to delete sweep: {e}")
1017
+ except Exception as e:
1018
+ raise ValueError(f"Failed to delete sweep: {e}") from e
1019
+
1020
+
1021
+ def delete_all_command(status_filter: str | None = None) -> str:
1022
+ """Delete all sweeps (optionally filtered by status).
1023
+
1024
+ Args:
1025
+ status_filter: Optional status to filter by (pending, running, completed, failed)
1026
+
1027
+ Returns:
1028
+ Summary of deletions
1029
+ """
1030
+ import httpx
1031
+ from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url, list_sweeps
1032
+
1033
+ try:
1034
+ # Get all sweeps
1035
+ all_sweeps = asyncio.run(list_sweeps())
1036
+
1037
+ if not all_sweeps:
1038
+ return "No sweeps found."
1039
+
1040
+ # Filter by status if specified
1041
+ if status_filter:
1042
+ sweeps_to_delete = [s for s in all_sweeps if s.status == status_filter]
1043
+ if not sweeps_to_delete:
1044
+ return f"No sweeps found with status '{status_filter}'."
1045
+ else:
1046
+ sweeps_to_delete = all_sweeps
1047
+
1048
+ # Show what will be deleted
1049
+ count = len(sweeps_to_delete)
1050
+ status_msg = f" with status '{status_filter}'" if status_filter else ""
1051
+
1052
+ print(f"Found {count} sweep(s){status_msg} to delete:")
1053
+ print()
1054
+ for sweep in sweeps_to_delete:
1055
+ print(f" - {sweep.name} ({sweep.id})")
1056
+ print(f" Status: {sweep.status}, Trials: {sweep.completed_trials}/{sweep.total_trials}")
1057
+ print()
1058
+
1059
+ # Delete all
1060
+ api_url = get_api_url()
1061
+ headers = _get_auth_headers()
1062
+
1063
+ async def _delete_all() -> int:
1064
+ deleted_count = 0
1065
+ async with httpx.AsyncClient(timeout=30.0, headers=headers) as client:
1066
+ for sweep in sweeps_to_delete:
1067
+ try:
1068
+ response = await client.delete(f"{api_url}/v1/autotuner/sweeps/{sweep.id}")
1069
+ response.raise_for_status()
1070
+ deleted_count += 1
1071
+ print(f"✓ Deleted {sweep.id}")
1072
+ except Exception as e:
1073
+ print(f"✗ Failed to delete {sweep.id}: {e}")
1074
+ return deleted_count
1075
+
1076
+ deleted_count = asyncio.run(_delete_all())
1077
+ return f"\nSuccessfully deleted {deleted_count}/{count} sweeps"
1078
+
1079
+ except Exception as e:
1080
+ raise ValueError(f"Failed to delete sweeps: {e}") from e