synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +186 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +7 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +21 -46
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +67 -49
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +242 -193
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
- examples/warming_up_to_rl/run_eval.py +127 -18
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +73 -29
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +134 -0
- synth_ai/api/train/configs/sft.py +95 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +49 -43
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +86 -106
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1710 -186
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +82 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +127 -0
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/task/rubrics/strict.py +149 -0
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +228 -89
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -1
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -1,23 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import argparse
|
|
3
4
|
import ast
|
|
4
5
|
import asyncio
|
|
5
6
|
import contextlib
|
|
7
|
+
import functools
|
|
6
8
|
import hashlib
|
|
7
9
|
import importlib
|
|
8
10
|
import importlib.util
|
|
9
11
|
import inspect
|
|
10
12
|
import json
|
|
11
13
|
import os
|
|
14
|
+
import shlex
|
|
12
15
|
import shutil
|
|
13
16
|
import signal
|
|
17
|
+
import sqlite3
|
|
14
18
|
import subprocess
|
|
15
19
|
import sys
|
|
16
20
|
import tempfile
|
|
17
21
|
import textwrap
|
|
22
|
+
import time
|
|
18
23
|
import types
|
|
24
|
+
import uuid
|
|
19
25
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
20
26
|
from dataclasses import dataclass
|
|
27
|
+
from datetime import UTC, datetime
|
|
21
28
|
from pathlib import Path
|
|
22
29
|
from typing import Any, cast
|
|
23
30
|
|
|
@@ -25,24 +32,39 @@ try: # Python 3.11+
|
|
|
25
32
|
import tomllib as _toml
|
|
26
33
|
except Exception: # pragma: no cover - fallback
|
|
27
34
|
_toml = None # type: ignore
|
|
28
|
-
import uuid
|
|
29
35
|
|
|
30
36
|
import click
|
|
31
37
|
from click.exceptions import Abort
|
|
32
38
|
|
|
39
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
40
|
+
BaseEvent,
|
|
41
|
+
EnvironmentEvent,
|
|
42
|
+
RuntimeEvent,
|
|
43
|
+
SessionEventMarkovBlanketMessage,
|
|
44
|
+
SessionMessageContent,
|
|
45
|
+
SessionTimeStep,
|
|
46
|
+
SessionTracer,
|
|
47
|
+
TimeRecord,
|
|
48
|
+
)
|
|
49
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
50
|
+
SessionTrace as V3SessionTrace,
|
|
51
|
+
)
|
|
52
|
+
|
|
33
53
|
# ---------------------------------------------------------------------------
|
|
34
54
|
# Dynamic imports to avoid hard dependencies during type checking.
|
|
35
55
|
# ---------------------------------------------------------------------------
|
|
36
56
|
ModalDeploymentConfigType = TaskAppConfigType = TaskAppEntryType = Any
|
|
37
57
|
|
|
38
58
|
try: # Resolve base URL defaults lazily
|
|
39
|
-
_config_module =
|
|
59
|
+
_config_module = cast(
|
|
60
|
+
Any, importlib.import_module("synth_ai.config.base_url")
|
|
61
|
+
)
|
|
40
62
|
PROD_BASE_URL_DEFAULT = cast(str, _config_module.PROD_BASE_URL_DEFAULT)
|
|
41
63
|
except Exception: # pragma: no cover - fallback
|
|
42
64
|
PROD_BASE_URL_DEFAULT = "https://agent-learning.onrender.com"
|
|
43
65
|
|
|
44
66
|
try:
|
|
45
|
-
_task_apps_module = importlib.import_module("synth_ai.task.apps")
|
|
67
|
+
_task_apps_module = cast(Any, importlib.import_module("synth_ai.task.apps"))
|
|
46
68
|
ModalDeploymentConfig = cast(
|
|
47
69
|
type[ModalDeploymentConfigType], _task_apps_module.ModalDeploymentConfig
|
|
48
70
|
)
|
|
@@ -53,9 +75,9 @@ except Exception as exc: # pragma: no cover - critical dependency
|
|
|
53
75
|
raise RuntimeError("Unable to load task app registry") from exc
|
|
54
76
|
|
|
55
77
|
try:
|
|
56
|
-
_task_server_module = importlib.import_module("synth_ai.task.server")
|
|
57
|
-
create_task_app = _task_server_module.create_task_app
|
|
58
|
-
run_task_app = _task_server_module.run_task_app
|
|
78
|
+
_task_server_module = cast(Any, importlib.import_module("synth_ai.task.server"))
|
|
79
|
+
create_task_app = cast(Callable[..., Any], _task_server_module.create_task_app)
|
|
80
|
+
run_task_app = cast(Callable[..., Any], _task_server_module.run_task_app)
|
|
59
81
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
60
82
|
raise RuntimeError("Unable to load task app server utilities") from exc
|
|
61
83
|
|
|
@@ -64,10 +86,12 @@ def _load_demo_directory() -> Path | None:
|
|
|
64
86
|
"""Return the demo task apps directory if available."""
|
|
65
87
|
|
|
66
88
|
try:
|
|
67
|
-
module =
|
|
68
|
-
|
|
89
|
+
module = cast(
|
|
90
|
+
Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
91
|
+
)
|
|
92
|
+
loader = cast(Callable[[], str | Path | None], module.load_demo_dir)
|
|
69
93
|
demo_dir = loader()
|
|
70
|
-
if isinstance(demo_dir,
|
|
94
|
+
if isinstance(demo_dir, str | Path):
|
|
71
95
|
demo_path = Path(demo_dir)
|
|
72
96
|
if demo_path.exists():
|
|
73
97
|
return demo_path.resolve()
|
|
@@ -105,6 +129,25 @@ DEFAULT_SEARCH_RELATIVE = (
|
|
|
105
129
|
)
|
|
106
130
|
|
|
107
131
|
|
|
132
|
+
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float | None:
|
|
133
|
+
if len(xs) != len(ys) or len(xs) < 2:
|
|
134
|
+
return None
|
|
135
|
+
mean_x = sum(xs) / len(xs)
|
|
136
|
+
mean_y = sum(ys) / len(ys)
|
|
137
|
+
num = 0.0
|
|
138
|
+
denom_x = 0.0
|
|
139
|
+
denom_y = 0.0
|
|
140
|
+
for x, y in zip(xs, ys, strict=False):
|
|
141
|
+
dx = x - mean_x
|
|
142
|
+
dy = y - mean_y
|
|
143
|
+
num += dx * dy
|
|
144
|
+
denom_x += dx * dx
|
|
145
|
+
denom_y += dy * dy
|
|
146
|
+
if denom_x <= 0 or denom_y <= 0:
|
|
147
|
+
return None
|
|
148
|
+
return num / (denom_x ** 0.5 * denom_y ** 0.5)
|
|
149
|
+
|
|
150
|
+
|
|
108
151
|
@dataclass
|
|
109
152
|
class AppChoice:
|
|
110
153
|
app_id: str
|
|
@@ -128,6 +171,171 @@ class AppChoice:
|
|
|
128
171
|
return entry
|
|
129
172
|
|
|
130
173
|
|
|
174
|
+
@dataclass
|
|
175
|
+
class JudgeSpec:
|
|
176
|
+
name: str
|
|
177
|
+
fn: Callable[..., Any]
|
|
178
|
+
kwargs: dict[str, Any]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _parse_datetime_for_trace(value: Any) -> datetime | None:
|
|
182
|
+
if isinstance(value, datetime):
|
|
183
|
+
return value if value.tzinfo else value.replace(tzinfo=UTC)
|
|
184
|
+
if isinstance(value, str):
|
|
185
|
+
value = value.replace("Z", "+00:00")
|
|
186
|
+
try:
|
|
187
|
+
dt = datetime.fromisoformat(value)
|
|
188
|
+
except ValueError:
|
|
189
|
+
try:
|
|
190
|
+
dt = datetime.fromtimestamp(float(value), tz=UTC)
|
|
191
|
+
except Exception:
|
|
192
|
+
return None
|
|
193
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
|
194
|
+
if isinstance(value, int | float):
|
|
195
|
+
return datetime.fromtimestamp(float(value), tz=UTC)
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _time_record_from_dict(payload: dict[str, Any] | None) -> TimeRecord:
|
|
200
|
+
payload = payload or {}
|
|
201
|
+
event_time = payload.get("event_time")
|
|
202
|
+
if not isinstance(event_time, int | float):
|
|
203
|
+
try:
|
|
204
|
+
event_time = float(event_time)
|
|
205
|
+
except Exception:
|
|
206
|
+
event_time = float(time.time())
|
|
207
|
+
message_time = payload.get("message_time")
|
|
208
|
+
if message_time is not None:
|
|
209
|
+
try:
|
|
210
|
+
message_time = int(message_time)
|
|
211
|
+
except Exception:
|
|
212
|
+
message_time = None
|
|
213
|
+
return TimeRecord(event_time=event_time, message_time=message_time)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
|
|
217
|
+
base_kwargs = {
|
|
218
|
+
"system_instance_id": payload.get("system_instance_id", ""),
|
|
219
|
+
"time_record": _time_record_from_dict(payload.get("time_record")),
|
|
220
|
+
"metadata": payload.get("metadata") or {},
|
|
221
|
+
"event_metadata": payload.get("event_metadata"),
|
|
222
|
+
}
|
|
223
|
+
if "actions" in payload:
|
|
224
|
+
return RuntimeEvent(actions=payload.get("actions") or [], **base_kwargs)
|
|
225
|
+
if any(key in payload for key in ("reward", "terminated", "truncated")):
|
|
226
|
+
return EnvironmentEvent(
|
|
227
|
+
reward=float(payload.get("reward", 0.0) or 0.0),
|
|
228
|
+
terminated=bool(payload.get("terminated", False)),
|
|
229
|
+
truncated=bool(payload.get("truncated", False)),
|
|
230
|
+
system_state_before=payload.get("system_state_before"),
|
|
231
|
+
system_state_after=payload.get("system_state_after"),
|
|
232
|
+
**base_kwargs,
|
|
233
|
+
)
|
|
234
|
+
return BaseEvent(**base_kwargs)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlanketMessage:
|
|
238
|
+
content_payload = payload.get("content") or {}
|
|
239
|
+
content = SessionMessageContent(
|
|
240
|
+
text=content_payload.get("text"),
|
|
241
|
+
json_payload=content_payload.get("json_payload"),
|
|
242
|
+
)
|
|
243
|
+
raw_type = (payload.get("message_type") or "").lower()
|
|
244
|
+
if raw_type == "observation":
|
|
245
|
+
normalized_type = "system"
|
|
246
|
+
elif raw_type == "action":
|
|
247
|
+
normalized_type = "assistant"
|
|
248
|
+
elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
|
|
249
|
+
normalized_type = raw_type
|
|
250
|
+
else:
|
|
251
|
+
normalized_type = "system"
|
|
252
|
+
|
|
253
|
+
return SessionEventMarkovBlanketMessage(
|
|
254
|
+
content=content,
|
|
255
|
+
message_type=normalized_type,
|
|
256
|
+
time_record=_time_record_from_dict(payload.get("time_record")),
|
|
257
|
+
metadata=payload.get("metadata") or {},
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
262
|
+
events = [
|
|
263
|
+
_event_from_dict(event)
|
|
264
|
+
for event in payload.get("events", [])
|
|
265
|
+
if isinstance(event, dict)
|
|
266
|
+
]
|
|
267
|
+
messages = [
|
|
268
|
+
_markov_message_from_dict(msg)
|
|
269
|
+
for msg in payload.get("markov_blanket_messages", [])
|
|
270
|
+
if isinstance(msg, dict)
|
|
271
|
+
]
|
|
272
|
+
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(UTC)
|
|
273
|
+
completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
|
|
274
|
+
return SessionTimeStep(
|
|
275
|
+
step_id=payload.get("step_id", ""),
|
|
276
|
+
step_index=int(payload.get("step_index", 0) or 0),
|
|
277
|
+
timestamp=timestamp,
|
|
278
|
+
turn_number=payload.get("turn_number"),
|
|
279
|
+
events=events,
|
|
280
|
+
markov_blanket_messages=messages,
|
|
281
|
+
step_metadata=payload.get("step_metadata") or {},
|
|
282
|
+
completed_at=completed_at,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
|
|
287
|
+
if not isinstance(payload, dict):
|
|
288
|
+
return None
|
|
289
|
+
steps = [
|
|
290
|
+
_step_from_dict(step)
|
|
291
|
+
for step in payload.get("session_time_steps", [])
|
|
292
|
+
if isinstance(step, dict)
|
|
293
|
+
]
|
|
294
|
+
events = [
|
|
295
|
+
_event_from_dict(event)
|
|
296
|
+
for event in payload.get("event_history", [])
|
|
297
|
+
if isinstance(event, dict)
|
|
298
|
+
]
|
|
299
|
+
markov_history = [
|
|
300
|
+
_markov_message_from_dict(msg)
|
|
301
|
+
for msg in payload.get("markov_blanket_message_history", [])
|
|
302
|
+
if isinstance(msg, dict)
|
|
303
|
+
]
|
|
304
|
+
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(UTC)
|
|
305
|
+
metadata = payload.get("metadata") or {}
|
|
306
|
+
session_metadata = payload.get("session_metadata")
|
|
307
|
+
return V3SessionTrace(
|
|
308
|
+
session_id=payload.get("session_id", ""),
|
|
309
|
+
created_at=created_at,
|
|
310
|
+
session_time_steps=steps,
|
|
311
|
+
event_history=events,
|
|
312
|
+
markov_blanket_message_history=markov_history,
|
|
313
|
+
metadata=metadata,
|
|
314
|
+
session_metadata=session_metadata,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
async def _store_trace(
|
|
319
|
+
tracer: SessionTracer | None,
|
|
320
|
+
trace_namespace: dict[str, Any] | None,
|
|
321
|
+
extra_metadata: dict[str, Any] | None = None,
|
|
322
|
+
):
|
|
323
|
+
if tracer is None or not isinstance(trace_namespace, dict):
|
|
324
|
+
return
|
|
325
|
+
session_payload = trace_namespace.get("session_trace")
|
|
326
|
+
if not isinstance(session_payload, dict):
|
|
327
|
+
return
|
|
328
|
+
trace_obj = _session_trace_from_dict(session_payload)
|
|
329
|
+
if trace_obj is None:
|
|
330
|
+
return
|
|
331
|
+
if tracer.db is None:
|
|
332
|
+
await tracer.initialize()
|
|
333
|
+
meta = dict(trace_obj.metadata or {})
|
|
334
|
+
if extra_metadata:
|
|
335
|
+
meta.update(extra_metadata)
|
|
336
|
+
trace_obj.metadata = meta
|
|
337
|
+
await tracer.db.insert_session_trace(trace_obj)
|
|
338
|
+
|
|
131
339
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
132
340
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
133
341
|
|
|
@@ -676,36 +884,44 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
|
|
|
676
884
|
elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
677
885
|
# Handle pip_packages list/tuple
|
|
678
886
|
packages: list[str] = []
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
887
|
+
value_node = kw.value
|
|
888
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
889
|
+
for elt in value_node.elts:
|
|
890
|
+
if isinstance(elt, ast.Constant):
|
|
891
|
+
packages.append(elt.value)
|
|
682
892
|
kwargs[kw.arg] = tuple(packages)
|
|
683
893
|
elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
684
894
|
# Handle extra_local_dirs list/tuple of tuples
|
|
685
895
|
dirs = []
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
896
|
+
value_node = kw.value
|
|
897
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
898
|
+
for elt in value_node.elts:
|
|
899
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
900
|
+
src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
901
|
+
dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
902
|
+
if src and dst:
|
|
903
|
+
dirs.append((src, dst))
|
|
692
904
|
kwargs[kw.arg] = tuple(dirs)
|
|
693
905
|
elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
694
906
|
# Handle secret_names list/tuple
|
|
695
907
|
secrets = []
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
908
|
+
value_node = kw.value
|
|
909
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
910
|
+
for elt in value_node.elts:
|
|
911
|
+
if isinstance(elt, ast.Constant):
|
|
912
|
+
secrets.append(elt.value)
|
|
699
913
|
kwargs[kw.arg] = tuple(secrets)
|
|
700
914
|
elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
701
915
|
# Handle volume_mounts list/tuple of tuples
|
|
702
916
|
mounts = []
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
917
|
+
value_node = kw.value
|
|
918
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
919
|
+
for elt in value_node.elts:
|
|
920
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
921
|
+
name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
922
|
+
mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
923
|
+
if name and mount:
|
|
924
|
+
mounts.append((name, mount))
|
|
709
925
|
kwargs[kw.arg] = tuple(mounts)
|
|
710
926
|
|
|
711
927
|
return ModalDeploymentConfig(**kwargs)
|
|
@@ -778,6 +994,9 @@ def _select_app_choice(app_id: str | None, purpose: str) -> AppChoice:
|
|
|
778
994
|
if not matches:
|
|
779
995
|
available = ", ".join(sorted({c.app_id for c in filtered}))
|
|
780
996
|
raise click.ClickException(f"Task app '{app_id}' not found. Available: {available}")
|
|
997
|
+
exact_matches = [c for c in matches if c.app_id == app_id]
|
|
998
|
+
if len(exact_matches) == 1:
|
|
999
|
+
return exact_matches[0]
|
|
781
1000
|
if len(matches) == 1:
|
|
782
1001
|
return matches[0]
|
|
783
1002
|
# Prefer entries with modal support when required
|
|
@@ -829,6 +1048,71 @@ def _import_task_app_module(
|
|
|
829
1048
|
return module
|
|
830
1049
|
|
|
831
1050
|
|
|
1051
|
+
@contextlib.contextmanager
|
|
1052
|
+
def _safe_import_context() -> Iterator[None]:
|
|
1053
|
+
"""Guard module imports against argparse/uvicorn side effects."""
|
|
1054
|
+
|
|
1055
|
+
original_argv = sys.argv[:]
|
|
1056
|
+
sys.argv = [original_argv[0]] if original_argv else ["python"]
|
|
1057
|
+
|
|
1058
|
+
parser_cls = argparse.ArgumentParser
|
|
1059
|
+
old_parse_args = parser_cls.parse_args
|
|
1060
|
+
|
|
1061
|
+
def _parse_noargs(self, args=None, namespace=None): # type: ignore[override]
|
|
1062
|
+
if args is None:
|
|
1063
|
+
args = []
|
|
1064
|
+
if namespace is None:
|
|
1065
|
+
namespace = argparse.Namespace()
|
|
1066
|
+
try:
|
|
1067
|
+
return old_parse_args(self, args, namespace)
|
|
1068
|
+
except SystemExit:
|
|
1069
|
+
return namespace
|
|
1070
|
+
|
|
1071
|
+
parser_cls.parse_args = _parse_noargs # type: ignore[assignment]
|
|
1072
|
+
|
|
1073
|
+
uvicorn_run = None
|
|
1074
|
+
run_task_app_orig = None
|
|
1075
|
+
try:
|
|
1076
|
+
import uvicorn # type: ignore
|
|
1077
|
+
|
|
1078
|
+
uvicorn_run = uvicorn.run
|
|
1079
|
+
uvicorn.run = lambda *args, **kwargs: None # type: ignore[assignment]
|
|
1080
|
+
except Exception:
|
|
1081
|
+
uvicorn_run = None
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
_task_server_patch = cast(
|
|
1085
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1086
|
+
)
|
|
1087
|
+
run_task_app_orig = cast(Callable[..., Any], _task_server_patch.run_task_app)
|
|
1088
|
+
_task_server_patch.run_task_app = ( # type: ignore[assignment]
|
|
1089
|
+
lambda *args, **kwargs: None
|
|
1090
|
+
)
|
|
1091
|
+
except Exception:
|
|
1092
|
+
run_task_app_orig = None
|
|
1093
|
+
|
|
1094
|
+
try:
|
|
1095
|
+
yield
|
|
1096
|
+
finally:
|
|
1097
|
+
sys.argv = original_argv
|
|
1098
|
+
parser_cls.parse_args = old_parse_args # type: ignore[assignment]
|
|
1099
|
+
if uvicorn_run is not None:
|
|
1100
|
+
try:
|
|
1101
|
+
import uvicorn # type: ignore
|
|
1102
|
+
|
|
1103
|
+
uvicorn.run = uvicorn_run # type: ignore[assignment]
|
|
1104
|
+
except Exception:
|
|
1105
|
+
pass
|
|
1106
|
+
if run_task_app_orig is not None:
|
|
1107
|
+
try:
|
|
1108
|
+
_task_server_patch = cast(
|
|
1109
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1110
|
+
)
|
|
1111
|
+
_task_server_patch.run_task_app = run_task_app_orig # type: ignore[assignment]
|
|
1112
|
+
except Exception:
|
|
1113
|
+
pass
|
|
1114
|
+
|
|
1115
|
+
|
|
832
1116
|
def _load_entry_from_path(
|
|
833
1117
|
path: Path, app_id: str, module_search_roots: Sequence[Path] | None = None
|
|
834
1118
|
) -> TaskAppEntryType:
|
|
@@ -856,13 +1140,14 @@ def _load_entry_from_path(
|
|
|
856
1140
|
|
|
857
1141
|
for module_name, namespace_root in _possible_module_names(resolved, search_roots):
|
|
858
1142
|
try:
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
1143
|
+
with _safe_import_context():
|
|
1144
|
+
module = _import_task_app_module(
|
|
1145
|
+
resolved,
|
|
1146
|
+
module_name,
|
|
1147
|
+
namespace_root=namespace_root,
|
|
1148
|
+
sys_path_roots=search_roots,
|
|
1149
|
+
ensure_namespace=True,
|
|
1150
|
+
)
|
|
866
1151
|
break
|
|
867
1152
|
except Exception as exc: # pragma: no cover - best-effort fallbacks
|
|
868
1153
|
last_error = exc
|
|
@@ -871,13 +1156,14 @@ def _load_entry_from_path(
|
|
|
871
1156
|
if module is None:
|
|
872
1157
|
hashed_name = f"_synth_task_app_{hashlib.md5(str(resolved).encode(), usedforsecurity=False).hexdigest()}"
|
|
873
1158
|
try:
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
1159
|
+
with _safe_import_context():
|
|
1160
|
+
module = _import_task_app_module(
|
|
1161
|
+
resolved,
|
|
1162
|
+
hashed_name,
|
|
1163
|
+
namespace_root=None,
|
|
1164
|
+
sys_path_roots=search_roots,
|
|
1165
|
+
ensure_namespace=False,
|
|
1166
|
+
)
|
|
881
1167
|
except Exception as exc: # pragma: no cover - propagate meaningful error
|
|
882
1168
|
detail = last_error or exc
|
|
883
1169
|
raise click.ClickException(f"Failed to import {resolved}: {detail}") from detail
|
|
@@ -925,7 +1211,10 @@ def _load_entry_from_path(
|
|
|
925
1211
|
if has_required:
|
|
926
1212
|
continue
|
|
927
1213
|
try:
|
|
928
|
-
|
|
1214
|
+
with _safe_import_context():
|
|
1215
|
+
result = attr()
|
|
1216
|
+
except SystemExit:
|
|
1217
|
+
continue
|
|
929
1218
|
except Exception:
|
|
930
1219
|
continue
|
|
931
1220
|
if isinstance(result, TaskAppConfig) and result.app_id == app_id:
|
|
@@ -1021,21 +1310,173 @@ def _resolve_env_paths_for_script(script_path: Path, explicit: Sequence[str]) ->
|
|
|
1021
1310
|
return [env_candidates[choice - 1]]
|
|
1022
1311
|
|
|
1023
1312
|
|
|
1313
|
+
def _path_is_within(child: Path, parent: Path) -> bool:
|
|
1314
|
+
try:
|
|
1315
|
+
child.resolve().relative_to(parent.resolve())
|
|
1316
|
+
return True
|
|
1317
|
+
except Exception:
|
|
1318
|
+
return False
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
@functools.lru_cache(maxsize=16)
|
|
1322
|
+
def _is_modal_shim(path_str: str) -> bool:
|
|
1323
|
+
"""Return True if the candidate CLI path refers to the synth-ai shim."""
|
|
1324
|
+
|
|
1325
|
+
path = Path(path_str)
|
|
1326
|
+
try:
|
|
1327
|
+
resolved = path.resolve(strict=True)
|
|
1328
|
+
except Exception:
|
|
1329
|
+
resolved = path
|
|
1330
|
+
|
|
1331
|
+
if not resolved.exists() or resolved.is_dir():
|
|
1332
|
+
return False
|
|
1333
|
+
|
|
1334
|
+
snippet = ""
|
|
1335
|
+
try:
|
|
1336
|
+
snippet = resolved.read_bytes()[:4096].decode("utf-8", errors="ignore")
|
|
1337
|
+
except Exception:
|
|
1338
|
+
snippet = ""
|
|
1339
|
+
|
|
1340
|
+
shim_markers = (
|
|
1341
|
+
"synth_ai.cli._modal_wrapper",
|
|
1342
|
+
"from modal.__main__ import main",
|
|
1343
|
+
"import modal.__main__",
|
|
1344
|
+
"run_module('modal.__main__'",
|
|
1345
|
+
)
|
|
1346
|
+
if snippet and any(marker in snippet for marker in shim_markers):
|
|
1347
|
+
return True
|
|
1348
|
+
|
|
1349
|
+
try:
|
|
1350
|
+
size = resolved.stat().st_size
|
|
1351
|
+
except Exception:
|
|
1352
|
+
size = None
|
|
1353
|
+
|
|
1354
|
+
if (
|
|
1355
|
+
size is not None
|
|
1356
|
+
and size < 2048
|
|
1357
|
+
and "python" in (snippet.splitlines() or [""])[0]
|
|
1358
|
+
and (
|
|
1359
|
+
"modal.__main__" in snippet
|
|
1360
|
+
or "modal.__main__" in snippet.replace(" ", "")
|
|
1361
|
+
)
|
|
1362
|
+
):
|
|
1363
|
+
return True
|
|
1364
|
+
|
|
1365
|
+
virtual_env = os.environ.get("VIRTUAL_ENV")
|
|
1366
|
+
if virtual_env and _path_is_within(resolved, Path(virtual_env)):
|
|
1367
|
+
return True
|
|
1368
|
+
|
|
1369
|
+
if _path_is_within(resolved, REPO_ROOT):
|
|
1370
|
+
return True
|
|
1371
|
+
|
|
1372
|
+
uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
|
|
1373
|
+
return uv_tools_dir.exists() and _path_is_within(resolved, uv_tools_dir)
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
def _find_modal_executable(modal_cli: str) -> tuple[str | None, str | None]:
|
|
1377
|
+
"""Return the first non-shim executable and the first shim discovered on PATH."""
|
|
1378
|
+
|
|
1379
|
+
if not modal_cli:
|
|
1380
|
+
modal_cli = "modal"
|
|
1381
|
+
|
|
1382
|
+
candidate_path = Path(modal_cli).expanduser()
|
|
1383
|
+
if candidate_path.is_absolute() or len(candidate_path.parts) > 1:
|
|
1384
|
+
resolved_candidate = candidate_path
|
|
1385
|
+
if not resolved_candidate.is_absolute():
|
|
1386
|
+
resolved_candidate = (Path.cwd() / resolved_candidate).resolve()
|
|
1387
|
+
else:
|
|
1388
|
+
resolved_candidate = resolved_candidate.resolve()
|
|
1389
|
+
if not resolved_candidate.exists():
|
|
1390
|
+
raise click.ClickException(f"--modal-cli path does not exist: {resolved_candidate}")
|
|
1391
|
+
if not os.access(resolved_candidate, os.X_OK):
|
|
1392
|
+
raise click.ClickException(f"--modal-cli is not executable: {resolved_candidate}")
|
|
1393
|
+
return str(resolved_candidate), None
|
|
1394
|
+
|
|
1395
|
+
path_env = os.environ.get("PATH", "")
|
|
1396
|
+
if not path_env:
|
|
1397
|
+
return None, None
|
|
1398
|
+
|
|
1399
|
+
seen_dirs: set[str] = set()
|
|
1400
|
+
seen_candidates: set[str] = set()
|
|
1401
|
+
shim_path: str | None = None
|
|
1402
|
+
|
|
1403
|
+
for raw_entry in path_env.split(os.pathsep):
|
|
1404
|
+
if not raw_entry:
|
|
1405
|
+
continue
|
|
1406
|
+
try:
|
|
1407
|
+
resolved_entry = str(Path(raw_entry).resolve())
|
|
1408
|
+
except Exception:
|
|
1409
|
+
resolved_entry = os.path.normpath(raw_entry)
|
|
1410
|
+
if resolved_entry in seen_dirs:
|
|
1411
|
+
continue
|
|
1412
|
+
seen_dirs.add(resolved_entry)
|
|
1413
|
+
|
|
1414
|
+
candidate = shutil.which(modal_cli, path=raw_entry)
|
|
1415
|
+
if candidate is None:
|
|
1416
|
+
continue
|
|
1417
|
+
if candidate in seen_candidates:
|
|
1418
|
+
continue
|
|
1419
|
+
seen_candidates.add(candidate)
|
|
1420
|
+
|
|
1421
|
+
if _is_modal_shim(candidate):
|
|
1422
|
+
if shim_path is None:
|
|
1423
|
+
shim_path = candidate
|
|
1424
|
+
continue
|
|
1425
|
+
return candidate, shim_path
|
|
1426
|
+
|
|
1427
|
+
return None, shim_path
|
|
1428
|
+
|
|
1429
|
+
|
|
1024
1430
|
def _modal_command_prefix(modal_cli: str) -> list[str]:
|
|
1025
1431
|
"""Resolve a command prefix for invoking the Modal CLI within the active environment."""
|
|
1026
|
-
|
|
1432
|
+
|
|
1433
|
+
force_wrapper_env = os.environ.get("SYNTH_FORCE_MODAL_WRAPPER", "").strip().lower()
|
|
1434
|
+
if force_wrapper_env in {"1", "true", "yes"}:
|
|
1435
|
+
click.secho(
|
|
1436
|
+
"[modal-prefix] SYNTH_FORCE_MODAL_WRAPPER=1 -> using in-process wrapper",
|
|
1437
|
+
fg="yellow",
|
|
1438
|
+
)
|
|
1027
1439
|
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1028
1440
|
|
|
1029
|
-
|
|
1030
|
-
if
|
|
1031
|
-
|
|
1441
|
+
lookup = modal_cli or "modal"
|
|
1442
|
+
spec = importlib.util.find_spec("modal") if lookup == "modal" else None
|
|
1443
|
+
|
|
1444
|
+
preferred, shim_candidate = _find_modal_executable(lookup)
|
|
1445
|
+
if preferred is not None:
|
|
1446
|
+
detail = f"[modal-prefix] modal_cli={lookup} selected={preferred}"
|
|
1447
|
+
if lookup == "modal":
|
|
1448
|
+
detail += f" spec={'yes' if spec else 'no'}"
|
|
1449
|
+
click.secho(detail, fg="cyan")
|
|
1450
|
+
return [preferred]
|
|
1451
|
+
|
|
1452
|
+
if lookup != "modal":
|
|
1453
|
+
raise click.ClickException(f"Modal CLI not found (looked for '{lookup}')")
|
|
1454
|
+
|
|
1455
|
+
if spec is not None:
|
|
1456
|
+
warning = "[modal-prefix] Using synth-ai modal shim; pass --modal-cli /path/to/modal to override."
|
|
1457
|
+
if shim_candidate is not None:
|
|
1458
|
+
warning = (
|
|
1459
|
+
f"[modal-prefix] Using synth-ai modal shim at {shim_candidate}; "
|
|
1460
|
+
"pass --modal-cli /path/to/modal to override."
|
|
1461
|
+
)
|
|
1462
|
+
click.secho(warning, fg="yellow")
|
|
1463
|
+
click.secho(
|
|
1464
|
+
"[modal-prefix] modal_cli=modal selected=module-wrapper spec=yes",
|
|
1465
|
+
fg="yellow",
|
|
1466
|
+
)
|
|
1467
|
+
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1032
1468
|
|
|
1033
|
-
if
|
|
1469
|
+
if shim_candidate is not None:
|
|
1034
1470
|
raise click.ClickException(
|
|
1035
|
-
"Modal CLI
|
|
1036
|
-
"
|
|
1471
|
+
"Modal CLI resolution found the synth-ai shim but the 'modal' package "
|
|
1472
|
+
"is not importable in this environment. Install the official Modal CLI "
|
|
1473
|
+
"or pass --modal-cli with its path."
|
|
1037
1474
|
)
|
|
1038
|
-
|
|
1475
|
+
|
|
1476
|
+
raise click.ClickException(
|
|
1477
|
+
"Modal CLI not found. Install the 'modal' package in this environment or pass "
|
|
1478
|
+
"--modal-cli with an explicit path."
|
|
1479
|
+
)
|
|
1039
1480
|
|
|
1040
1481
|
|
|
1041
1482
|
def _build_modal_app_wrapper(original_script: Path) -> tuple[Path, Path]:
|
|
@@ -1170,8 +1611,15 @@ def _run_modal_script(
|
|
|
1170
1611
|
if modal_name and command == "deploy":
|
|
1171
1612
|
cmd.extend(["--name", modal_name])
|
|
1172
1613
|
if dry_run:
|
|
1173
|
-
click.echo(
|
|
1614
|
+
click.echo(
|
|
1615
|
+
"Dry run: " + " ".join(shlex.quote(component) for component in cmd),
|
|
1616
|
+
err=False,
|
|
1617
|
+
)
|
|
1174
1618
|
return
|
|
1619
|
+
click.secho(
|
|
1620
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1621
|
+
fg="cyan",
|
|
1622
|
+
)
|
|
1175
1623
|
try:
|
|
1176
1624
|
# Stream output live for better diagnostics
|
|
1177
1625
|
proc = subprocess.Popen(
|
|
@@ -1426,7 +1874,6 @@ def _run_modal_with_entry(
|
|
|
1426
1874
|
inline_secret_values=inline_secret_values,
|
|
1427
1875
|
)
|
|
1428
1876
|
cmd = [*_modal_command_prefix(modal_cli), command, str(script_path)]
|
|
1429
|
-
|
|
1430
1877
|
if modal_name and command == "deploy":
|
|
1431
1878
|
cmd.extend(["--name", modal_name])
|
|
1432
1879
|
|
|
@@ -1441,9 +1888,13 @@ def _run_modal_with_entry(
|
|
|
1441
1888
|
proc_env["PYTHONPATH"] = os.pathsep.join(list(dict.fromkeys(pythonpath_entries)))
|
|
1442
1889
|
|
|
1443
1890
|
if dry_run:
|
|
1444
|
-
click.echo("Dry run: " + " ".join(cmd))
|
|
1891
|
+
click.echo("Dry run: " + " ".join(shlex.quote(component) for component in cmd))
|
|
1445
1892
|
script_path.unlink(missing_ok=True)
|
|
1446
1893
|
return
|
|
1894
|
+
click.secho(
|
|
1895
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1896
|
+
fg="cyan",
|
|
1897
|
+
)
|
|
1447
1898
|
|
|
1448
1899
|
try:
|
|
1449
1900
|
# Stream output live for better diagnostics
|
|
@@ -1528,6 +1979,10 @@ def _parse_env_file(path: Path) -> dict[str, str]:
|
|
|
1528
1979
|
|
|
1529
1980
|
|
|
1530
1981
|
def _interactive_fill_env(env_path: Path) -> Path | None:
|
|
1982
|
+
if not sys.stdin.isatty():
|
|
1983
|
+
raise click.ClickException(
|
|
1984
|
+
"ENVIRONMENT_API_KEY missing. Provide --env-file or run `synth-ai setup` in an interactive shell to create one."
|
|
1985
|
+
)
|
|
1531
1986
|
existing = _parse_env_file(env_path) if env_path.exists() else {}
|
|
1532
1987
|
|
|
1533
1988
|
def _prompt(label: str, *, default: str = "", required: bool) -> str | None:
|
|
@@ -1567,6 +2022,10 @@ def _ensure_env_values(env_paths: list[Path], fallback_dir: Path) -> None:
|
|
|
1567
2022
|
if (os.environ.get("ENVIRONMENT_API_KEY") or "").strip():
|
|
1568
2023
|
return
|
|
1569
2024
|
target = env_paths[0] if env_paths else (fallback_dir / ".env").resolve()
|
|
2025
|
+
click.echo(
|
|
2026
|
+
"⚠️ ENVIRONMENT_API_KEY not set. Run `uvx synth-ai setup`, "
|
|
2027
|
+
"or pass --env-file pointing at a .env with ENVIRONMENT_API_KEY."
|
|
2028
|
+
)
|
|
1570
2029
|
result = _interactive_fill_env(target)
|
|
1571
2030
|
if result is None:
|
|
1572
2031
|
raise click.ClickException("ENVIRONMENT_API_KEY required to continue")
|
|
@@ -1590,7 +2049,7 @@ def _deploy_entry(
|
|
|
1590
2049
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1591
2050
|
)
|
|
1592
2051
|
|
|
1593
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2052
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1594
2053
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1595
2054
|
_run_modal_with_entry(
|
|
1596
2055
|
entry,
|
|
@@ -1617,7 +2076,7 @@ def _modal_serve_entry(
|
|
|
1617
2076
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1618
2077
|
)
|
|
1619
2078
|
|
|
1620
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2079
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1621
2080
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1622
2081
|
_run_modal_with_entry(
|
|
1623
2082
|
entry,
|
|
@@ -1648,6 +2107,255 @@ def list_apps() -> None:
|
|
|
1648
2107
|
click.echo(f"- {entry.app_id}{aliases}: {entry.description}")
|
|
1649
2108
|
|
|
1650
2109
|
|
|
2110
|
+
@task_app_group.command("validate")
|
|
2111
|
+
@click.argument("app_id", type=str, required=True)
|
|
2112
|
+
@click.option(
|
|
2113
|
+
"--url",
|
|
2114
|
+
type=str,
|
|
2115
|
+
default=None,
|
|
2116
|
+
help="Task app URL to validate (if not provided, starts a local server)",
|
|
2117
|
+
)
|
|
2118
|
+
@click.option(
|
|
2119
|
+
"--port",
|
|
2120
|
+
type=int,
|
|
2121
|
+
default=8765,
|
|
2122
|
+
help="Port to use for temporary server (default: 8765)",
|
|
2123
|
+
)
|
|
2124
|
+
@click.option(
|
|
2125
|
+
"--api-key",
|
|
2126
|
+
type=str,
|
|
2127
|
+
default=None,
|
|
2128
|
+
envvar="ENVIRONMENT_API_KEY",
|
|
2129
|
+
help="API key for authentication (default: $ENVIRONMENT_API_KEY)",
|
|
2130
|
+
)
|
|
2131
|
+
@click.option(
|
|
2132
|
+
"--min-instances",
|
|
2133
|
+
type=int,
|
|
2134
|
+
default=10,
|
|
2135
|
+
help="Minimum number of task instances required (default: 10)",
|
|
2136
|
+
)
|
|
2137
|
+
@click.option(
|
|
2138
|
+
"--verbose",
|
|
2139
|
+
"-v",
|
|
2140
|
+
is_flag=True,
|
|
2141
|
+
help="Show detailed information about the task app",
|
|
2142
|
+
)
|
|
2143
|
+
@click.option(
|
|
2144
|
+
"--json",
|
|
2145
|
+
"output_json",
|
|
2146
|
+
is_flag=True,
|
|
2147
|
+
help="Output results as JSON",
|
|
2148
|
+
)
|
|
2149
|
+
def validate_task_app_cmd(
|
|
2150
|
+
app_id: str,
|
|
2151
|
+
url: str | None,
|
|
2152
|
+
port: int,
|
|
2153
|
+
api_key: str | None,
|
|
2154
|
+
min_instances: int,
|
|
2155
|
+
verbose: bool,
|
|
2156
|
+
output_json: bool,
|
|
2157
|
+
) -> None:
|
|
2158
|
+
"""Validate a task app deployment readiness.
|
|
2159
|
+
|
|
2160
|
+
This command verifies that a task app is properly configured and ready to run
|
|
2161
|
+
by checking all required HTTP endpoints, authentication, and task availability.
|
|
2162
|
+
|
|
2163
|
+
By default, it starts a temporary local server for validation. You can also
|
|
2164
|
+
validate a remote deployment by passing --url.
|
|
2165
|
+
|
|
2166
|
+
\b
|
|
2167
|
+
What gets validated:
|
|
2168
|
+
• Root endpoint (/) responds correctly
|
|
2169
|
+
• Health endpoint (/health) is accessible with proper authentication
|
|
2170
|
+
• Info endpoint (/info) returns valid task metadata
|
|
2171
|
+
• Task info endpoint (/task_info) provides task instances
|
|
2172
|
+
• Rollout endpoint (/rollout) is registered
|
|
2173
|
+
• At least N task instances are available (default: 10)
|
|
2174
|
+
|
|
2175
|
+
\b
|
|
2176
|
+
Examples:
|
|
2177
|
+
|
|
2178
|
+
\b
|
|
2179
|
+
Validate grpo-crafter (starts local server automatically):
|
|
2180
|
+
$ synth-ai task-app validate grpo-crafter
|
|
2181
|
+
|
|
2182
|
+
\b
|
|
2183
|
+
Validate sokoban with verbose output:
|
|
2184
|
+
$ synth-ai task-app validate sokoban --verbose
|
|
2185
|
+
|
|
2186
|
+
\b
|
|
2187
|
+
Validate with custom port:
|
|
2188
|
+
$ synth-ai task-app validate sokoban --port 9000
|
|
2189
|
+
|
|
2190
|
+
\b
|
|
2191
|
+
Validate a remote deployment:
|
|
2192
|
+
$ synth-ai task-app validate grpo-crafter --url https://my-crafter.modal.run
|
|
2193
|
+
|
|
2194
|
+
\b
|
|
2195
|
+
Require at least 20 task instances:
|
|
2196
|
+
$ synth-ai task-app validate grpo-crafter --min-instances 20
|
|
2197
|
+
|
|
2198
|
+
\b
|
|
2199
|
+
Get JSON output for automation:
|
|
2200
|
+
$ synth-ai task-app validate sokoban --json
|
|
2201
|
+
|
|
2202
|
+
\b
|
|
2203
|
+
Common use cases:
|
|
2204
|
+
• Pre-deployment verification: Check task app works before deploying to Modal
|
|
2205
|
+
• CI/CD integration: Use --json flag for automated validation in pipelines
|
|
2206
|
+
• Debug failing deployments: Use --verbose to see detailed endpoint responses
|
|
2207
|
+
• Test API key configuration: Verify authentication is set up correctly
|
|
2208
|
+
"""
|
|
2209
|
+
import asyncio
|
|
2210
|
+
import socket
|
|
2211
|
+
import subprocess
|
|
2212
|
+
import tempfile
|
|
2213
|
+
import time
|
|
2214
|
+
|
|
2215
|
+
# Import the validate_task_app function defined in this module
|
|
2216
|
+
from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
|
|
2217
|
+
|
|
2218
|
+
proc = None
|
|
2219
|
+
task_app_url = url
|
|
2220
|
+
|
|
2221
|
+
try:
|
|
2222
|
+
# If no URL provided, start a temporary server
|
|
2223
|
+
if not task_app_url:
|
|
2224
|
+
# Find an available port
|
|
2225
|
+
def is_port_available(port: int) -> bool:
|
|
2226
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
2227
|
+
try:
|
|
2228
|
+
s.bind(("", port))
|
|
2229
|
+
return True
|
|
2230
|
+
except OSError:
|
|
2231
|
+
return False
|
|
2232
|
+
|
|
2233
|
+
while not is_port_available(port):
|
|
2234
|
+
port += 1
|
|
2235
|
+
|
|
2236
|
+
task_app_url = f"http://localhost:{port}"
|
|
2237
|
+
|
|
2238
|
+
if not output_json:
|
|
2239
|
+
click.echo(f"Starting temporary {app_id} server on port {port}...")
|
|
2240
|
+
|
|
2241
|
+
# Start the server in background
|
|
2242
|
+
env = os.environ.copy()
|
|
2243
|
+
if api_key:
|
|
2244
|
+
env["ENVIRONMENT_API_KEY"] = api_key
|
|
2245
|
+
|
|
2246
|
+
# Create a temporary trace DB and trace dir to avoid prompts
|
|
2247
|
+
import tempfile
|
|
2248
|
+
temp_dir = tempfile.mkdtemp()
|
|
2249
|
+
temp_trace_db = os.path.join(temp_dir, "validate_trace.db")
|
|
2250
|
+
temp_trace_dir = os.path.join(temp_dir, "traces")
|
|
2251
|
+
os.makedirs(temp_trace_dir, exist_ok=True)
|
|
2252
|
+
|
|
2253
|
+
proc = subprocess.Popen(
|
|
2254
|
+
[
|
|
2255
|
+
"uv",
|
|
2256
|
+
"run",
|
|
2257
|
+
"synth-ai",
|
|
2258
|
+
"task-app",
|
|
2259
|
+
"serve",
|
|
2260
|
+
app_id,
|
|
2261
|
+
"--port",
|
|
2262
|
+
str(port),
|
|
2263
|
+
"--no-reload",
|
|
2264
|
+
"--trace",
|
|
2265
|
+
temp_trace_dir,
|
|
2266
|
+
"--trace-db",
|
|
2267
|
+
temp_trace_db,
|
|
2268
|
+
],
|
|
2269
|
+
env=env,
|
|
2270
|
+
stdin=subprocess.PIPE, # Add stdin to handle any prompts
|
|
2271
|
+
stdout=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2272
|
+
stderr=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2273
|
+
text=True,
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
# Write empty input to stdin to skip any prompts
|
|
2277
|
+
if proc.stdin:
|
|
2278
|
+
try:
|
|
2279
|
+
proc.stdin.write("\n")
|
|
2280
|
+
proc.stdin.flush()
|
|
2281
|
+
proc.stdin.close()
|
|
2282
|
+
except Exception:
|
|
2283
|
+
pass
|
|
2284
|
+
|
|
2285
|
+
# Wait for server to be ready
|
|
2286
|
+
if not output_json:
|
|
2287
|
+
click.echo("Waiting for server to start...")
|
|
2288
|
+
|
|
2289
|
+
import httpx
|
|
2290
|
+
for _attempt in range(60): # 30 seconds timeout
|
|
2291
|
+
try:
|
|
2292
|
+
async def check_health():
|
|
2293
|
+
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
2294
|
+
resp = await client.get(f"{task_app_url}/")
|
|
2295
|
+
return resp.status_code == 200
|
|
2296
|
+
|
|
2297
|
+
if asyncio.run(check_health()):
|
|
2298
|
+
break
|
|
2299
|
+
except Exception:
|
|
2300
|
+
pass
|
|
2301
|
+
|
|
2302
|
+
# Check if process died
|
|
2303
|
+
if proc.poll() is not None:
|
|
2304
|
+
stderr_output = ""
|
|
2305
|
+
if proc.stderr and not output_json:
|
|
2306
|
+
stderr_output = proc.stderr.read()
|
|
2307
|
+
click.echo(click.style("✗ Server process exited unexpectedly", fg="red"), err=True)
|
|
2308
|
+
if stderr_output and not output_json:
|
|
2309
|
+
click.echo(f"Error output:\n{stderr_output}", err=True)
|
|
2310
|
+
sys.exit(1)
|
|
2311
|
+
|
|
2312
|
+
time.sleep(0.5)
|
|
2313
|
+
else:
|
|
2314
|
+
click.echo(click.style("✗ Server failed to start within 30 seconds", fg="red"), err=True)
|
|
2315
|
+
sys.exit(1)
|
|
2316
|
+
|
|
2317
|
+
if not output_json:
|
|
2318
|
+
click.echo(click.style("✓ Server started", fg="green"))
|
|
2319
|
+
click.echo()
|
|
2320
|
+
|
|
2321
|
+
# Ensure URL doesn't have trailing slash
|
|
2322
|
+
task_app_url = task_app_url.rstrip("/")
|
|
2323
|
+
|
|
2324
|
+
async def _run() -> tuple[bool, dict[str, Any]]:
|
|
2325
|
+
return await validate_task_app(
|
|
2326
|
+
url=task_app_url,
|
|
2327
|
+
api_key=api_key,
|
|
2328
|
+
min_instances=min_instances,
|
|
2329
|
+
verbose=verbose,
|
|
2330
|
+
)
|
|
2331
|
+
|
|
2332
|
+
success, results = asyncio.run(_run())
|
|
2333
|
+
|
|
2334
|
+
if output_json:
|
|
2335
|
+
import json as _json
|
|
2336
|
+
click.echo(_json.dumps(results, indent=2))
|
|
2337
|
+
|
|
2338
|
+
sys.exit(0 if success else 1)
|
|
2339
|
+
|
|
2340
|
+
finally:
|
|
2341
|
+
# Cleanup: stop the temporary server
|
|
2342
|
+
if proc is not None:
|
|
2343
|
+
if not output_json:
|
|
2344
|
+
click.echo("\nStopping temporary server...")
|
|
2345
|
+
try:
|
|
2346
|
+
proc.terminate()
|
|
2347
|
+
proc.wait(timeout=5)
|
|
2348
|
+
except Exception:
|
|
2349
|
+
proc.kill()
|
|
2350
|
+
|
|
2351
|
+
# Cleanup temp trace DB
|
|
2352
|
+
if not url and 'temp_dir' in locals():
|
|
2353
|
+
import contextlib
|
|
2354
|
+
import shutil
|
|
2355
|
+
with contextlib.suppress(Exception):
|
|
2356
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
2357
|
+
|
|
2358
|
+
|
|
1651
2359
|
def _load_env_files_into_process(paths: Sequence[str]) -> None:
|
|
1652
2360
|
for p in paths:
|
|
1653
2361
|
try:
|
|
@@ -1904,7 +2612,9 @@ def serve_task_group(
|
|
|
1904
2612
|
)
|
|
1905
2613
|
|
|
1906
2614
|
|
|
1907
|
-
def _determine_env_files(
|
|
2615
|
+
def _determine_env_files(
|
|
2616
|
+
entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
|
|
2617
|
+
) -> list[Path]:
|
|
1908
2618
|
resolved: list[Path] = []
|
|
1909
2619
|
for candidate in user_env_files:
|
|
1910
2620
|
p = Path(candidate).expanduser()
|
|
@@ -1914,30 +2624,46 @@ def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str])
|
|
|
1914
2624
|
if resolved:
|
|
1915
2625
|
return resolved
|
|
1916
2626
|
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
2627
|
+
declared: list[Path] = []
|
|
2628
|
+
for candidate in getattr(entry, "env_files", ()) or ():
|
|
2629
|
+
try:
|
|
2630
|
+
p = Path(candidate).expanduser()
|
|
2631
|
+
except Exception:
|
|
2632
|
+
continue
|
|
2633
|
+
if p.exists() and p.is_file():
|
|
2634
|
+
declared.append(p)
|
|
2635
|
+
if declared:
|
|
2636
|
+
return declared
|
|
1925
2637
|
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
for repo_file in repo_env_files:
|
|
1930
|
-
if repo_file not in env_candidates:
|
|
1931
|
-
env_candidates.append(repo_file)
|
|
2638
|
+
def _append_candidate(collection: list[Path], candidate: Path) -> None:
|
|
2639
|
+
if candidate.exists() and candidate.is_file() and candidate not in collection:
|
|
2640
|
+
collection.append(candidate)
|
|
1932
2641
|
|
|
1933
|
-
|
|
1934
|
-
raise click.ClickException("No env file found. Pass --env-file explicitly.")
|
|
2642
|
+
auto_candidates: list[Path] = []
|
|
1935
2643
|
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
2644
|
+
search_dirs: list[Path] = []
|
|
2645
|
+
if original_path is not None:
|
|
2646
|
+
search_dirs.append(original_path.parent.resolve())
|
|
2647
|
+
for parent in original_path.parent.resolve().parents:
|
|
2648
|
+
search_dirs.append(parent)
|
|
2649
|
+
cwd = Path.cwd().resolve()
|
|
2650
|
+
if cwd not in search_dirs:
|
|
2651
|
+
search_dirs.append(cwd)
|
|
2652
|
+
repo_root = REPO_ROOT.resolve()
|
|
2653
|
+
if repo_root not in search_dirs:
|
|
2654
|
+
search_dirs.append(repo_root)
|
|
2655
|
+
|
|
2656
|
+
for directory in search_dirs:
|
|
2657
|
+
_append_candidate(auto_candidates, directory / ".env")
|
|
2658
|
+
for candidate in sorted(directory.glob("*.env")):
|
|
2659
|
+
_append_candidate(auto_candidates, candidate)
|
|
2660
|
+
|
|
2661
|
+
if auto_candidates:
|
|
2662
|
+
return [auto_candidates[0]]
|
|
2663
|
+
|
|
2664
|
+
raise click.ClickException(
|
|
2665
|
+
"No .env file discovered automatically. Pass --env-file /path/to/.env or generate one with `uvx synth-ai setup`."
|
|
2666
|
+
)
|
|
1941
2667
|
|
|
1942
2668
|
|
|
1943
2669
|
def _ensure_port_free(port: int, host: str, *, force: bool) -> None:
|
|
@@ -2239,7 +2965,14 @@ def deploy_app(
|
|
|
2239
2965
|
def modal_serve_app(
|
|
2240
2966
|
app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
|
|
2241
2967
|
) -> None:
|
|
2242
|
-
|
|
2968
|
+
click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
|
|
2969
|
+
try:
|
|
2970
|
+
choice = _select_app_choice(app_id, purpose="modal-serve")
|
|
2971
|
+
except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
|
|
2972
|
+
raise click.ClickException(
|
|
2973
|
+
f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
|
|
2974
|
+
"Make sure you're running the Click CLI (synth_ai.cli:cli)."
|
|
2975
|
+
) from exc
|
|
2243
2976
|
|
|
2244
2977
|
if choice.modal_script:
|
|
2245
2978
|
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
@@ -2248,6 +2981,7 @@ def modal_serve_app(
|
|
|
2248
2981
|
return
|
|
2249
2982
|
|
|
2250
2983
|
entry = choice.ensure_entry()
|
|
2984
|
+
click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
|
|
2251
2985
|
_modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
|
|
2252
2986
|
|
|
2253
2987
|
|
|
@@ -2477,22 +3211,60 @@ def register(cli: click.Group) -> None:
|
|
|
2477
3211
|
cli.add_command(serve_command)
|
|
2478
3212
|
cli.add_command(task_app_group)
|
|
2479
3213
|
cli.add_command(eval_command)
|
|
3214
|
+
cli.add_command(filter_command)
|
|
2480
3215
|
|
|
2481
3216
|
|
|
2482
|
-
@click.command(
|
|
3217
|
+
@click.command(
|
|
3218
|
+
"eval",
|
|
3219
|
+
help="Run one-off rollouts against a task app and print judge/eval summaries.",
|
|
3220
|
+
)
|
|
2483
3221
|
@click.argument("app_id", type=str, required=False)
|
|
2484
|
-
@click.option(
|
|
3222
|
+
@click.option(
|
|
3223
|
+
"--config",
|
|
3224
|
+
type=click.Path(),
|
|
3225
|
+
default=None,
|
|
3226
|
+
help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
|
|
3227
|
+
)
|
|
2485
3228
|
@click.option(
|
|
2486
3229
|
"--url",
|
|
2487
3230
|
"task_app_url",
|
|
2488
3231
|
type=str,
|
|
2489
3232
|
default=None,
|
|
2490
|
-
help="Base URL of a running task app (
|
|
3233
|
+
help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
|
|
3234
|
+
)
|
|
3235
|
+
@click.option(
|
|
3236
|
+
"--seeds",
|
|
3237
|
+
default="0,1,2,3,4",
|
|
3238
|
+
help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
|
|
2491
3239
|
)
|
|
2492
|
-
@click.option("--seeds", default="0,1,2,3,4", help="Comma-separated seeds/indices to evaluate")
|
|
2493
3240
|
@click.option("--split", default="train", show_default=True, help="Dataset split to use")
|
|
2494
|
-
@click.option(
|
|
2495
|
-
|
|
3241
|
+
@click.option(
|
|
3242
|
+
"--model",
|
|
3243
|
+
default=None,
|
|
3244
|
+
help="Model identifier. When omitted the CLI will prompt based on task metadata.",
|
|
3245
|
+
)
|
|
3246
|
+
@click.option(
|
|
3247
|
+
"--env-file",
|
|
3248
|
+
multiple=True,
|
|
3249
|
+
type=click.Path(),
|
|
3250
|
+
help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
|
|
3251
|
+
)
|
|
3252
|
+
@click.option(
|
|
3253
|
+
"--trace-db",
|
|
3254
|
+
default="traces/v3/eval_traces.db",
|
|
3255
|
+
show_default=True,
|
|
3256
|
+
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3257
|
+
)
|
|
3258
|
+
@click.option(
|
|
3259
|
+
"--metadata",
|
|
3260
|
+
multiple=True,
|
|
3261
|
+
help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
|
|
3262
|
+
)
|
|
3263
|
+
@click.option(
|
|
3264
|
+
"--metadata-sql",
|
|
3265
|
+
default=None,
|
|
3266
|
+
help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
|
|
3267
|
+
)
|
|
2496
3268
|
def eval_command(
|
|
2497
3269
|
app_id: str | None,
|
|
2498
3270
|
config: str | None,
|
|
@@ -2501,8 +3273,17 @@ def eval_command(
|
|
|
2501
3273
|
split: str,
|
|
2502
3274
|
model: str | None,
|
|
2503
3275
|
env_file: Sequence[str],
|
|
3276
|
+
trace_db: str,
|
|
3277
|
+
metadata: Sequence[str],
|
|
3278
|
+
metadata_sql: str | None,
|
|
2504
3279
|
) -> None:
|
|
2505
|
-
"""Run
|
|
3280
|
+
"""Run rollouts against a task app and report judge statistics.
|
|
3281
|
+
|
|
3282
|
+
By default the command spins up the selected task app in-process, executes the
|
|
3283
|
+
requested seeds, and prints aggregate scores (official and custom judges). When
|
|
3284
|
+
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3285
|
+
forward authentication headers to the running service.
|
|
3286
|
+
"""
|
|
2506
3287
|
cfg: dict[str, Any] = {}
|
|
2507
3288
|
config_path: Path | None = None
|
|
2508
3289
|
if config:
|
|
@@ -2531,6 +3312,50 @@ def eval_command(
|
|
|
2531
3312
|
|
|
2532
3313
|
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
2533
3314
|
|
|
3315
|
+
metadata_filters: dict[str, str] = {}
|
|
3316
|
+
cfg_metadata = cfg.get("metadata")
|
|
3317
|
+
if isinstance(cfg_metadata, dict):
|
|
3318
|
+
for key, value in cfg_metadata.items():
|
|
3319
|
+
metadata_filters[str(key)] = str(value)
|
|
3320
|
+
elif isinstance(cfg_metadata, list):
|
|
3321
|
+
for item in cfg_metadata:
|
|
3322
|
+
if isinstance(item, str) and "=" in item:
|
|
3323
|
+
key, value = item.split("=", 1)
|
|
3324
|
+
metadata_filters[key.strip()] = value.strip()
|
|
3325
|
+
|
|
3326
|
+
for item in metadata or ():
|
|
3327
|
+
if "=" not in item:
|
|
3328
|
+
raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
|
|
3329
|
+
key, value = item.split("=", 1)
|
|
3330
|
+
key = key.strip()
|
|
3331
|
+
value = value.strip()
|
|
3332
|
+
if not key or not value:
|
|
3333
|
+
raise click.ClickException(f"Invalid metadata filter: {item}")
|
|
3334
|
+
metadata_filters[key] = value
|
|
3335
|
+
|
|
3336
|
+
metadata_sql_query: str | None = None
|
|
3337
|
+
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3338
|
+
if isinstance(cfg_metadata_sql, dict):
|
|
3339
|
+
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3340
|
+
elif isinstance(cfg_metadata_sql, str):
|
|
3341
|
+
metadata_sql_query = cfg_metadata_sql
|
|
3342
|
+
|
|
3343
|
+
if metadata_sql:
|
|
3344
|
+
metadata_sql_query = metadata_sql
|
|
3345
|
+
if metadata_sql_query is not None:
|
|
3346
|
+
metadata_sql_query = str(metadata_sql_query)
|
|
3347
|
+
|
|
3348
|
+
trace_db_url: str | None = None
|
|
3349
|
+
trace_db = (trace_db or "").strip()
|
|
3350
|
+
if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
|
|
3351
|
+
if "://" in trace_db:
|
|
3352
|
+
trace_db_url = trace_db
|
|
3353
|
+
else:
|
|
3354
|
+
trace_path = Path(trace_db).expanduser()
|
|
3355
|
+
trace_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3356
|
+
trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
|
|
3357
|
+
trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
|
|
3358
|
+
|
|
2534
3359
|
# Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
|
|
2535
3360
|
if cfg.get("model") and not model:
|
|
2536
3361
|
model = str(cfg["model"]) # type: ignore[index]
|
|
@@ -2550,14 +3375,16 @@ def eval_command(
|
|
|
2550
3375
|
elif isinstance(ef, list):
|
|
2551
3376
|
env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
|
|
2552
3377
|
|
|
3378
|
+
choice_for_env: AppChoice | None = None
|
|
2553
3379
|
entry: TaskAppEntryType | None = None
|
|
2554
3380
|
if task_app_url is None:
|
|
2555
|
-
|
|
2556
|
-
entry =
|
|
3381
|
+
choice_for_env = _select_app_choice(app_id, purpose="eval")
|
|
3382
|
+
entry = choice_for_env.ensure_entry()
|
|
2557
3383
|
|
|
2558
3384
|
env_paths: list[Path] = []
|
|
2559
3385
|
if entry is not None:
|
|
2560
|
-
|
|
3386
|
+
original_env_path = choice_for_env.path if choice_for_env is not None else None
|
|
3387
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
|
|
2561
3388
|
else:
|
|
2562
3389
|
if not env_file:
|
|
2563
3390
|
raise click.ClickException("--env-file is required when using --url")
|
|
@@ -2580,12 +3407,30 @@ def eval_command(
|
|
|
2580
3407
|
app = create_task_app(config)
|
|
2581
3408
|
|
|
2582
3409
|
# Determine supported models
|
|
3410
|
+
inference_meta: dict[str, Any] = {}
|
|
2583
3411
|
supported: list[str] = []
|
|
3412
|
+
seen_models: set[str] = set()
|
|
3413
|
+
|
|
3414
|
+
def _add_supported_model(candidate: Any) -> None:
|
|
3415
|
+
if not candidate:
|
|
3416
|
+
return
|
|
3417
|
+
text = str(candidate).strip()
|
|
3418
|
+
if not text or text in seen_models:
|
|
3419
|
+
return
|
|
3420
|
+
supported.append(text)
|
|
3421
|
+
seen_models.add(text)
|
|
3422
|
+
|
|
2584
3423
|
if task_app_url is None:
|
|
2585
3424
|
try:
|
|
2586
|
-
|
|
3425
|
+
if hasattr(config, "base_task_info") and config.base_task_info:
|
|
3426
|
+
inf_obj = getattr(config.base_task_info, "inference", None)
|
|
3427
|
+
if inf_obj is not None:
|
|
3428
|
+
if hasattr(inf_obj, "model_dump"):
|
|
3429
|
+
inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
|
|
3430
|
+
elif isinstance(inf_obj, dict):
|
|
3431
|
+
inference_meta = dict(inf_obj)
|
|
2587
3432
|
except Exception:
|
|
2588
|
-
|
|
3433
|
+
inference_meta = {}
|
|
2589
3434
|
else:
|
|
2590
3435
|
try:
|
|
2591
3436
|
import httpx as _hx
|
|
@@ -2598,38 +3443,38 @@ def eval_command(
|
|
|
2598
3443
|
info = c.get("/info").json()
|
|
2599
3444
|
inf = info.get("inference") if isinstance(info, dict) else None
|
|
2600
3445
|
if isinstance(inf, dict):
|
|
2601
|
-
|
|
2602
|
-
if isinstance(m, list):
|
|
2603
|
-
supported = [str(x) for x in m]
|
|
2604
|
-
if not supported:
|
|
2605
|
-
providers = inf.get("providers")
|
|
2606
|
-
if isinstance(providers, list):
|
|
2607
|
-
if "openai" in providers:
|
|
2608
|
-
supported.append("gpt-5")
|
|
2609
|
-
if "groq" in providers:
|
|
2610
|
-
supported.append("groq:llama-3.1-70b-versatile")
|
|
2611
|
-
supported.append("synth:qwen-0.6b")
|
|
3446
|
+
inference_meta = dict(inf)
|
|
2612
3447
|
except Exception:
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
3448
|
+
inference_meta = {}
|
|
3449
|
+
|
|
3450
|
+
default_model = inference_meta.get("model")
|
|
3451
|
+
if isinstance(default_model, str):
|
|
3452
|
+
_add_supported_model(default_model)
|
|
3453
|
+
|
|
3454
|
+
models_field = inference_meta.get("models")
|
|
3455
|
+
if isinstance(models_field, list):
|
|
3456
|
+
for candidate in models_field:
|
|
3457
|
+
_add_supported_model(candidate)
|
|
3458
|
+
|
|
3459
|
+
supported_models = inference_meta.get("supported_models")
|
|
3460
|
+
if isinstance(supported_models, list):
|
|
3461
|
+
for candidate in supported_models:
|
|
3462
|
+
_add_supported_model(candidate)
|
|
3463
|
+
|
|
3464
|
+
providers = inference_meta.get("providers")
|
|
3465
|
+
if isinstance(providers, list):
|
|
3466
|
+
if "openai" in providers:
|
|
3467
|
+
_add_supported_model("gpt-5")
|
|
3468
|
+
if "groq" in providers:
|
|
3469
|
+
_add_supported_model("groq:llama-3.1-70b-versatile")
|
|
3470
|
+
|
|
3471
|
+
_add_supported_model("synth:qwen-0.6b")
|
|
2627
3472
|
|
|
2628
3473
|
selected_model = model
|
|
2629
3474
|
if not selected_model:
|
|
2630
3475
|
if not supported:
|
|
2631
3476
|
raise click.ClickException(
|
|
2632
|
-
"No supported models; supply --model or add base_task_info.inference.
|
|
3477
|
+
"No supported models; supply --model or add base_task_info.inference.model"
|
|
2633
3478
|
)
|
|
2634
3479
|
click.echo("Select model to evaluate:")
|
|
2635
3480
|
for idx, m in enumerate(supported, start=1):
|
|
@@ -2649,70 +3494,347 @@ def eval_command(
|
|
|
2649
3494
|
if api_key:
|
|
2650
3495
|
headers["X-API-Key"] = api_key
|
|
2651
3496
|
|
|
3497
|
+
# Precompute optional policy overrides from TOML
|
|
3498
|
+
policy_overrides: dict[str, Any] = {}
|
|
3499
|
+
try:
|
|
3500
|
+
# Accept [eval.policy] table or top-level keys for convenience
|
|
3501
|
+
if isinstance(cfg.get("policy"), dict):
|
|
3502
|
+
policy_overrides.update(dict(cfg["policy"]))
|
|
3503
|
+
# Back-compat: allow temperature/max_tokens at top level
|
|
3504
|
+
for k in (
|
|
3505
|
+
"temperature",
|
|
3506
|
+
"max_tokens",
|
|
3507
|
+
"reasoning_effort",
|
|
3508
|
+
"system_hint",
|
|
3509
|
+
"tool_choice",
|
|
3510
|
+
"inference_url",
|
|
3511
|
+
):
|
|
3512
|
+
if k in cfg and k not in policy_overrides:
|
|
3513
|
+
policy_overrides[k] = cfg.get(k)
|
|
3514
|
+
except Exception:
|
|
3515
|
+
policy_overrides = {}
|
|
3516
|
+
|
|
3517
|
+
raw_concurrency = cfg.get("concurrency")
|
|
3518
|
+
try:
|
|
3519
|
+
concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
|
|
3520
|
+
except Exception:
|
|
3521
|
+
concurrency_limit = 1
|
|
3522
|
+
if concurrency_limit <= 0:
|
|
3523
|
+
concurrency_limit = 1
|
|
3524
|
+
concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
|
|
3525
|
+
|
|
3526
|
+
judge_specs: list[JudgeSpec] = []
|
|
3527
|
+
|
|
3528
|
+
def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
|
|
3529
|
+
if not judge_cfg:
|
|
3530
|
+
return
|
|
3531
|
+
judge_module = judge_cfg.get("module")
|
|
3532
|
+
judge_path = judge_cfg.get("path")
|
|
3533
|
+
judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
|
|
3534
|
+
if judge_module and judge_path:
|
|
3535
|
+
raise click.ClickException("Judge config cannot set both 'module' and 'path'")
|
|
3536
|
+
if not judge_module and not judge_path:
|
|
3537
|
+
raise click.ClickException("Judge config requires 'module' or 'path'")
|
|
3538
|
+
try:
|
|
3539
|
+
if judge_module:
|
|
3540
|
+
module = importlib.import_module(str(judge_module))
|
|
3541
|
+
else:
|
|
3542
|
+
path = Path(str(judge_path)).expanduser()
|
|
3543
|
+
if not path.exists():
|
|
3544
|
+
raise click.ClickException(f"Judge module path not found: {path}")
|
|
3545
|
+
spec = importlib.util.spec_from_file_location(
|
|
3546
|
+
f"_eval_judge_{path.stem}", path
|
|
3547
|
+
)
|
|
3548
|
+
if not spec or not spec.loader:
|
|
3549
|
+
raise click.ClickException(f"Failed to load judge module from {path}")
|
|
3550
|
+
module = importlib.util.module_from_spec(spec)
|
|
3551
|
+
sys.modules[spec.name] = module
|
|
3552
|
+
spec.loader.exec_module(module)
|
|
3553
|
+
except click.ClickException:
|
|
3554
|
+
raise
|
|
3555
|
+
except Exception as exc:
|
|
3556
|
+
raise click.ClickException(f"Unable to load judge module: {exc}") from exc
|
|
3557
|
+
|
|
3558
|
+
if judge_callable_name:
|
|
3559
|
+
try:
|
|
3560
|
+
judge_fn = getattr(module, str(judge_callable_name))
|
|
3561
|
+
except AttributeError as exc:
|
|
3562
|
+
raise click.ClickException(
|
|
3563
|
+
f"Judge callable '{judge_callable_name}' not found in module"
|
|
3564
|
+
) from exc
|
|
3565
|
+
else:
|
|
3566
|
+
if hasattr(module, "judge"):
|
|
3567
|
+
judge_fn = module.judge
|
|
3568
|
+
else:
|
|
3569
|
+
raise click.ClickException("Judge module must expose 'judge' callable")
|
|
3570
|
+
|
|
3571
|
+
if not callable(judge_fn):
|
|
3572
|
+
raise click.ClickException("Judge callable is not callable")
|
|
3573
|
+
|
|
3574
|
+
judge_kwargs = {
|
|
3575
|
+
k: v
|
|
3576
|
+
for k, v in judge_cfg.items()
|
|
3577
|
+
if k not in {"module", "path", "callable", "function", "name"}
|
|
3578
|
+
}
|
|
3579
|
+
display_name = str(
|
|
3580
|
+
judge_cfg.get("name")
|
|
3581
|
+
or name_hint
|
|
3582
|
+
or f"judge{len(judge_specs) + 1}"
|
|
3583
|
+
)
|
|
3584
|
+
judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
|
|
3585
|
+
|
|
3586
|
+
raw_judge_cfg = cfg.get("judge")
|
|
3587
|
+
if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
|
|
3588
|
+
direct_keys = {"module", "path", "callable", "function", "name"}
|
|
3589
|
+
has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
|
|
3590
|
+
nested_candidates = [
|
|
3591
|
+
(key, value)
|
|
3592
|
+
for key, value in raw_judge_cfg.items()
|
|
3593
|
+
if isinstance(value, dict)
|
|
3594
|
+
]
|
|
3595
|
+
if has_direct_keys and not nested_candidates:
|
|
3596
|
+
_register_judge(None, raw_judge_cfg)
|
|
3597
|
+
else:
|
|
3598
|
+
for sub_name, sub_cfg in nested_candidates:
|
|
3599
|
+
_register_judge(sub_name, sub_cfg)
|
|
3600
|
+
|
|
3601
|
+
raw_judges_list = cfg.get("judges")
|
|
3602
|
+
if isinstance(raw_judges_list, list):
|
|
3603
|
+
for _index, entry in enumerate(raw_judges_list, start=1):
|
|
3604
|
+
if isinstance(entry, dict):
|
|
3605
|
+
_register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
|
|
3606
|
+
|
|
3607
|
+
records: list[dict[str, Any]] = []
|
|
3608
|
+
|
|
2652
3609
|
successes = 0
|
|
2653
3610
|
failures = 0
|
|
2654
3611
|
# Aggregate outcome stats across successful seeds
|
|
2655
3612
|
outcome_sum: float = 0.0
|
|
2656
3613
|
outcome_count: int = 0
|
|
2657
3614
|
outcome_correct: int = 0
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
)
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
"
|
|
2683
|
-
"
|
|
2684
|
-
"
|
|
2685
|
-
|
|
2686
|
-
|
|
2687
|
-
policy_overrides[k] = cfg.get(k)
|
|
2688
|
-
except Exception:
|
|
2689
|
-
policy_overrides = {}
|
|
2690
|
-
|
|
2691
|
-
for seed_val in seed_values:
|
|
2692
|
-
body = {
|
|
2693
|
-
"run_id": str(uuid.uuid4()),
|
|
2694
|
-
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
2695
|
-
"policy": {
|
|
2696
|
-
"policy_name": selected_model,
|
|
2697
|
-
"config": {"model": selected_model, **policy_overrides},
|
|
2698
|
-
},
|
|
2699
|
-
"ops": [],
|
|
3615
|
+
|
|
3616
|
+
def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
|
|
3617
|
+
rows: dict[int, dict[str, Any]] = {}
|
|
3618
|
+
if not isinstance(taskset, dict):
|
|
3619
|
+
return rows
|
|
3620
|
+
|
|
3621
|
+
scenario_ids = taskset.get("scenario_ids") or []
|
|
3622
|
+
loop_ids = taskset.get("loop_ids") or []
|
|
3623
|
+
thread_ids = taskset.get("thread_ids") or []
|
|
3624
|
+
difficulty_map = taskset.get("difficulty_map") or {}
|
|
3625
|
+
|
|
3626
|
+
max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
|
|
3627
|
+
for seed in range(max_len):
|
|
3628
|
+
scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
|
|
3629
|
+
loop_id = loop_ids[seed] if seed < len(loop_ids) else None
|
|
3630
|
+
thread_id = thread_ids[seed] if seed < len(thread_ids) else None
|
|
3631
|
+
difficulty = None
|
|
3632
|
+
if isinstance(difficulty_map, dict):
|
|
3633
|
+
if scenario_id and scenario_id in difficulty_map:
|
|
3634
|
+
difficulty = difficulty_map.get(scenario_id)
|
|
3635
|
+
elif str(seed) in difficulty_map:
|
|
3636
|
+
difficulty = difficulty_map.get(str(seed))
|
|
3637
|
+
|
|
3638
|
+
rows[seed] = {
|
|
3639
|
+
"seed": seed,
|
|
3640
|
+
"scenario_id": scenario_id,
|
|
3641
|
+
"loop_id": loop_id,
|
|
3642
|
+
"thread_id": thread_id,
|
|
3643
|
+
"difficulty": difficulty,
|
|
2700
3644
|
}
|
|
3645
|
+
return rows
|
|
3646
|
+
|
|
3647
|
+
def _apply_metadata_filters(
|
|
3648
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
|
|
3649
|
+
) -> list[int]:
|
|
3650
|
+
if not filters:
|
|
3651
|
+
return seeds_list
|
|
3652
|
+
filtered: list[int] = []
|
|
3653
|
+
for seed in seeds_list:
|
|
3654
|
+
row = rows.get(seed)
|
|
3655
|
+
if not row:
|
|
3656
|
+
continue
|
|
3657
|
+
include = True
|
|
3658
|
+
for key, expected in filters.items():
|
|
3659
|
+
actual = row.get(key)
|
|
3660
|
+
if actual is None:
|
|
3661
|
+
include = False
|
|
3662
|
+
break
|
|
3663
|
+
if str(actual).lower() != expected.lower():
|
|
3664
|
+
include = False
|
|
3665
|
+
break
|
|
3666
|
+
if include:
|
|
3667
|
+
filtered.append(seed)
|
|
3668
|
+
return filtered
|
|
3669
|
+
|
|
3670
|
+
def _apply_metadata_sql(
|
|
3671
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
|
|
3672
|
+
) -> list[int]:
|
|
3673
|
+
"""Return seeds that satisfy an arbitrary SQL query.
|
|
3674
|
+
|
|
3675
|
+
The query is executed against an in-memory SQLite table named `tasks`
|
|
3676
|
+
with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
|
|
3677
|
+
Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
|
|
3678
|
+
"""
|
|
3679
|
+
if not query:
|
|
3680
|
+
return seeds_list
|
|
3681
|
+
conn = sqlite3.connect(":memory:")
|
|
3682
|
+
try:
|
|
3683
|
+
cur = conn.cursor()
|
|
3684
|
+
cur.execute(
|
|
3685
|
+
"CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
|
|
3686
|
+
)
|
|
3687
|
+
insert_stmt = (
|
|
3688
|
+
"INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
|
|
3689
|
+
)
|
|
3690
|
+
for seed in seeds_list:
|
|
3691
|
+
row = rows.get(seed, {})
|
|
3692
|
+
cur.execute(
|
|
3693
|
+
insert_stmt,
|
|
3694
|
+
[
|
|
3695
|
+
seed,
|
|
3696
|
+
row.get("scenario_id"),
|
|
3697
|
+
row.get("loop_id"),
|
|
3698
|
+
row.get("thread_id"),
|
|
3699
|
+
row.get("difficulty"),
|
|
3700
|
+
],
|
|
3701
|
+
)
|
|
3702
|
+
|
|
3703
|
+
result = cur.execute(query)
|
|
3704
|
+
fetched = result.fetchall()
|
|
3705
|
+
if not fetched:
|
|
3706
|
+
return []
|
|
3707
|
+
description = result.description or []
|
|
3708
|
+
col_names = [col[0] for col in description]
|
|
3709
|
+
seeds_out: list[int] = []
|
|
3710
|
+
for entry in fetched:
|
|
3711
|
+
value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
|
|
3712
|
+
try:
|
|
3713
|
+
seeds_out.append(int(value))
|
|
3714
|
+
except Exception as exc:
|
|
3715
|
+
raise click.ClickException(
|
|
3716
|
+
"metadata SQL query must return seed integers"
|
|
3717
|
+
) from exc
|
|
3718
|
+
seeds_set = set(seeds_out)
|
|
3719
|
+
return [seed for seed in seeds_list if seed in seeds_set]
|
|
3720
|
+
except sqlite3.Error as exc:
|
|
3721
|
+
raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
|
|
3722
|
+
finally:
|
|
3723
|
+
conn.close()
|
|
3724
|
+
|
|
3725
|
+
async def _run_eval() -> None:
|
|
3726
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
|
|
3727
|
+
|
|
3728
|
+
if trace_tracer is not None and trace_tracer.db is None:
|
|
3729
|
+
await trace_tracer.initialize()
|
|
3730
|
+
|
|
3731
|
+
if task_app_url is None:
|
|
3732
|
+
transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
|
|
3733
|
+
async_client = httpx.AsyncClient(
|
|
3734
|
+
transport=cast(Any, transport),
|
|
3735
|
+
base_url="http://eval.local",
|
|
3736
|
+
timeout=300.0,
|
|
3737
|
+
follow_redirects=True,
|
|
3738
|
+
headers=headers,
|
|
3739
|
+
)
|
|
3740
|
+
else:
|
|
3741
|
+
async_client = httpx.AsyncClient(
|
|
3742
|
+
base_url=task_app_url,
|
|
3743
|
+
timeout=300.0,
|
|
3744
|
+
follow_redirects=True,
|
|
3745
|
+
headers=headers,
|
|
3746
|
+
)
|
|
3747
|
+
|
|
3748
|
+
try:
|
|
3749
|
+
taskset_payload: dict[str, Any] | None = None
|
|
2701
3750
|
try:
|
|
2702
|
-
|
|
2703
|
-
|
|
3751
|
+
task_info_response = await async_client.get("/task_info")
|
|
3752
|
+
except Exception:
|
|
3753
|
+
task_info_response = None
|
|
3754
|
+
if task_info_response is not None and task_info_response.status_code == 200:
|
|
3755
|
+
with contextlib.suppress(Exception):
|
|
3756
|
+
payload_json = task_info_response.json()
|
|
3757
|
+
if isinstance(payload_json, dict) and "taskset" in payload_json:
|
|
3758
|
+
taskset_payload = payload_json.get("taskset")
|
|
3759
|
+
if not isinstance(taskset_payload, dict):
|
|
3760
|
+
taskset_payload = None
|
|
3761
|
+
elif isinstance(payload_json, dict):
|
|
3762
|
+
taskset_payload = payload_json
|
|
3763
|
+
|
|
3764
|
+
available_seeds = list(seed_values)
|
|
3765
|
+
if metadata_sql_query or metadata_filters:
|
|
3766
|
+
if not taskset_payload:
|
|
3767
|
+
raise click.ClickException(
|
|
3768
|
+
"Task metadata filters require the task app to expose /task_info metadata"
|
|
3769
|
+
)
|
|
3770
|
+
rows = _build_task_rows(taskset_payload)
|
|
3771
|
+
if metadata_sql_query:
|
|
3772
|
+
available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
|
|
3773
|
+
if metadata_filters:
|
|
3774
|
+
available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
|
|
3775
|
+
if not available_seeds:
|
|
3776
|
+
raise click.ClickException("No seeds match the provided metadata filters")
|
|
3777
|
+
seed_values = available_seeds
|
|
3778
|
+
|
|
3779
|
+
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
3780
|
+
|
|
3781
|
+
async def _run_seed(seed_val: int) -> None:
|
|
3782
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3783
|
+
body = {
|
|
3784
|
+
"run_id": str(uuid.uuid4()),
|
|
3785
|
+
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
3786
|
+
"policy": {
|
|
3787
|
+
"policy_name": selected_model,
|
|
3788
|
+
"config": {"model": selected_model, **policy_overrides},
|
|
3789
|
+
},
|
|
3790
|
+
"ops": [],
|
|
3791
|
+
}
|
|
3792
|
+
rollout_elapsed: float | None = None
|
|
3793
|
+
rollout_start = time.perf_counter()
|
|
3794
|
+
try:
|
|
3795
|
+
async with semaphore:
|
|
3796
|
+
response = await async_client.post("/rollout", json=body)
|
|
3797
|
+
rollout_elapsed = time.perf_counter() - rollout_start
|
|
3798
|
+
except Exception as exc:
|
|
3799
|
+
failures += 1
|
|
3800
|
+
click.echo(f"seed={seed_val} error={exc}")
|
|
3801
|
+
return
|
|
3802
|
+
|
|
3803
|
+
ok = 200 <= response.status_code < 300
|
|
2704
3804
|
if ok:
|
|
2705
3805
|
successes += 1
|
|
2706
3806
|
else:
|
|
2707
3807
|
failures += 1
|
|
2708
3808
|
|
|
2709
|
-
|
|
2710
|
-
|
|
3809
|
+
summary = [f"seed={seed_val}", f"status={response.status_code}"]
|
|
3810
|
+
data: Any
|
|
2711
3811
|
try:
|
|
2712
|
-
data =
|
|
3812
|
+
data = response.json()
|
|
2713
3813
|
except Exception:
|
|
2714
3814
|
data = None
|
|
3815
|
+
|
|
3816
|
+
metrics: dict[str, Any] | None = None
|
|
3817
|
+
completion: str | None = None
|
|
3818
|
+
prompt_index: int | None = None
|
|
3819
|
+
prompt_text: str | None = None
|
|
3820
|
+
task_id: str | None = None
|
|
3821
|
+
task_split: str | None = None
|
|
3822
|
+
task_rubric_id: str | None = None
|
|
3823
|
+
|
|
3824
|
+
trace_namespace: dict[str, Any] | None = None
|
|
3825
|
+
session_trace_dict: dict[str, Any] | None = None
|
|
3826
|
+
|
|
2715
3827
|
if isinstance(data, dict):
|
|
3828
|
+
trace_namespace = data.get("trace")
|
|
3829
|
+
if not isinstance(trace_namespace, dict):
|
|
3830
|
+
raise RuntimeError(
|
|
3831
|
+
"rollout response missing trace payload; task app must return tracing_v3 data"
|
|
3832
|
+
)
|
|
3833
|
+
session_trace_dict = trace_namespace.get("session_trace")
|
|
3834
|
+
if not isinstance(session_trace_dict, dict):
|
|
3835
|
+
raise RuntimeError(
|
|
3836
|
+
"rollout response trace missing 'session_trace'; ensure the task app is serving the tracing_v3 build"
|
|
3837
|
+
)
|
|
2716
3838
|
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
2717
3839
|
if metrics:
|
|
2718
3840
|
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
@@ -2721,7 +3843,6 @@ def eval_command(
|
|
|
2721
3843
|
summary.append(f"mean_return={mean_return}")
|
|
2722
3844
|
if outcome is not None:
|
|
2723
3845
|
summary.append(f"outcome={outcome}")
|
|
2724
|
-
# Aggregate outcome stats
|
|
2725
3846
|
try:
|
|
2726
3847
|
val = float(outcome)
|
|
2727
3848
|
outcome_sum += val
|
|
@@ -2730,7 +3851,6 @@ def eval_command(
|
|
|
2730
3851
|
outcome_correct += 1
|
|
2731
3852
|
except Exception:
|
|
2732
3853
|
pass
|
|
2733
|
-
# Try to infer tool call count from first trajectory step
|
|
2734
3854
|
trajs = (
|
|
2735
3855
|
data.get("trajectories")
|
|
2736
3856
|
if isinstance(data.get("trajectories"), list)
|
|
@@ -2744,38 +3864,163 @@ def eval_command(
|
|
|
2744
3864
|
tool_calls = step0.get("tool_calls") or step0.get("tools") or []
|
|
2745
3865
|
if isinstance(tool_calls, list):
|
|
2746
3866
|
summary.append(f"tool_calls={len(tool_calls)}")
|
|
3867
|
+
obs = step0.get("obs") if isinstance(step0, dict) else None
|
|
3868
|
+
if isinstance(obs, dict):
|
|
3869
|
+
idx_val = obs.get("prompt_index")
|
|
3870
|
+
if isinstance(idx_val, int):
|
|
3871
|
+
prompt_index = idx_val
|
|
3872
|
+
prompt_raw = obs.get("prompt")
|
|
3873
|
+
if isinstance(prompt_raw, str):
|
|
3874
|
+
prompt_text = prompt_raw
|
|
3875
|
+
if task_id is None:
|
|
3876
|
+
candidate_id = obs.get("task_id")
|
|
3877
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3878
|
+
task_id = candidate_id
|
|
3879
|
+
if task_split is None:
|
|
3880
|
+
candidate_split = obs.get("task_split")
|
|
3881
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3882
|
+
task_split = candidate_split
|
|
3883
|
+
if task_rubric_id is None:
|
|
3884
|
+
candidate_rid = obs.get("task_rubric_id")
|
|
3885
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3886
|
+
task_rubric_id = candidate_rid
|
|
3887
|
+
final = first.get("final") if isinstance(first, dict) else None
|
|
3888
|
+
if isinstance(final, dict):
|
|
3889
|
+
final_obs = final.get("observation")
|
|
3890
|
+
if isinstance(final_obs, dict):
|
|
3891
|
+
comp_val = final_obs.get("completion")
|
|
3892
|
+
if isinstance(comp_val, str):
|
|
3893
|
+
completion = comp_val
|
|
3894
|
+
if task_id is None:
|
|
3895
|
+
candidate_id = final_obs.get("task_id")
|
|
3896
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3897
|
+
task_id = candidate_id
|
|
3898
|
+
if task_split is None:
|
|
3899
|
+
candidate_split = final_obs.get("task_split")
|
|
3900
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3901
|
+
task_split = candidate_split
|
|
3902
|
+
if task_rubric_id is None:
|
|
3903
|
+
candidate_rid = final_obs.get("task_rubric_id")
|
|
3904
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3905
|
+
task_rubric_id = candidate_rid
|
|
3906
|
+
final_info = final.get("info")
|
|
3907
|
+
if isinstance(final_info, dict):
|
|
3908
|
+
if task_id is None:
|
|
3909
|
+
candidate_id = final_info.get("task_id")
|
|
3910
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3911
|
+
task_id = candidate_id
|
|
3912
|
+
if task_split is None:
|
|
3913
|
+
candidate_split = final_info.get("task_split")
|
|
3914
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3915
|
+
task_split = candidate_split
|
|
3916
|
+
if task_rubric_id is None:
|
|
3917
|
+
candidate_rid = final_info.get("task_rubric_id")
|
|
3918
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3919
|
+
task_rubric_id = candidate_rid
|
|
3920
|
+
if task_id:
|
|
3921
|
+
summary.append(f"task_id={task_id}")
|
|
2747
3922
|
click.echo(" ".join(summary))
|
|
2748
|
-
# Print the full response JSON (trace, trajectories, metrics)
|
|
2749
3923
|
with contextlib.suppress(Exception):
|
|
2750
3924
|
click.echo(json.dumps(data, indent=2))
|
|
2751
3925
|
else:
|
|
2752
3926
|
click.echo(" ".join(summary))
|
|
2753
|
-
except Exception as exc:
|
|
2754
|
-
failures += 1
|
|
2755
|
-
click.echo(f"seed={seed_val} error={exc}")
|
|
2756
3927
|
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
except RuntimeError:
|
|
2766
|
-
# Fallback when already inside a running loop (rare for CLI).
|
|
2767
|
-
new_loop = asyncio.new_event_loop()
|
|
3928
|
+
official_score = None
|
|
3929
|
+
if isinstance(metrics, dict):
|
|
3930
|
+
for key in ("mean_return", "total_reward", "outcome_score"):
|
|
3931
|
+
val = metrics.get(key)
|
|
3932
|
+
if isinstance(val, int | float):
|
|
3933
|
+
official_score = float(val)
|
|
3934
|
+
break
|
|
3935
|
+
if official_score is None and isinstance(data, dict):
|
|
2768
3936
|
try:
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
3937
|
+
reward_val = data["trajectories"][0]["steps"][0].get("reward")
|
|
3938
|
+
if isinstance(reward_val, int | float):
|
|
3939
|
+
official_score = float(reward_val)
|
|
3940
|
+
except Exception:
|
|
3941
|
+
pass
|
|
3942
|
+
|
|
3943
|
+
if official_score is not None:
|
|
3944
|
+
if official_score < 0.0:
|
|
3945
|
+
official_score = 0.0
|
|
3946
|
+
elif official_score > 1.0:
|
|
3947
|
+
official_score = min(1.0, official_score)
|
|
3948
|
+
|
|
3949
|
+
judge_scores: dict[str, float | None] = {}
|
|
3950
|
+
judges_timings: dict[str, float | None] = {}
|
|
3951
|
+
timings: dict[str, Any] = {
|
|
3952
|
+
"rollout_s": rollout_elapsed,
|
|
3953
|
+
"judges": judges_timings,
|
|
3954
|
+
}
|
|
3955
|
+
if judge_specs:
|
|
3956
|
+
for spec in judge_specs:
|
|
3957
|
+
score_value: float | None = None
|
|
3958
|
+
judge_elapsed: float | None = None
|
|
3959
|
+
if completion is not None:
|
|
3960
|
+
judge_payload = {
|
|
3961
|
+
"seed": seed_val,
|
|
3962
|
+
"prompt_index": prompt_index,
|
|
3963
|
+
"prompt": prompt_text,
|
|
3964
|
+
"completion": completion,
|
|
3965
|
+
"metrics": metrics,
|
|
3966
|
+
"response": data,
|
|
3967
|
+
"trace": trace_namespace,
|
|
3968
|
+
}
|
|
3969
|
+
try:
|
|
3970
|
+
judge_start = time.perf_counter()
|
|
3971
|
+
result = spec.fn(judge_payload, **spec.kwargs)
|
|
3972
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
3973
|
+
if isinstance(result, int | float):
|
|
3974
|
+
score_value = float(result)
|
|
3975
|
+
except Exception as exc:
|
|
3976
|
+
if judge_elapsed is None:
|
|
3977
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
3978
|
+
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
3979
|
+
judges_timings[spec.name] = judge_elapsed
|
|
3980
|
+
judge_scores[spec.name] = score_value
|
|
3981
|
+
|
|
3982
|
+
if trace_tracer is not None and trace_namespace:
|
|
3983
|
+
storage_metadata = {
|
|
3984
|
+
"eval_seed": seed_val,
|
|
3985
|
+
"prompt_index": prompt_index,
|
|
3986
|
+
"task_id": task_id,
|
|
3987
|
+
"task_split": task_split,
|
|
3988
|
+
"task_rubric_id": task_rubric_id,
|
|
3989
|
+
"official_score": official_score,
|
|
3990
|
+
"judge_scores": judge_scores,
|
|
3991
|
+
"model": selected_model,
|
|
3992
|
+
"prompt": prompt_text,
|
|
3993
|
+
"completion": completion,
|
|
3994
|
+
}
|
|
3995
|
+
await _store_trace(trace_tracer, trace_namespace, storage_metadata)
|
|
3996
|
+
|
|
3997
|
+
records.append(
|
|
3998
|
+
{
|
|
3999
|
+
"seed": seed_val,
|
|
4000
|
+
"prompt_index": prompt_index,
|
|
4001
|
+
"task_id": task_id,
|
|
4002
|
+
"task_split": task_split,
|
|
4003
|
+
"task_rubric_id": task_rubric_id,
|
|
4004
|
+
"official_score": official_score,
|
|
4005
|
+
"judge_scores": judge_scores,
|
|
4006
|
+
"timings": timings,
|
|
4007
|
+
}
|
|
4008
|
+
)
|
|
4009
|
+
|
|
4010
|
+
await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
|
|
4011
|
+
finally:
|
|
4012
|
+
await async_client.aclose()
|
|
4013
|
+
|
|
4014
|
+
try:
|
|
4015
|
+
asyncio.run(_run_eval())
|
|
4016
|
+
finally:
|
|
4017
|
+
if trace_tracer is not None and trace_tracer.db is not None:
|
|
4018
|
+
asyncio.run(trace_tracer.db.close())
|
|
2774
4019
|
|
|
2775
4020
|
click.echo(
|
|
2776
4021
|
f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
|
|
2777
4022
|
)
|
|
2778
|
-
|
|
4023
|
+
|
|
2779
4024
|
if outcome_count > 0:
|
|
2780
4025
|
mean_outcome = outcome_sum / float(outcome_count)
|
|
2781
4026
|
frac_right = outcome_correct / float(outcome_count)
|
|
@@ -2783,6 +4028,285 @@ def eval_command(
|
|
|
2783
4028
|
f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
|
|
2784
4029
|
)
|
|
2785
4030
|
|
|
4031
|
+
if records:
|
|
4032
|
+
judge_specs = judge_specs or [] # ensure iterable
|
|
4033
|
+
official_scores = [
|
|
4034
|
+
r["official_score"] for r in records if r["official_score"] is not None
|
|
4035
|
+
]
|
|
4036
|
+
if official_scores:
|
|
4037
|
+
click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
|
|
4038
|
+
else:
|
|
4039
|
+
click.echo(" Official mean: n/a")
|
|
4040
|
+
|
|
4041
|
+
for spec in judge_specs:
|
|
4042
|
+
spec_scores = [
|
|
4043
|
+
record["judge_scores"].get(spec.name)
|
|
4044
|
+
for record in records
|
|
4045
|
+
if record["judge_scores"].get(spec.name) is not None
|
|
4046
|
+
]
|
|
4047
|
+
if spec_scores:
|
|
4048
|
+
mean_spec = sum(spec_scores) / len(spec_scores)
|
|
4049
|
+
click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
|
|
4050
|
+
else:
|
|
4051
|
+
click.echo(f" [{spec.name}] mean: n/a")
|
|
4052
|
+
|
|
4053
|
+
paired = [
|
|
4054
|
+
(
|
|
4055
|
+
record["official_score"],
|
|
4056
|
+
record["judge_scores"].get(spec.name),
|
|
4057
|
+
)
|
|
4058
|
+
for record in records
|
|
4059
|
+
if record["official_score"] is not None
|
|
4060
|
+
and record["judge_scores"].get(spec.name) is not None
|
|
4061
|
+
]
|
|
4062
|
+
if len(paired) >= 2:
|
|
4063
|
+
corr = _pearson(
|
|
4064
|
+
[p[0] for p in paired if p[0] is not None],
|
|
4065
|
+
[p[1] for p in paired if p[1] is not None],
|
|
4066
|
+
)
|
|
4067
|
+
if corr is not None:
|
|
4068
|
+
click.echo(f" Pearson r: {corr:.3f}")
|
|
4069
|
+
else:
|
|
4070
|
+
click.echo(" Pearson r: undefined (zero variance)")
|
|
4071
|
+
else:
|
|
4072
|
+
click.echo(" Pearson r: n/a (need ≥2 paired scores)")
|
|
4073
|
+
|
|
4074
|
+
header = ["Seed", "Prompt", "Official"]
|
|
4075
|
+
header.extend(spec.name for spec in judge_specs)
|
|
4076
|
+
rows: list[list[str]] = []
|
|
4077
|
+
for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
|
|
4078
|
+
seed_val = str(record["seed"])
|
|
4079
|
+
prompt_idx = (
|
|
4080
|
+
str(record["prompt_index"])
|
|
4081
|
+
if record["prompt_index"] is not None
|
|
4082
|
+
else "-"
|
|
4083
|
+
)
|
|
4084
|
+
official_val = (
|
|
4085
|
+
f"{record['official_score']:.3f}"
|
|
4086
|
+
if record["official_score"] is not None
|
|
4087
|
+
else "-"
|
|
4088
|
+
)
|
|
4089
|
+
row = [seed_val, prompt_idx, official_val]
|
|
4090
|
+
for spec in judge_specs:
|
|
4091
|
+
score_val = record["judge_scores"].get(spec.name)
|
|
4092
|
+
row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
|
|
4093
|
+
rows.append(row)
|
|
4094
|
+
|
|
4095
|
+
widths = [len(col) for col in header]
|
|
4096
|
+
for row in rows:
|
|
4097
|
+
for idx, cell in enumerate(row):
|
|
4098
|
+
widths[idx] = max(widths[idx], len(cell))
|
|
4099
|
+
|
|
4100
|
+
click.echo("")
|
|
4101
|
+
click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
|
|
4102
|
+
click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
|
|
4103
|
+
for row in rows:
|
|
4104
|
+
click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
|
|
4105
|
+
|
|
4106
|
+
|
|
4107
|
+
|
|
4108
|
+
@click.command(
|
|
4109
|
+
"filter",
|
|
4110
|
+
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
4111
|
+
)
|
|
4112
|
+
@click.option(
|
|
4113
|
+
"--config",
|
|
4114
|
+
"config_path",
|
|
4115
|
+
type=click.Path(),
|
|
4116
|
+
required=True,
|
|
4117
|
+
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
4118
|
+
)
|
|
4119
|
+
def filter_command(config_path: str) -> None:
|
|
4120
|
+
"""Render tracing sessions that match filter rules into SFT JSONL.
|
|
4121
|
+
|
|
4122
|
+
The TOML file should contain a `[filter]` table with at least:
|
|
4123
|
+
|
|
4124
|
+
db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
|
|
4125
|
+
output = \"ft_data/out.jsonl\" # destination JSONL
|
|
4126
|
+
|
|
4127
|
+
Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
|
|
4128
|
+
`min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
|
|
4129
|
+
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4130
|
+
for a working example.
|
|
4131
|
+
"""
|
|
4132
|
+
if _toml is None:
|
|
4133
|
+
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4134
|
+
|
|
4135
|
+
cfg_path = Path(config_path)
|
|
4136
|
+
if not cfg_path.exists():
|
|
4137
|
+
raise click.ClickException(f"Filter config not found: {cfg_path}")
|
|
4138
|
+
|
|
4139
|
+
try:
|
|
4140
|
+
config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
|
|
4141
|
+
except Exception as exc:
|
|
4142
|
+
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4143
|
+
|
|
4144
|
+
filter_cfg = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4145
|
+
if not isinstance(filter_cfg, dict):
|
|
4146
|
+
raise click.ClickException("Config must contain a [filter] table")
|
|
4147
|
+
|
|
4148
|
+
db_value = str(filter_cfg.get("db", "traces/v3/eval_traces.db")).strip()
|
|
4149
|
+
if not db_value:
|
|
4150
|
+
raise click.ClickException("filter.db must be provided")
|
|
4151
|
+
if "://" in db_value:
|
|
4152
|
+
db_url = db_value
|
|
4153
|
+
else:
|
|
4154
|
+
db_path = Path(db_value).expanduser()
|
|
4155
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4156
|
+
db_url = f"sqlite+aiosqlite:///{db_path}"
|
|
4157
|
+
|
|
4158
|
+
output_value = filter_cfg.get("output")
|
|
4159
|
+
if not output_value:
|
|
4160
|
+
raise click.ClickException("filter.output must be provided")
|
|
4161
|
+
output_path = Path(str(output_value)).expanduser()
|
|
4162
|
+
|
|
4163
|
+
splits = set(filter_cfg.get("splits", []) or [])
|
|
4164
|
+
task_ids = set(filter_cfg.get("task_ids", []) or [])
|
|
4165
|
+
models = set(filter_cfg.get("models", []) or [])
|
|
4166
|
+
min_official = filter_cfg.get("min_official_score")
|
|
4167
|
+
max_official = filter_cfg.get("max_official_score")
|
|
4168
|
+
if min_official is not None:
|
|
4169
|
+
try:
|
|
4170
|
+
min_official = float(min_official)
|
|
4171
|
+
except Exception as err:
|
|
4172
|
+
raise click.ClickException("filter.min_official_score must be numeric") from err
|
|
4173
|
+
if max_official is not None:
|
|
4174
|
+
try:
|
|
4175
|
+
max_official = float(max_official)
|
|
4176
|
+
except Exception as err:
|
|
4177
|
+
raise click.ClickException("filter.max_official_score must be numeric") from err
|
|
4178
|
+
min_judge_scores = filter_cfg.get("min_judge_scores", {}) or {}
|
|
4179
|
+
max_judge_scores = filter_cfg.get("max_judge_scores", {}) or {}
|
|
4180
|
+
try:
|
|
4181
|
+
min_judge_scores = {k: float(v) for k, v in min_judge_scores.items()}
|
|
4182
|
+
except Exception as err:
|
|
4183
|
+
raise click.ClickException("filter.min_judge_scores values must be numeric") from err
|
|
4184
|
+
try:
|
|
4185
|
+
max_judge_scores = {k: float(v) for k, v in max_judge_scores.items()}
|
|
4186
|
+
except Exception as err:
|
|
4187
|
+
raise click.ClickException("filter.max_judge_scores values must be numeric") from err
|
|
4188
|
+
min_created = _parse_datetime_for_trace(filter_cfg.get("min_created_at"))
|
|
4189
|
+
max_created = _parse_datetime_for_trace(filter_cfg.get("max_created_at"))
|
|
4190
|
+
limit = filter_cfg.get("limit")
|
|
4191
|
+
if limit is not None:
|
|
4192
|
+
try:
|
|
4193
|
+
limit = int(limit)
|
|
4194
|
+
except Exception as err:
|
|
4195
|
+
raise click.ClickException("filter.limit must be an integer") from err
|
|
4196
|
+
|
|
4197
|
+
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4198
|
+
try:
|
|
4199
|
+
if value is None:
|
|
4200
|
+
return min_val is None
|
|
4201
|
+
value = float(value)
|
|
4202
|
+
except Exception:
|
|
4203
|
+
return False
|
|
4204
|
+
if min_val is not None and value < float(min_val):
|
|
4205
|
+
return False
|
|
4206
|
+
return not (max_val is not None and value > float(max_val))
|
|
4207
|
+
|
|
4208
|
+
async def _run_filter() -> None:
|
|
4209
|
+
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
4210
|
+
await tracer.initialize()
|
|
4211
|
+
|
|
4212
|
+
df = await tracer.db.query_traces(
|
|
4213
|
+
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
4214
|
+
)
|
|
4215
|
+
if getattr(df, "empty", True):
|
|
4216
|
+
raise click.ClickException("No traces found in database")
|
|
4217
|
+
|
|
4218
|
+
sessions = df.to_dict("records")
|
|
4219
|
+
accepted: list[dict[str, Any]] = []
|
|
4220
|
+
|
|
4221
|
+
for row in sessions:
|
|
4222
|
+
metadata_raw = row.get("metadata")
|
|
4223
|
+
if isinstance(metadata_raw, str):
|
|
4224
|
+
try:
|
|
4225
|
+
metadata = json.loads(metadata_raw)
|
|
4226
|
+
except Exception:
|
|
4227
|
+
metadata = {}
|
|
4228
|
+
elif isinstance(metadata_raw, dict):
|
|
4229
|
+
metadata = dict(metadata_raw)
|
|
4230
|
+
else:
|
|
4231
|
+
metadata = {}
|
|
4232
|
+
|
|
4233
|
+
created_at_raw = row.get("created_at")
|
|
4234
|
+
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
4235
|
+
|
|
4236
|
+
session_id = row.get("session_id")
|
|
4237
|
+
|
|
4238
|
+
if splits and metadata.get("task_split") not in splits:
|
|
4239
|
+
continue
|
|
4240
|
+
if task_ids and metadata.get("task_id") not in task_ids:
|
|
4241
|
+
continue
|
|
4242
|
+
if models and metadata.get("model") not in models:
|
|
4243
|
+
continue
|
|
4244
|
+
|
|
4245
|
+
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
4246
|
+
continue
|
|
4247
|
+
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4248
|
+
continue
|
|
4249
|
+
|
|
4250
|
+
if not _score_ok(metadata.get("official_score"), min_official, max_official):
|
|
4251
|
+
continue
|
|
4252
|
+
|
|
4253
|
+
judge_scores = metadata.get("judge_scores") or {}
|
|
4254
|
+
include = True
|
|
4255
|
+
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
4256
|
+
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
4257
|
+
include = False
|
|
4258
|
+
break
|
|
4259
|
+
if not include:
|
|
4260
|
+
continue
|
|
4261
|
+
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
4262
|
+
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
4263
|
+
include = False
|
|
4264
|
+
break
|
|
4265
|
+
if not include:
|
|
4266
|
+
continue
|
|
4267
|
+
|
|
4268
|
+
prompt = metadata.get("prompt") or ""
|
|
4269
|
+
completion = metadata.get("completion") or ""
|
|
4270
|
+
if not prompt or not completion:
|
|
4271
|
+
continue
|
|
4272
|
+
|
|
4273
|
+
record = {
|
|
4274
|
+
"messages": [
|
|
4275
|
+
{"role": "user", "content": str(prompt)},
|
|
4276
|
+
{"role": "assistant", "content": str(completion)},
|
|
4277
|
+
],
|
|
4278
|
+
"metadata": {
|
|
4279
|
+
"session_id": session_id,
|
|
4280
|
+
"task_id": metadata.get("task_id"),
|
|
4281
|
+
"task_split": metadata.get("task_split"),
|
|
4282
|
+
"task_rubric_id": metadata.get("task_rubric_id"),
|
|
4283
|
+
"official_score": metadata.get("official_score"),
|
|
4284
|
+
"judge_scores": judge_scores,
|
|
4285
|
+
"model": metadata.get("model"),
|
|
4286
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4287
|
+
"prompt": prompt,
|
|
4288
|
+
"completion": completion,
|
|
4289
|
+
},
|
|
4290
|
+
}
|
|
4291
|
+
accepted.append(record)
|
|
4292
|
+
|
|
4293
|
+
if not accepted:
|
|
4294
|
+
raise click.ClickException("No sessions matched the provided filters")
|
|
4295
|
+
|
|
4296
|
+
if limit is not None and limit > 0:
|
|
4297
|
+
accepted = accepted[:limit]
|
|
4298
|
+
|
|
4299
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4300
|
+
with output_path.open("w", encoding="utf-8") as handle:
|
|
4301
|
+
for item in accepted:
|
|
4302
|
+
handle.write(json.dumps(item, ensure_ascii=False))
|
|
4303
|
+
handle.write("\n")
|
|
4304
|
+
|
|
4305
|
+
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
4306
|
+
await tracer.db.close()
|
|
4307
|
+
|
|
4308
|
+
asyncio.run(_run_filter())
|
|
4309
|
+
|
|
2786
4310
|
|
|
2787
4311
|
def register_eval(cli: click.Group) -> None:
|
|
2788
4312
|
cli.add_command(eval_command)
|