wafer-core 0.1.37__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
- wafer_core/rollouts/_logging/__init__.py +5 -1
- wafer_core/rollouts/_logging/logging_config.py +95 -3
- wafer_core/rollouts/_logging/sample_handler.py +66 -0
- wafer_core/rollouts/_pytui/__init__.py +114 -0
- wafer_core/rollouts/_pytui/app.py +809 -0
- wafer_core/rollouts/_pytui/console.py +291 -0
- wafer_core/rollouts/_pytui/renderer.py +210 -0
- wafer_core/rollouts/_pytui/spinner.py +73 -0
- wafer_core/rollouts/_pytui/terminal.py +489 -0
- wafer_core/rollouts/_pytui/text.py +470 -0
- wafer_core/rollouts/_pytui/theme.py +241 -0
- wafer_core/rollouts/evaluation.py +142 -177
- wafer_core/rollouts/progress_app.py +395 -0
- wafer_core/rollouts/tui/DESIGN.md +251 -115
- wafer_core/rollouts/tui/monitor.py +64 -20
- wafer_core/tools/compile/__init__.py +30 -0
- wafer_core/tools/compile/compiler.py +314 -0
- wafer_core/tools/compile/modal_compile.py +359 -0
- wafer_core/tools/compile/tests/__init__.py +1 -0
- wafer_core/tools/compile/tests/test_compiler.py +675 -0
- wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
- wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
- wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
- wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
- wafer_core/tools/compile/types.py +117 -0
- {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/METADATA +1 -1
- {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/RECORD +29 -12
- wafer_core/rollouts/events.py +0 -240
- wafer_core/rollouts/progress_display.py +0 -476
- wafer_core/utils/event_streaming.py +0 -63
- {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/WHEEL +0 -0
|
@@ -17,6 +17,7 @@ from typing import Any
|
|
|
17
17
|
|
|
18
18
|
import trio
|
|
19
19
|
|
|
20
|
+
from ._logging.logging_config import EvalLoggingContext, setup_eval_logging
|
|
20
21
|
from .agents import run_agent
|
|
21
22
|
from .dtypes import (
|
|
22
23
|
Actor,
|
|
@@ -29,16 +30,19 @@ from .dtypes import (
|
|
|
29
30
|
Score,
|
|
30
31
|
StopReason,
|
|
31
32
|
StreamChunk,
|
|
33
|
+
TextDelta,
|
|
32
34
|
TextEnd,
|
|
35
|
+
ThinkingDelta,
|
|
33
36
|
ToolExecutionEnd,
|
|
34
37
|
Trajectory,
|
|
35
38
|
)
|
|
36
|
-
from .events import EventEmitter, emit_event
|
|
37
|
-
from .progress import MultiProgress
|
|
38
39
|
from .training.types import Sample, Status
|
|
39
40
|
|
|
40
41
|
logger = logging.getLogger(__name__)
|
|
41
42
|
|
|
43
|
+
# Logger for structured eval events — handlers configured by setup_eval_logging()
|
|
44
|
+
_event_logger = logging.getLogger("wafer.eval.events")
|
|
45
|
+
|
|
42
46
|
|
|
43
47
|
# ── Runtime Context ───────────────────────────────────────────────────────────
|
|
44
48
|
|
|
@@ -56,7 +60,6 @@ class EvalRuntime:
|
|
|
56
60
|
config: EvalConfig
|
|
57
61
|
api_limiter: trio.CapacityLimiter | None = None
|
|
58
62
|
tool_limiter: trio.CapacityLimiter | None = None
|
|
59
|
-
progress: MultiProgress | None = None
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
# JSON-like recursive type for sanitize_api_keys
|
|
@@ -187,17 +190,12 @@ async def _evaluate_batch(
|
|
|
187
190
|
Used for incremental report writing.
|
|
188
191
|
"""
|
|
189
192
|
config = runtime.config
|
|
190
|
-
progress = runtime.progress
|
|
191
193
|
results: list[Sample] = []
|
|
192
194
|
# Lock for thread-safe results access during concurrent execution
|
|
193
195
|
results_lock = trio.Lock()
|
|
194
196
|
|
|
195
197
|
async def run_one(sample_id: str, sample_data: dict[str, Any]) -> Sample:
|
|
196
198
|
"""Evaluate a single sample."""
|
|
197
|
-
task_name = sample_data.get("name", sample_id)
|
|
198
|
-
if progress:
|
|
199
|
-
progress.add_task(sample_id, name=task_name)
|
|
200
|
-
|
|
201
199
|
# Get environment: prefer direct environment, fall back to factory
|
|
202
200
|
if config.environment is not None:
|
|
203
201
|
env = config.environment
|
|
@@ -212,17 +210,6 @@ async def _evaluate_batch(
|
|
|
212
210
|
environment=env,
|
|
213
211
|
)
|
|
214
212
|
|
|
215
|
-
# Mark task complete
|
|
216
|
-
if progress:
|
|
217
|
-
reward = result.score.reward if result.score else 0.0
|
|
218
|
-
success = result.metadata.get("status") == "success"
|
|
219
|
-
if success:
|
|
220
|
-
message = f"reward={reward:.2f}"
|
|
221
|
-
else:
|
|
222
|
-
error = result.metadata.get("error", "failed")
|
|
223
|
-
message = error[:30] if len(error) > 30 else error
|
|
224
|
-
progress.complete_task(sample_id, success=success, message=message)
|
|
225
|
-
|
|
226
213
|
return result
|
|
227
214
|
|
|
228
215
|
if config.max_concurrent == 1:
|
|
@@ -477,11 +464,8 @@ def _strip_ansi(text: str) -> str:
|
|
|
477
464
|
|
|
478
465
|
|
|
479
466
|
def _truncate(text: str, max_len: int = 2000) -> str:
|
|
480
|
-
"""
|
|
481
|
-
|
|
482
|
-
if len(text) <= max_len:
|
|
483
|
-
return text
|
|
484
|
-
return text[:max_len] + f"\n\n... (truncated, {len(text)} chars total)"
|
|
467
|
+
"""Strip ANSI codes from text."""
|
|
468
|
+
return _strip_ansi(text)
|
|
485
469
|
|
|
486
470
|
|
|
487
471
|
def _format_message_content(content: Any) -> str:
|
|
@@ -642,7 +626,9 @@ class EvalReport:
|
|
|
642
626
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
643
627
|
git_info: dict[str, Any] = field(default_factory=_get_git_info)
|
|
644
628
|
config_path: str | None = None # Path to config file relative to repo root
|
|
645
|
-
metadata: dict[str, Any] | None =
|
|
629
|
+
metadata: dict[str, Any] | None = (
|
|
630
|
+
None # Custom metadata (waferbench_category, github_runner, etc.)
|
|
631
|
+
)
|
|
646
632
|
|
|
647
633
|
async def save(self, output_dir: Path) -> None:
|
|
648
634
|
"""Save evaluation results to directory."""
|
|
@@ -974,7 +960,6 @@ async def evaluate_sample(
|
|
|
974
960
|
"""
|
|
975
961
|
# Unpack runtime for convenience
|
|
976
962
|
config = runtime.config
|
|
977
|
-
progress = runtime.progress
|
|
978
963
|
|
|
979
964
|
# Prepare initial messages from sample
|
|
980
965
|
initial_messages = config.prepare_messages(sample_data)
|
|
@@ -1004,73 +989,101 @@ async def evaluate_sample(
|
|
|
1004
989
|
async def on_chunk_with_sample_id(event: object) -> None:
|
|
1005
990
|
nonlocal last_status, current_turn
|
|
1006
991
|
|
|
1007
|
-
# Update MultiProgress on various events for granular status
|
|
1008
992
|
status = _get_progress_status_for_event(event)
|
|
1009
|
-
turn = _get_turn_from_event(event)
|
|
1010
|
-
|
|
1011
|
-
if progress is not None:
|
|
1012
|
-
if status is not None or turn is not None:
|
|
1013
|
-
progress.update_task(
|
|
1014
|
-
sample_id,
|
|
1015
|
-
turn=turn if turn is not None else None,
|
|
1016
|
-
status=status if status is not None else None,
|
|
1017
|
-
)
|
|
1018
993
|
|
|
1019
|
-
# Emit to
|
|
994
|
+
# Emit to JSONL files via logging — overview (INFO+) and per-sample (all levels)
|
|
1020
995
|
if isinstance(event, StreamChunk):
|
|
1021
996
|
if event.type == "turn_start":
|
|
1022
997
|
turn_num = event.data.get("turn", 0)
|
|
1023
998
|
current_turn[sample_id] = turn_num
|
|
1024
|
-
|
|
999
|
+
_event_logger.info(
|
|
1000
|
+
"turn",
|
|
1001
|
+
extra={
|
|
1002
|
+
"sample_id": sample_id,
|
|
1003
|
+
"turn": turn_num,
|
|
1004
|
+
"status": "waiting",
|
|
1005
|
+
},
|
|
1006
|
+
)
|
|
1025
1007
|
last_status[sample_id] = "waiting"
|
|
1026
1008
|
elif event.type == "modal_progress":
|
|
1027
|
-
|
|
1009
|
+
_event_logger.info(
|
|
1010
|
+
"modal_progress",
|
|
1011
|
+
extra={
|
|
1012
|
+
"sample_id": sample_id,
|
|
1013
|
+
"phase": event.data.get("phase", ""),
|
|
1014
|
+
},
|
|
1015
|
+
)
|
|
1028
1016
|
|
|
1029
|
-
# Emit status changes
|
|
1017
|
+
# Emit status changes (dedup to avoid flooding)
|
|
1030
1018
|
if status is not None and status != last_status.get(sample_id):
|
|
1031
|
-
|
|
1019
|
+
_event_logger.info("turn", extra={"sample_id": sample_id, "status": status})
|
|
1032
1020
|
last_status[sample_id] = status
|
|
1033
1021
|
|
|
1034
|
-
# Wide events:
|
|
1022
|
+
# Wide events: detailed timing for performance analysis
|
|
1035
1023
|
sample_turn = current_turn.get(sample_id, 0)
|
|
1036
1024
|
if isinstance(event, LLMCallEnd):
|
|
1037
|
-
|
|
1025
|
+
_event_logger.info(
|
|
1038
1026
|
"llm_call",
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1027
|
+
extra={
|
|
1028
|
+
"sample_id": sample_id,
|
|
1029
|
+
"turn": sample_turn,
|
|
1030
|
+
"duration_ms": round(event.duration_ms, 1),
|
|
1031
|
+
"provider": event.provider,
|
|
1032
|
+
"model": event.model,
|
|
1033
|
+
"tokens_in": event.tokens_in,
|
|
1034
|
+
"tokens_out": event.tokens_out,
|
|
1035
|
+
"status": event.status,
|
|
1036
|
+
"error": event.error,
|
|
1037
|
+
},
|
|
1048
1038
|
)
|
|
1049
1039
|
elif isinstance(event, ToolExecutionEnd):
|
|
1050
|
-
|
|
1040
|
+
_event_logger.info(
|
|
1051
1041
|
"tool_execution",
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1042
|
+
extra={
|
|
1043
|
+
"sample_id": sample_id,
|
|
1044
|
+
"turn": sample_turn,
|
|
1045
|
+
"tool_name": event.tool_name,
|
|
1046
|
+
"duration_ms": round(event.duration_ms, 1),
|
|
1047
|
+
"status": event.status,
|
|
1048
|
+
"is_error": event.is_error,
|
|
1049
|
+
"result_summary": event.result_summary,
|
|
1050
|
+
},
|
|
1059
1051
|
)
|
|
1060
1052
|
elif isinstance(event, TextEnd):
|
|
1061
|
-
#
|
|
1062
|
-
# Truncate long content to avoid bloating events file
|
|
1053
|
+
# Truncate for events.jsonl (INFO), full content in per-sample (also INFO)
|
|
1063
1054
|
content = event.content
|
|
1064
1055
|
truncated = len(content) > 2000
|
|
1065
1056
|
if truncated:
|
|
1066
1057
|
content = content[:2000] + "..."
|
|
1067
|
-
|
|
1058
|
+
_event_logger.info(
|
|
1068
1059
|
"assistant_message",
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1060
|
+
extra={
|
|
1061
|
+
"sample_id": sample_id,
|
|
1062
|
+
"turn": sample_turn,
|
|
1063
|
+
"content": content,
|
|
1064
|
+
"content_length": len(event.content),
|
|
1065
|
+
"truncated": truncated,
|
|
1066
|
+
},
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# DEBUG: streaming deltas — per-sample files only (filtered out of events.jsonl)
|
|
1070
|
+
elif isinstance(event, TextDelta):
|
|
1071
|
+
_event_logger.debug(
|
|
1072
|
+
"text_delta",
|
|
1073
|
+
extra={
|
|
1074
|
+
"sample_id": sample_id,
|
|
1075
|
+
"turn": sample_turn,
|
|
1076
|
+
"text": event.delta,
|
|
1077
|
+
},
|
|
1078
|
+
)
|
|
1079
|
+
elif isinstance(event, ThinkingDelta):
|
|
1080
|
+
_event_logger.debug(
|
|
1081
|
+
"thinking_delta",
|
|
1082
|
+
extra={
|
|
1083
|
+
"sample_id": sample_id,
|
|
1084
|
+
"turn": sample_turn,
|
|
1085
|
+
"text": event.delta,
|
|
1086
|
+
},
|
|
1074
1087
|
)
|
|
1075
1088
|
|
|
1076
1089
|
# Wrap event with sample_id and forward to base handler
|
|
@@ -1107,13 +1120,13 @@ async def evaluate_sample(
|
|
|
1107
1120
|
)
|
|
1108
1121
|
)
|
|
1109
1122
|
|
|
1110
|
-
#
|
|
1123
|
+
# Emit sample_start for progress display
|
|
1111
1124
|
# TODO: Retry logic can emit multiple sample_start events for the same sample_id
|
|
1112
1125
|
# without a corresponding sample_end, causing progress display to show 100/100
|
|
1113
1126
|
# while a sample is still running. Either emit sample_end before retry, or
|
|
1114
1127
|
# don't emit sample_start on retries. See: chiraag/supabase-eval-traces PR #504
|
|
1115
1128
|
sample_name = sample_data.get("name", sample_id)
|
|
1116
|
-
|
|
1129
|
+
_event_logger.info("sample_start", extra={"sample_id": sample_id, "sample_name": sample_name})
|
|
1117
1130
|
|
|
1118
1131
|
# Run agent with error handling
|
|
1119
1132
|
result = await _run_agent_with_error_handling(initial_state, run_config, sample_id)
|
|
@@ -1202,8 +1215,7 @@ async def evaluate_sample(
|
|
|
1202
1215
|
)
|
|
1203
1216
|
)
|
|
1204
1217
|
|
|
1205
|
-
|
|
1206
|
-
emit_event("sample_end", id=sample_id, score=reward)
|
|
1218
|
+
_event_logger.info("sample_end", extra={"sample_id": sample_id, "score": reward})
|
|
1207
1219
|
|
|
1208
1220
|
return sample
|
|
1209
1221
|
|
|
@@ -1296,13 +1308,26 @@ async def evaluate(
|
|
|
1296
1308
|
logger.info(f"max concurrent: {config.max_concurrent}")
|
|
1297
1309
|
logger.debug("=" * 50)
|
|
1298
1310
|
|
|
1299
|
-
#
|
|
1300
|
-
|
|
1301
|
-
emitter: EventEmitter | None = None
|
|
1311
|
+
# Set up eval logging: events.jsonl (overview) + samples/{id}.jsonl (per-sample)
|
|
1312
|
+
eval_logging: EvalLoggingContext | None = None
|
|
1302
1313
|
if config.output_dir:
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1314
|
+
eval_logging = setup_eval_logging(config.output_dir)
|
|
1315
|
+
_event_logger.info(
|
|
1316
|
+
"eval_start",
|
|
1317
|
+
extra={
|
|
1318
|
+
"eval_name": config.eval_name,
|
|
1319
|
+
"total": len(samples_to_eval),
|
|
1320
|
+
},
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
# Progress display: pytui subprocess that reads events.jsonl
|
|
1324
|
+
# Requires output_dir since it reads events.jsonl from there
|
|
1325
|
+
progress_ctx = None
|
|
1326
|
+
if config.show_progress and config.output_dir:
|
|
1327
|
+
from .progress_app import progress_display
|
|
1328
|
+
|
|
1329
|
+
progress_ctx = progress_display(output_dir=config.output_dir)
|
|
1330
|
+
progress_ctx.__enter__()
|
|
1306
1331
|
|
|
1307
1332
|
# Evaluate samples (with concurrency control)
|
|
1308
1333
|
results: list[Sample] = []
|
|
@@ -1329,18 +1354,6 @@ async def evaluate(
|
|
|
1329
1354
|
)
|
|
1330
1355
|
last_report_count = len(all_results)
|
|
1331
1356
|
|
|
1332
|
-
# Initialize progress display for sample-level tracking
|
|
1333
|
-
# MultiProgress shows each concurrent sample with turn-by-turn updates
|
|
1334
|
-
progress: MultiProgress | None = None
|
|
1335
|
-
if config.show_progress:
|
|
1336
|
-
progress = MultiProgress(
|
|
1337
|
-
total=len(samples_to_eval),
|
|
1338
|
-
desc=config.eval_name,
|
|
1339
|
-
unit="sample",
|
|
1340
|
-
verbose=config.verbose, # verbose=True shows INFO logs, False shows only WARNING+
|
|
1341
|
-
)
|
|
1342
|
-
progress.__enter__()
|
|
1343
|
-
|
|
1344
1357
|
# Create two-level concurrency limiters if configured
|
|
1345
1358
|
api_limiter = (
|
|
1346
1359
|
trio.CapacityLimiter(config.max_api_concurrent)
|
|
@@ -1358,7 +1371,6 @@ async def evaluate(
|
|
|
1358
1371
|
config=config,
|
|
1359
1372
|
api_limiter=api_limiter,
|
|
1360
1373
|
tool_limiter=tool_limiter,
|
|
1361
|
-
progress=progress,
|
|
1362
1374
|
)
|
|
1363
1375
|
|
|
1364
1376
|
# Run initial evaluation batch with incremental report callback
|
|
@@ -1386,9 +1398,10 @@ async def evaluate(
|
|
|
1386
1398
|
else:
|
|
1387
1399
|
raise
|
|
1388
1400
|
|
|
1389
|
-
# Close progress
|
|
1390
|
-
if
|
|
1391
|
-
|
|
1401
|
+
# Close progress TUI before any print() calls
|
|
1402
|
+
if progress_ctx:
|
|
1403
|
+
progress_ctx.__exit__(None, None, None)
|
|
1404
|
+
progress_ctx = None
|
|
1392
1405
|
|
|
1393
1406
|
# Write final partial report and upload if interrupted
|
|
1394
1407
|
if interrupted and config.output_dir:
|
|
@@ -1409,10 +1422,17 @@ async def evaluate(
|
|
|
1409
1422
|
print("Upload complete.")
|
|
1410
1423
|
except Exception as e:
|
|
1411
1424
|
print(f"Upload failed: {e}")
|
|
1412
|
-
#
|
|
1413
|
-
if
|
|
1414
|
-
|
|
1415
|
-
|
|
1425
|
+
# Emit eval_end and clean up logging
|
|
1426
|
+
if eval_logging:
|
|
1427
|
+
_event_logger.info(
|
|
1428
|
+
"eval_end",
|
|
1429
|
+
extra={
|
|
1430
|
+
"eval_name": config.eval_name,
|
|
1431
|
+
"total": len(results),
|
|
1432
|
+
"interrupted": True,
|
|
1433
|
+
},
|
|
1434
|
+
)
|
|
1435
|
+
eval_logging.teardown()
|
|
1416
1436
|
# Exit cleanly - don't re-raise to avoid big traceback
|
|
1417
1437
|
print(f"Partial results saved to {config.output_dir}")
|
|
1418
1438
|
print("Resume with: --resume", config.output_dir)
|
|
@@ -1435,21 +1455,16 @@ async def evaluate(
|
|
|
1435
1455
|
f"Retrying {len(failed_samples)} failed samples "
|
|
1436
1456
|
f"(attempt {retry_attempt + 1}/{config.max_sample_retries}, waiting {wait_seconds}s)"
|
|
1437
1457
|
)
|
|
1438
|
-
|
|
1439
|
-
progress.log(retry_msg)
|
|
1440
|
-
else:
|
|
1441
|
-
logger.info(retry_msg)
|
|
1458
|
+
logger.info(retry_msg)
|
|
1442
1459
|
await trio.sleep(wait_seconds)
|
|
1443
1460
|
|
|
1444
1461
|
# Remove failed samples and retry
|
|
1445
1462
|
failed_ids = {sid for sid, _ in failed_samples}
|
|
1446
1463
|
results = [r for r in results if r.id not in failed_ids]
|
|
1447
|
-
# Create runtime without progress for retries (no incremental reports during retry)
|
|
1448
1464
|
retry_runtime = EvalRuntime(
|
|
1449
1465
|
config=config,
|
|
1450
1466
|
api_limiter=api_limiter,
|
|
1451
1467
|
tool_limiter=tool_limiter,
|
|
1452
|
-
progress=None,
|
|
1453
1468
|
)
|
|
1454
1469
|
retry_results = await _evaluate_batch(failed_samples, retry_runtime)
|
|
1455
1470
|
results.extend(retry_results)
|
|
@@ -1471,10 +1486,7 @@ async def evaluate(
|
|
|
1471
1486
|
retry_result_msg = (
|
|
1472
1487
|
f"Retry {retry_attempt + 1}: {succeeded} succeeded, {still_failed} still failing"
|
|
1473
1488
|
)
|
|
1474
|
-
|
|
1475
|
-
progress.log(retry_result_msg)
|
|
1476
|
-
else:
|
|
1477
|
-
logger.info(retry_result_msg)
|
|
1489
|
+
logger.info(retry_result_msg)
|
|
1478
1490
|
|
|
1479
1491
|
# Compute summary metrics
|
|
1480
1492
|
summary_metrics = compute_summary_metrics(results)
|
|
@@ -1517,10 +1529,21 @@ async def evaluate(
|
|
|
1517
1529
|
else:
|
|
1518
1530
|
logger.info(f"{key}: {value}")
|
|
1519
1531
|
|
|
1520
|
-
#
|
|
1521
|
-
if
|
|
1522
|
-
|
|
1523
|
-
|
|
1532
|
+
# Emit eval_end and clean up logging (before progress TUI so it sees eval_end)
|
|
1533
|
+
if eval_logging:
|
|
1534
|
+
_event_logger.info(
|
|
1535
|
+
"eval_end",
|
|
1536
|
+
extra={
|
|
1537
|
+
"eval_name": config.eval_name,
|
|
1538
|
+
"total": len(results),
|
|
1539
|
+
},
|
|
1540
|
+
)
|
|
1541
|
+
eval_logging.teardown()
|
|
1542
|
+
|
|
1543
|
+
# Close progress TUI
|
|
1544
|
+
if progress_ctx:
|
|
1545
|
+
progress_ctx.__exit__(None, None, None)
|
|
1546
|
+
progress_ctx = None
|
|
1524
1547
|
|
|
1525
1548
|
return report
|
|
1526
1549
|
|
|
@@ -1571,7 +1594,7 @@ def compute_summary_metrics(results: list[Sample]) -> dict[str, float]:
|
|
|
1571
1594
|
median_val = (sorted_values[n // 2 - 1] + sorted_values[n // 2]) / 2
|
|
1572
1595
|
else:
|
|
1573
1596
|
median_val = sorted_values[n // 2]
|
|
1574
|
-
|
|
1597
|
+
|
|
1575
1598
|
summary[f"mean_{metric_name}"] = mean_val
|
|
1576
1599
|
summary[f"median_{metric_name}"] = median_val
|
|
1577
1600
|
summary[f"min_{metric_name}"] = min(values)
|
|
@@ -1800,75 +1823,22 @@ def get_api_key(provider: str = "anthropic") -> str | None:
|
|
|
1800
1823
|
return None
|
|
1801
1824
|
|
|
1802
1825
|
|
|
1803
|
-
def run_with_progress(
|
|
1804
|
-
eval_fn: Callable[[Any], Any],
|
|
1805
|
-
config: Any,
|
|
1806
|
-
output_dir: Path,
|
|
1807
|
-
quiet_config_fn: Callable[[Any], Any] | None = None,
|
|
1808
|
-
async_wrapper: Callable[[Callable, Any], Callable[[], Any]] | None = None,
|
|
1809
|
-
) -> dict[str, Any]:
|
|
1810
|
-
"""Run evaluation with progress display TUI.
|
|
1811
|
-
|
|
1812
|
-
Wraps an async eval function with the progress_display context manager,
|
|
1813
|
-
which redirects stdout/stderr to output.log and renders a TUI.
|
|
1814
|
-
|
|
1815
|
-
Args:
|
|
1816
|
-
eval_fn: Async evaluation function that takes config and returns results
|
|
1817
|
-
config: Configuration object
|
|
1818
|
-
output_dir: Directory for output files (events.jsonl, output.log)
|
|
1819
|
-
quiet_config_fn: Optional function to create a quiet version of config
|
|
1820
|
-
(disables internal verbose/show_progress flags)
|
|
1821
|
-
async_wrapper: Optional wrapper for async runtime compatibility (e.g., trio_asyncio).
|
|
1822
|
-
Takes (eval_fn, config) and returns an async callable for trio.run().
|
|
1823
|
-
|
|
1824
|
-
Returns:
|
|
1825
|
-
Results dict from eval_fn
|
|
1826
|
-
|
|
1827
|
-
Example:
|
|
1828
|
-
def my_quiet_config(config):
|
|
1829
|
-
return replace(config, run=replace(config.run, verbose=False, show_progress=False))
|
|
1830
|
-
|
|
1831
|
-
result = run_with_progress(
|
|
1832
|
-
evaluate_my_task,
|
|
1833
|
-
config,
|
|
1834
|
-
config.output.output_dir,
|
|
1835
|
-
quiet_config_fn=my_quiet_config,
|
|
1836
|
-
)
|
|
1837
|
-
"""
|
|
1838
|
-
from .progress_display import progress_display
|
|
1839
|
-
|
|
1840
|
-
# Apply quiet config transformation if provided
|
|
1841
|
-
run_config = quiet_config_fn(config) if quiet_config_fn else config
|
|
1842
|
-
|
|
1843
|
-
with progress_display(output_dir=output_dir):
|
|
1844
|
-
if async_wrapper:
|
|
1845
|
-
result = trio.run(async_wrapper(eval_fn, run_config))
|
|
1846
|
-
else:
|
|
1847
|
-
result = trio.run(eval_fn, run_config)
|
|
1848
|
-
assert result is not None, "Evaluation was cancelled"
|
|
1849
|
-
|
|
1850
|
-
return result
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
1826
|
def run_eval(
|
|
1854
1827
|
eval_fn: Callable[[Any], Any],
|
|
1855
1828
|
config: Any,
|
|
1856
1829
|
output_dir: Path,
|
|
1857
|
-
show_progress: bool = False,
|
|
1858
|
-
quiet_config_fn: Callable[[Any], Any] | None = None,
|
|
1859
1830
|
print_summary_fn: Callable[[dict[str, Any], Path], None] | None = None,
|
|
1860
1831
|
async_wrapper: Callable[[Callable, Any], Callable[[], Any]] | None = None,
|
|
1861
1832
|
) -> dict[str, Any]:
|
|
1862
1833
|
"""Standard entry point for running evaluations.
|
|
1863
1834
|
|
|
1864
|
-
|
|
1835
|
+
Runs the async eval function via trio.run(). Progress display is handled
|
|
1836
|
+
internally by evaluate() based on config.show_progress.
|
|
1865
1837
|
|
|
1866
1838
|
Args:
|
|
1867
1839
|
eval_fn: Async evaluation function that takes config and returns results
|
|
1868
1840
|
config: Configuration object
|
|
1869
1841
|
output_dir: Directory for output files
|
|
1870
|
-
show_progress: Whether to show progress TUI
|
|
1871
|
-
quiet_config_fn: Optional function to create quiet config for progress mode
|
|
1872
1842
|
print_summary_fn: Optional function to print summary after completion
|
|
1873
1843
|
async_wrapper: Optional wrapper for async runtime compatibility (e.g., trio_asyncio).
|
|
1874
1844
|
Takes (eval_fn, config) and returns an async callable for trio.run().
|
|
@@ -1883,8 +1853,6 @@ def run_eval(
|
|
|
1883
1853
|
eval_fn=evaluate_my_task,
|
|
1884
1854
|
config=config,
|
|
1885
1855
|
output_dir=config.output.output_dir,
|
|
1886
|
-
show_progress=config.run.show_progress,
|
|
1887
|
-
quiet_config_fn=lambda c: replace(c, run=replace(c.run, verbose=False, show_progress=False)),
|
|
1888
1856
|
print_summary_fn=print_my_summary,
|
|
1889
1857
|
)
|
|
1890
1858
|
|
|
@@ -1904,14 +1872,11 @@ def run_eval(
|
|
|
1904
1872
|
async_wrapper=asyncio_compat_wrapper,
|
|
1905
1873
|
)
|
|
1906
1874
|
"""
|
|
1907
|
-
if
|
|
1908
|
-
result =
|
|
1875
|
+
if async_wrapper:
|
|
1876
|
+
result = trio.run(async_wrapper(eval_fn, config))
|
|
1909
1877
|
else:
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
else:
|
|
1913
|
-
result = trio.run(eval_fn, config)
|
|
1914
|
-
assert result is not None, "Evaluation was cancelled"
|
|
1878
|
+
result = trio.run(eval_fn, config)
|
|
1879
|
+
assert result is not None, "Evaluation was cancelled"
|
|
1915
1880
|
|
|
1916
1881
|
if print_summary_fn:
|
|
1917
1882
|
print_summary_fn(result, output_dir)
|