synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1707 -186
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +16 -16
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -1,23 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import argparse
|
|
3
4
|
import ast
|
|
4
5
|
import asyncio
|
|
5
6
|
import contextlib
|
|
7
|
+
import functools
|
|
6
8
|
import hashlib
|
|
7
9
|
import importlib
|
|
8
10
|
import importlib.util
|
|
9
11
|
import inspect
|
|
10
12
|
import json
|
|
11
13
|
import os
|
|
14
|
+
import shlex
|
|
12
15
|
import shutil
|
|
13
16
|
import signal
|
|
17
|
+
import sqlite3
|
|
14
18
|
import subprocess
|
|
15
19
|
import sys
|
|
16
20
|
import tempfile
|
|
17
21
|
import textwrap
|
|
22
|
+
import time
|
|
18
23
|
import types
|
|
24
|
+
import uuid
|
|
19
25
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
20
26
|
from dataclasses import dataclass
|
|
27
|
+
from datetime import UTC, datetime
|
|
21
28
|
from pathlib import Path
|
|
22
29
|
from typing import Any, cast
|
|
23
30
|
|
|
@@ -25,24 +32,39 @@ try: # Python 3.11+
|
|
|
25
32
|
import tomllib as _toml
|
|
26
33
|
except Exception: # pragma: no cover - fallback
|
|
27
34
|
_toml = None # type: ignore
|
|
28
|
-
import uuid
|
|
29
35
|
|
|
30
36
|
import click
|
|
31
37
|
from click.exceptions import Abort
|
|
32
38
|
|
|
39
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
40
|
+
BaseEvent,
|
|
41
|
+
EnvironmentEvent,
|
|
42
|
+
RuntimeEvent,
|
|
43
|
+
SessionEventMarkovBlanketMessage,
|
|
44
|
+
SessionMessageContent,
|
|
45
|
+
SessionTimeStep,
|
|
46
|
+
SessionTracer,
|
|
47
|
+
TimeRecord,
|
|
48
|
+
)
|
|
49
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
50
|
+
SessionTrace as V3SessionTrace,
|
|
51
|
+
)
|
|
52
|
+
|
|
33
53
|
# ---------------------------------------------------------------------------
|
|
34
54
|
# Dynamic imports to avoid hard dependencies during type checking.
|
|
35
55
|
# ---------------------------------------------------------------------------
|
|
36
56
|
ModalDeploymentConfigType = TaskAppConfigType = TaskAppEntryType = Any
|
|
37
57
|
|
|
38
58
|
try: # Resolve base URL defaults lazily
|
|
39
|
-
_config_module =
|
|
59
|
+
_config_module = cast(
|
|
60
|
+
Any, importlib.import_module("synth_ai.config.base_url")
|
|
61
|
+
)
|
|
40
62
|
PROD_BASE_URL_DEFAULT = cast(str, _config_module.PROD_BASE_URL_DEFAULT)
|
|
41
63
|
except Exception: # pragma: no cover - fallback
|
|
42
64
|
PROD_BASE_URL_DEFAULT = "https://agent-learning.onrender.com"
|
|
43
65
|
|
|
44
66
|
try:
|
|
45
|
-
_task_apps_module = importlib.import_module("synth_ai.task.apps")
|
|
67
|
+
_task_apps_module = cast(Any, importlib.import_module("synth_ai.task.apps"))
|
|
46
68
|
ModalDeploymentConfig = cast(
|
|
47
69
|
type[ModalDeploymentConfigType], _task_apps_module.ModalDeploymentConfig
|
|
48
70
|
)
|
|
@@ -53,9 +75,9 @@ except Exception as exc: # pragma: no cover - critical dependency
|
|
|
53
75
|
raise RuntimeError("Unable to load task app registry") from exc
|
|
54
76
|
|
|
55
77
|
try:
|
|
56
|
-
_task_server_module = importlib.import_module("synth_ai.task.server")
|
|
57
|
-
create_task_app = _task_server_module.create_task_app
|
|
58
|
-
run_task_app = _task_server_module.run_task_app
|
|
78
|
+
_task_server_module = cast(Any, importlib.import_module("synth_ai.task.server"))
|
|
79
|
+
create_task_app = cast(Callable[..., Any], _task_server_module.create_task_app)
|
|
80
|
+
run_task_app = cast(Callable[..., Any], _task_server_module.run_task_app)
|
|
59
81
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
60
82
|
raise RuntimeError("Unable to load task app server utilities") from exc
|
|
61
83
|
|
|
@@ -64,10 +86,12 @@ def _load_demo_directory() -> Path | None:
|
|
|
64
86
|
"""Return the demo task apps directory if available."""
|
|
65
87
|
|
|
66
88
|
try:
|
|
67
|
-
module =
|
|
68
|
-
|
|
89
|
+
module = cast(
|
|
90
|
+
Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
91
|
+
)
|
|
92
|
+
loader = cast(Callable[[], str | Path | None], module.load_demo_dir)
|
|
69
93
|
demo_dir = loader()
|
|
70
|
-
if isinstance(demo_dir,
|
|
94
|
+
if isinstance(demo_dir, str | Path):
|
|
71
95
|
demo_path = Path(demo_dir)
|
|
72
96
|
if demo_path.exists():
|
|
73
97
|
return demo_path.resolve()
|
|
@@ -105,6 +129,25 @@ DEFAULT_SEARCH_RELATIVE = (
|
|
|
105
129
|
)
|
|
106
130
|
|
|
107
131
|
|
|
132
|
+
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float | None:
|
|
133
|
+
if len(xs) != len(ys) or len(xs) < 2:
|
|
134
|
+
return None
|
|
135
|
+
mean_x = sum(xs) / len(xs)
|
|
136
|
+
mean_y = sum(ys) / len(ys)
|
|
137
|
+
num = 0.0
|
|
138
|
+
denom_x = 0.0
|
|
139
|
+
denom_y = 0.0
|
|
140
|
+
for x, y in zip(xs, ys, strict=False):
|
|
141
|
+
dx = x - mean_x
|
|
142
|
+
dy = y - mean_y
|
|
143
|
+
num += dx * dy
|
|
144
|
+
denom_x += dx * dx
|
|
145
|
+
denom_y += dy * dy
|
|
146
|
+
if denom_x <= 0 or denom_y <= 0:
|
|
147
|
+
return None
|
|
148
|
+
return num / (denom_x ** 0.5 * denom_y ** 0.5)
|
|
149
|
+
|
|
150
|
+
|
|
108
151
|
@dataclass
|
|
109
152
|
class AppChoice:
|
|
110
153
|
app_id: str
|
|
@@ -128,6 +171,171 @@ class AppChoice:
|
|
|
128
171
|
return entry
|
|
129
172
|
|
|
130
173
|
|
|
174
|
+
@dataclass
|
|
175
|
+
class JudgeSpec:
|
|
176
|
+
name: str
|
|
177
|
+
fn: Callable[..., Any]
|
|
178
|
+
kwargs: dict[str, Any]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _parse_datetime_for_trace(value: Any) -> datetime | None:
|
|
182
|
+
if isinstance(value, datetime):
|
|
183
|
+
return value if value.tzinfo else value.replace(tzinfo=UTC)
|
|
184
|
+
if isinstance(value, str):
|
|
185
|
+
value = value.replace("Z", "+00:00")
|
|
186
|
+
try:
|
|
187
|
+
dt = datetime.fromisoformat(value)
|
|
188
|
+
except ValueError:
|
|
189
|
+
try:
|
|
190
|
+
dt = datetime.fromtimestamp(float(value), tz=UTC)
|
|
191
|
+
except Exception:
|
|
192
|
+
return None
|
|
193
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
|
194
|
+
if isinstance(value, int | float):
|
|
195
|
+
return datetime.fromtimestamp(float(value), tz=UTC)
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _time_record_from_dict(payload: dict[str, Any] | None) -> TimeRecord:
|
|
200
|
+
payload = payload or {}
|
|
201
|
+
event_time = payload.get("event_time")
|
|
202
|
+
if not isinstance(event_time, int | float):
|
|
203
|
+
try:
|
|
204
|
+
event_time = float(event_time)
|
|
205
|
+
except Exception:
|
|
206
|
+
event_time = float(time.time())
|
|
207
|
+
message_time = payload.get("message_time")
|
|
208
|
+
if message_time is not None:
|
|
209
|
+
try:
|
|
210
|
+
message_time = int(message_time)
|
|
211
|
+
except Exception:
|
|
212
|
+
message_time = None
|
|
213
|
+
return TimeRecord(event_time=event_time, message_time=message_time)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
|
|
217
|
+
base_kwargs = {
|
|
218
|
+
"system_instance_id": payload.get("system_instance_id", ""),
|
|
219
|
+
"time_record": _time_record_from_dict(payload.get("time_record")),
|
|
220
|
+
"metadata": payload.get("metadata") or {},
|
|
221
|
+
"event_metadata": payload.get("event_metadata"),
|
|
222
|
+
}
|
|
223
|
+
if "actions" in payload:
|
|
224
|
+
return RuntimeEvent(actions=payload.get("actions") or [], **base_kwargs)
|
|
225
|
+
if any(key in payload for key in ("reward", "terminated", "truncated")):
|
|
226
|
+
return EnvironmentEvent(
|
|
227
|
+
reward=float(payload.get("reward", 0.0) or 0.0),
|
|
228
|
+
terminated=bool(payload.get("terminated", False)),
|
|
229
|
+
truncated=bool(payload.get("truncated", False)),
|
|
230
|
+
system_state_before=payload.get("system_state_before"),
|
|
231
|
+
system_state_after=payload.get("system_state_after"),
|
|
232
|
+
**base_kwargs,
|
|
233
|
+
)
|
|
234
|
+
return BaseEvent(**base_kwargs)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlanketMessage:
|
|
238
|
+
content_payload = payload.get("content") or {}
|
|
239
|
+
content = SessionMessageContent(
|
|
240
|
+
text=content_payload.get("text"),
|
|
241
|
+
json_payload=content_payload.get("json_payload"),
|
|
242
|
+
)
|
|
243
|
+
raw_type = (payload.get("message_type") or "").lower()
|
|
244
|
+
if raw_type == "observation":
|
|
245
|
+
normalized_type = "system"
|
|
246
|
+
elif raw_type == "action":
|
|
247
|
+
normalized_type = "assistant"
|
|
248
|
+
elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
|
|
249
|
+
normalized_type = raw_type
|
|
250
|
+
else:
|
|
251
|
+
normalized_type = "system"
|
|
252
|
+
|
|
253
|
+
return SessionEventMarkovBlanketMessage(
|
|
254
|
+
content=content,
|
|
255
|
+
message_type=normalized_type,
|
|
256
|
+
time_record=_time_record_from_dict(payload.get("time_record")),
|
|
257
|
+
metadata=payload.get("metadata") or {},
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
262
|
+
events = [
|
|
263
|
+
_event_from_dict(event)
|
|
264
|
+
for event in payload.get("events", [])
|
|
265
|
+
if isinstance(event, dict)
|
|
266
|
+
]
|
|
267
|
+
messages = [
|
|
268
|
+
_markov_message_from_dict(msg)
|
|
269
|
+
for msg in payload.get("markov_blanket_messages", [])
|
|
270
|
+
if isinstance(msg, dict)
|
|
271
|
+
]
|
|
272
|
+
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(UTC)
|
|
273
|
+
completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
|
|
274
|
+
return SessionTimeStep(
|
|
275
|
+
step_id=payload.get("step_id", ""),
|
|
276
|
+
step_index=int(payload.get("step_index", 0) or 0),
|
|
277
|
+
timestamp=timestamp,
|
|
278
|
+
turn_number=payload.get("turn_number"),
|
|
279
|
+
events=events,
|
|
280
|
+
markov_blanket_messages=messages,
|
|
281
|
+
step_metadata=payload.get("step_metadata") or {},
|
|
282
|
+
completed_at=completed_at,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
|
|
287
|
+
if not isinstance(payload, dict):
|
|
288
|
+
return None
|
|
289
|
+
steps = [
|
|
290
|
+
_step_from_dict(step)
|
|
291
|
+
for step in payload.get("session_time_steps", [])
|
|
292
|
+
if isinstance(step, dict)
|
|
293
|
+
]
|
|
294
|
+
events = [
|
|
295
|
+
_event_from_dict(event)
|
|
296
|
+
for event in payload.get("event_history", [])
|
|
297
|
+
if isinstance(event, dict)
|
|
298
|
+
]
|
|
299
|
+
markov_history = [
|
|
300
|
+
_markov_message_from_dict(msg)
|
|
301
|
+
for msg in payload.get("markov_blanket_message_history", [])
|
|
302
|
+
if isinstance(msg, dict)
|
|
303
|
+
]
|
|
304
|
+
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(UTC)
|
|
305
|
+
metadata = payload.get("metadata") or {}
|
|
306
|
+
session_metadata = payload.get("session_metadata")
|
|
307
|
+
return V3SessionTrace(
|
|
308
|
+
session_id=payload.get("session_id", ""),
|
|
309
|
+
created_at=created_at,
|
|
310
|
+
session_time_steps=steps,
|
|
311
|
+
event_history=events,
|
|
312
|
+
markov_blanket_message_history=markov_history,
|
|
313
|
+
metadata=metadata,
|
|
314
|
+
session_metadata=session_metadata,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
async def _store_trace(
|
|
319
|
+
tracer: SessionTracer | None,
|
|
320
|
+
trace_namespace: dict[str, Any] | None,
|
|
321
|
+
extra_metadata: dict[str, Any] | None = None,
|
|
322
|
+
):
|
|
323
|
+
if tracer is None or not isinstance(trace_namespace, dict):
|
|
324
|
+
return
|
|
325
|
+
session_payload = trace_namespace.get("session_trace")
|
|
326
|
+
if not isinstance(session_payload, dict):
|
|
327
|
+
return
|
|
328
|
+
trace_obj = _session_trace_from_dict(session_payload)
|
|
329
|
+
if trace_obj is None:
|
|
330
|
+
return
|
|
331
|
+
if tracer.db is None:
|
|
332
|
+
await tracer.initialize()
|
|
333
|
+
meta = dict(trace_obj.metadata or {})
|
|
334
|
+
if extra_metadata:
|
|
335
|
+
meta.update(extra_metadata)
|
|
336
|
+
trace_obj.metadata = meta
|
|
337
|
+
await tracer.db.insert_session_trace(trace_obj)
|
|
338
|
+
|
|
131
339
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
132
340
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
133
341
|
|
|
@@ -676,36 +884,44 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
|
|
|
676
884
|
elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
677
885
|
# Handle pip_packages list/tuple
|
|
678
886
|
packages: list[str] = []
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
887
|
+
value_node = kw.value
|
|
888
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
889
|
+
for elt in value_node.elts:
|
|
890
|
+
if isinstance(elt, ast.Constant):
|
|
891
|
+
packages.append(elt.value)
|
|
682
892
|
kwargs[kw.arg] = tuple(packages)
|
|
683
893
|
elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
684
894
|
# Handle extra_local_dirs list/tuple of tuples
|
|
685
895
|
dirs = []
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
896
|
+
value_node = kw.value
|
|
897
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
898
|
+
for elt in value_node.elts:
|
|
899
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
900
|
+
src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
901
|
+
dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
902
|
+
if src and dst:
|
|
903
|
+
dirs.append((src, dst))
|
|
692
904
|
kwargs[kw.arg] = tuple(dirs)
|
|
693
905
|
elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
694
906
|
# Handle secret_names list/tuple
|
|
695
907
|
secrets = []
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
908
|
+
value_node = kw.value
|
|
909
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
910
|
+
for elt in value_node.elts:
|
|
911
|
+
if isinstance(elt, ast.Constant):
|
|
912
|
+
secrets.append(elt.value)
|
|
699
913
|
kwargs[kw.arg] = tuple(secrets)
|
|
700
914
|
elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
701
915
|
# Handle volume_mounts list/tuple of tuples
|
|
702
916
|
mounts = []
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
917
|
+
value_node = kw.value
|
|
918
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
919
|
+
for elt in value_node.elts:
|
|
920
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
921
|
+
name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
922
|
+
mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
923
|
+
if name and mount:
|
|
924
|
+
mounts.append((name, mount))
|
|
709
925
|
kwargs[kw.arg] = tuple(mounts)
|
|
710
926
|
|
|
711
927
|
return ModalDeploymentConfig(**kwargs)
|
|
@@ -832,6 +1048,71 @@ def _import_task_app_module(
|
|
|
832
1048
|
return module
|
|
833
1049
|
|
|
834
1050
|
|
|
1051
|
+
@contextlib.contextmanager
|
|
1052
|
+
def _safe_import_context() -> Iterator[None]:
|
|
1053
|
+
"""Guard module imports against argparse/uvicorn side effects."""
|
|
1054
|
+
|
|
1055
|
+
original_argv = sys.argv[:]
|
|
1056
|
+
sys.argv = [original_argv[0]] if original_argv else ["python"]
|
|
1057
|
+
|
|
1058
|
+
parser_cls = argparse.ArgumentParser
|
|
1059
|
+
old_parse_args = parser_cls.parse_args
|
|
1060
|
+
|
|
1061
|
+
def _parse_noargs(self, args=None, namespace=None): # type: ignore[override]
|
|
1062
|
+
if args is None:
|
|
1063
|
+
args = []
|
|
1064
|
+
if namespace is None:
|
|
1065
|
+
namespace = argparse.Namespace()
|
|
1066
|
+
try:
|
|
1067
|
+
return old_parse_args(self, args, namespace)
|
|
1068
|
+
except SystemExit:
|
|
1069
|
+
return namespace
|
|
1070
|
+
|
|
1071
|
+
parser_cls.parse_args = _parse_noargs # type: ignore[assignment]
|
|
1072
|
+
|
|
1073
|
+
uvicorn_run = None
|
|
1074
|
+
run_task_app_orig = None
|
|
1075
|
+
try:
|
|
1076
|
+
import uvicorn # type: ignore
|
|
1077
|
+
|
|
1078
|
+
uvicorn_run = uvicorn.run
|
|
1079
|
+
uvicorn.run = lambda *args, **kwargs: None # type: ignore[assignment]
|
|
1080
|
+
except Exception:
|
|
1081
|
+
uvicorn_run = None
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
_task_server_patch = cast(
|
|
1085
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1086
|
+
)
|
|
1087
|
+
run_task_app_orig = cast(Callable[..., Any], _task_server_patch.run_task_app)
|
|
1088
|
+
_task_server_patch.run_task_app = ( # type: ignore[assignment]
|
|
1089
|
+
lambda *args, **kwargs: None
|
|
1090
|
+
)
|
|
1091
|
+
except Exception:
|
|
1092
|
+
run_task_app_orig = None
|
|
1093
|
+
|
|
1094
|
+
try:
|
|
1095
|
+
yield
|
|
1096
|
+
finally:
|
|
1097
|
+
sys.argv = original_argv
|
|
1098
|
+
parser_cls.parse_args = old_parse_args # type: ignore[assignment]
|
|
1099
|
+
if uvicorn_run is not None:
|
|
1100
|
+
try:
|
|
1101
|
+
import uvicorn # type: ignore
|
|
1102
|
+
|
|
1103
|
+
uvicorn.run = uvicorn_run # type: ignore[assignment]
|
|
1104
|
+
except Exception:
|
|
1105
|
+
pass
|
|
1106
|
+
if run_task_app_orig is not None:
|
|
1107
|
+
try:
|
|
1108
|
+
_task_server_patch = cast(
|
|
1109
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1110
|
+
)
|
|
1111
|
+
_task_server_patch.run_task_app = run_task_app_orig # type: ignore[assignment]
|
|
1112
|
+
except Exception:
|
|
1113
|
+
pass
|
|
1114
|
+
|
|
1115
|
+
|
|
835
1116
|
def _load_entry_from_path(
|
|
836
1117
|
path: Path, app_id: str, module_search_roots: Sequence[Path] | None = None
|
|
837
1118
|
) -> TaskAppEntryType:
|
|
@@ -859,13 +1140,14 @@ def _load_entry_from_path(
|
|
|
859
1140
|
|
|
860
1141
|
for module_name, namespace_root in _possible_module_names(resolved, search_roots):
|
|
861
1142
|
try:
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
1143
|
+
with _safe_import_context():
|
|
1144
|
+
module = _import_task_app_module(
|
|
1145
|
+
resolved,
|
|
1146
|
+
module_name,
|
|
1147
|
+
namespace_root=namespace_root,
|
|
1148
|
+
sys_path_roots=search_roots,
|
|
1149
|
+
ensure_namespace=True,
|
|
1150
|
+
)
|
|
869
1151
|
break
|
|
870
1152
|
except Exception as exc: # pragma: no cover - best-effort fallbacks
|
|
871
1153
|
last_error = exc
|
|
@@ -874,13 +1156,14 @@ def _load_entry_from_path(
|
|
|
874
1156
|
if module is None:
|
|
875
1157
|
hashed_name = f"_synth_task_app_{hashlib.md5(str(resolved).encode(), usedforsecurity=False).hexdigest()}"
|
|
876
1158
|
try:
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
1159
|
+
with _safe_import_context():
|
|
1160
|
+
module = _import_task_app_module(
|
|
1161
|
+
resolved,
|
|
1162
|
+
hashed_name,
|
|
1163
|
+
namespace_root=None,
|
|
1164
|
+
sys_path_roots=search_roots,
|
|
1165
|
+
ensure_namespace=False,
|
|
1166
|
+
)
|
|
884
1167
|
except Exception as exc: # pragma: no cover - propagate meaningful error
|
|
885
1168
|
detail = last_error or exc
|
|
886
1169
|
raise click.ClickException(f"Failed to import {resolved}: {detail}") from detail
|
|
@@ -928,7 +1211,10 @@ def _load_entry_from_path(
|
|
|
928
1211
|
if has_required:
|
|
929
1212
|
continue
|
|
930
1213
|
try:
|
|
931
|
-
|
|
1214
|
+
with _safe_import_context():
|
|
1215
|
+
result = attr()
|
|
1216
|
+
except SystemExit:
|
|
1217
|
+
continue
|
|
932
1218
|
except Exception:
|
|
933
1219
|
continue
|
|
934
1220
|
if isinstance(result, TaskAppConfig) and result.app_id == app_id:
|
|
@@ -1024,21 +1310,173 @@ def _resolve_env_paths_for_script(script_path: Path, explicit: Sequence[str]) ->
|
|
|
1024
1310
|
return [env_candidates[choice - 1]]
|
|
1025
1311
|
|
|
1026
1312
|
|
|
1313
|
+
def _path_is_within(child: Path, parent: Path) -> bool:
|
|
1314
|
+
try:
|
|
1315
|
+
child.resolve().relative_to(parent.resolve())
|
|
1316
|
+
return True
|
|
1317
|
+
except Exception:
|
|
1318
|
+
return False
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
@functools.lru_cache(maxsize=16)
|
|
1322
|
+
def _is_modal_shim(path_str: str) -> bool:
|
|
1323
|
+
"""Return True if the candidate CLI path refers to the synth-ai shim."""
|
|
1324
|
+
|
|
1325
|
+
path = Path(path_str)
|
|
1326
|
+
try:
|
|
1327
|
+
resolved = path.resolve(strict=True)
|
|
1328
|
+
except Exception:
|
|
1329
|
+
resolved = path
|
|
1330
|
+
|
|
1331
|
+
if not resolved.exists() or resolved.is_dir():
|
|
1332
|
+
return False
|
|
1333
|
+
|
|
1334
|
+
snippet = ""
|
|
1335
|
+
try:
|
|
1336
|
+
snippet = resolved.read_bytes()[:4096].decode("utf-8", errors="ignore")
|
|
1337
|
+
except Exception:
|
|
1338
|
+
snippet = ""
|
|
1339
|
+
|
|
1340
|
+
shim_markers = (
|
|
1341
|
+
"synth_ai.cli._modal_wrapper",
|
|
1342
|
+
"from modal.__main__ import main",
|
|
1343
|
+
"import modal.__main__",
|
|
1344
|
+
"run_module('modal.__main__'",
|
|
1345
|
+
)
|
|
1346
|
+
if snippet and any(marker in snippet for marker in shim_markers):
|
|
1347
|
+
return True
|
|
1348
|
+
|
|
1349
|
+
try:
|
|
1350
|
+
size = resolved.stat().st_size
|
|
1351
|
+
except Exception:
|
|
1352
|
+
size = None
|
|
1353
|
+
|
|
1354
|
+
if (
|
|
1355
|
+
size is not None
|
|
1356
|
+
and size < 2048
|
|
1357
|
+
and "python" in (snippet.splitlines() or [""])[0]
|
|
1358
|
+
and (
|
|
1359
|
+
"modal.__main__" in snippet
|
|
1360
|
+
or "modal.__main__" in snippet.replace(" ", "")
|
|
1361
|
+
)
|
|
1362
|
+
):
|
|
1363
|
+
return True
|
|
1364
|
+
|
|
1365
|
+
virtual_env = os.environ.get("VIRTUAL_ENV")
|
|
1366
|
+
if virtual_env and _path_is_within(resolved, Path(virtual_env)):
|
|
1367
|
+
return True
|
|
1368
|
+
|
|
1369
|
+
if _path_is_within(resolved, REPO_ROOT):
|
|
1370
|
+
return True
|
|
1371
|
+
|
|
1372
|
+
uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
|
|
1373
|
+
return uv_tools_dir.exists() and _path_is_within(resolved, uv_tools_dir)
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
def _find_modal_executable(modal_cli: str) -> tuple[str | None, str | None]:
|
|
1377
|
+
"""Return the first non-shim executable and the first shim discovered on PATH."""
|
|
1378
|
+
|
|
1379
|
+
if not modal_cli:
|
|
1380
|
+
modal_cli = "modal"
|
|
1381
|
+
|
|
1382
|
+
candidate_path = Path(modal_cli).expanduser()
|
|
1383
|
+
if candidate_path.is_absolute() or len(candidate_path.parts) > 1:
|
|
1384
|
+
resolved_candidate = candidate_path
|
|
1385
|
+
if not resolved_candidate.is_absolute():
|
|
1386
|
+
resolved_candidate = (Path.cwd() / resolved_candidate).resolve()
|
|
1387
|
+
else:
|
|
1388
|
+
resolved_candidate = resolved_candidate.resolve()
|
|
1389
|
+
if not resolved_candidate.exists():
|
|
1390
|
+
raise click.ClickException(f"--modal-cli path does not exist: {resolved_candidate}")
|
|
1391
|
+
if not os.access(resolved_candidate, os.X_OK):
|
|
1392
|
+
raise click.ClickException(f"--modal-cli is not executable: {resolved_candidate}")
|
|
1393
|
+
return str(resolved_candidate), None
|
|
1394
|
+
|
|
1395
|
+
path_env = os.environ.get("PATH", "")
|
|
1396
|
+
if not path_env:
|
|
1397
|
+
return None, None
|
|
1398
|
+
|
|
1399
|
+
seen_dirs: set[str] = set()
|
|
1400
|
+
seen_candidates: set[str] = set()
|
|
1401
|
+
shim_path: str | None = None
|
|
1402
|
+
|
|
1403
|
+
for raw_entry in path_env.split(os.pathsep):
|
|
1404
|
+
if not raw_entry:
|
|
1405
|
+
continue
|
|
1406
|
+
try:
|
|
1407
|
+
resolved_entry = str(Path(raw_entry).resolve())
|
|
1408
|
+
except Exception:
|
|
1409
|
+
resolved_entry = os.path.normpath(raw_entry)
|
|
1410
|
+
if resolved_entry in seen_dirs:
|
|
1411
|
+
continue
|
|
1412
|
+
seen_dirs.add(resolved_entry)
|
|
1413
|
+
|
|
1414
|
+
candidate = shutil.which(modal_cli, path=raw_entry)
|
|
1415
|
+
if candidate is None:
|
|
1416
|
+
continue
|
|
1417
|
+
if candidate in seen_candidates:
|
|
1418
|
+
continue
|
|
1419
|
+
seen_candidates.add(candidate)
|
|
1420
|
+
|
|
1421
|
+
if _is_modal_shim(candidate):
|
|
1422
|
+
if shim_path is None:
|
|
1423
|
+
shim_path = candidate
|
|
1424
|
+
continue
|
|
1425
|
+
return candidate, shim_path
|
|
1426
|
+
|
|
1427
|
+
return None, shim_path
|
|
1428
|
+
|
|
1429
|
+
|
|
1027
1430
|
def _modal_command_prefix(modal_cli: str) -> list[str]:
|
|
1028
1431
|
"""Resolve a command prefix for invoking the Modal CLI within the active environment."""
|
|
1029
|
-
|
|
1432
|
+
|
|
1433
|
+
force_wrapper_env = os.environ.get("SYNTH_FORCE_MODAL_WRAPPER", "").strip().lower()
|
|
1434
|
+
if force_wrapper_env in {"1", "true", "yes"}:
|
|
1435
|
+
click.secho(
|
|
1436
|
+
"[modal-prefix] SYNTH_FORCE_MODAL_WRAPPER=1 -> using in-process wrapper",
|
|
1437
|
+
fg="yellow",
|
|
1438
|
+
)
|
|
1030
1439
|
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1031
1440
|
|
|
1032
|
-
|
|
1033
|
-
if
|
|
1034
|
-
|
|
1441
|
+
lookup = modal_cli or "modal"
|
|
1442
|
+
spec = importlib.util.find_spec("modal") if lookup == "modal" else None
|
|
1443
|
+
|
|
1444
|
+
preferred, shim_candidate = _find_modal_executable(lookup)
|
|
1445
|
+
if preferred is not None:
|
|
1446
|
+
detail = f"[modal-prefix] modal_cli={lookup} selected={preferred}"
|
|
1447
|
+
if lookup == "modal":
|
|
1448
|
+
detail += f" spec={'yes' if spec else 'no'}"
|
|
1449
|
+
click.secho(detail, fg="cyan")
|
|
1450
|
+
return [preferred]
|
|
1451
|
+
|
|
1452
|
+
if lookup != "modal":
|
|
1453
|
+
raise click.ClickException(f"Modal CLI not found (looked for '{lookup}')")
|
|
1454
|
+
|
|
1455
|
+
if spec is not None:
|
|
1456
|
+
warning = "[modal-prefix] Using synth-ai modal shim; pass --modal-cli /path/to/modal to override."
|
|
1457
|
+
if shim_candidate is not None:
|
|
1458
|
+
warning = (
|
|
1459
|
+
f"[modal-prefix] Using synth-ai modal shim at {shim_candidate}; "
|
|
1460
|
+
"pass --modal-cli /path/to/modal to override."
|
|
1461
|
+
)
|
|
1462
|
+
click.secho(warning, fg="yellow")
|
|
1463
|
+
click.secho(
|
|
1464
|
+
"[modal-prefix] modal_cli=modal selected=module-wrapper spec=yes",
|
|
1465
|
+
fg="yellow",
|
|
1466
|
+
)
|
|
1467
|
+
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1035
1468
|
|
|
1036
|
-
if
|
|
1469
|
+
if shim_candidate is not None:
|
|
1037
1470
|
raise click.ClickException(
|
|
1038
|
-
"Modal CLI
|
|
1039
|
-
"
|
|
1471
|
+
"Modal CLI resolution found the synth-ai shim but the 'modal' package "
|
|
1472
|
+
"is not importable in this environment. Install the official Modal CLI "
|
|
1473
|
+
"or pass --modal-cli with its path."
|
|
1040
1474
|
)
|
|
1041
|
-
|
|
1475
|
+
|
|
1476
|
+
raise click.ClickException(
|
|
1477
|
+
"Modal CLI not found. Install the 'modal' package in this environment or pass "
|
|
1478
|
+
"--modal-cli with an explicit path."
|
|
1479
|
+
)
|
|
1042
1480
|
|
|
1043
1481
|
|
|
1044
1482
|
def _build_modal_app_wrapper(original_script: Path) -> tuple[Path, Path]:
|
|
@@ -1173,8 +1611,15 @@ def _run_modal_script(
|
|
|
1173
1611
|
if modal_name and command == "deploy":
|
|
1174
1612
|
cmd.extend(["--name", modal_name])
|
|
1175
1613
|
if dry_run:
|
|
1176
|
-
click.echo(
|
|
1614
|
+
click.echo(
|
|
1615
|
+
"Dry run: " + " ".join(shlex.quote(component) for component in cmd),
|
|
1616
|
+
err=False,
|
|
1617
|
+
)
|
|
1177
1618
|
return
|
|
1619
|
+
click.secho(
|
|
1620
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1621
|
+
fg="cyan",
|
|
1622
|
+
)
|
|
1178
1623
|
try:
|
|
1179
1624
|
# Stream output live for better diagnostics
|
|
1180
1625
|
proc = subprocess.Popen(
|
|
@@ -1429,7 +1874,6 @@ def _run_modal_with_entry(
|
|
|
1429
1874
|
inline_secret_values=inline_secret_values,
|
|
1430
1875
|
)
|
|
1431
1876
|
cmd = [*_modal_command_prefix(modal_cli), command, str(script_path)]
|
|
1432
|
-
|
|
1433
1877
|
if modal_name and command == "deploy":
|
|
1434
1878
|
cmd.extend(["--name", modal_name])
|
|
1435
1879
|
|
|
@@ -1444,9 +1888,13 @@ def _run_modal_with_entry(
|
|
|
1444
1888
|
proc_env["PYTHONPATH"] = os.pathsep.join(list(dict.fromkeys(pythonpath_entries)))
|
|
1445
1889
|
|
|
1446
1890
|
if dry_run:
|
|
1447
|
-
click.echo("Dry run: " + " ".join(cmd))
|
|
1891
|
+
click.echo("Dry run: " + " ".join(shlex.quote(component) for component in cmd))
|
|
1448
1892
|
script_path.unlink(missing_ok=True)
|
|
1449
1893
|
return
|
|
1894
|
+
click.secho(
|
|
1895
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1896
|
+
fg="cyan",
|
|
1897
|
+
)
|
|
1450
1898
|
|
|
1451
1899
|
try:
|
|
1452
1900
|
# Stream output live for better diagnostics
|
|
@@ -1531,6 +1979,10 @@ def _parse_env_file(path: Path) -> dict[str, str]:
|
|
|
1531
1979
|
|
|
1532
1980
|
|
|
1533
1981
|
def _interactive_fill_env(env_path: Path) -> Path | None:
|
|
1982
|
+
if not sys.stdin.isatty():
|
|
1983
|
+
raise click.ClickException(
|
|
1984
|
+
"ENVIRONMENT_API_KEY missing. Provide --env-file or run `synth-ai setup` in an interactive shell to create one."
|
|
1985
|
+
)
|
|
1534
1986
|
existing = _parse_env_file(env_path) if env_path.exists() else {}
|
|
1535
1987
|
|
|
1536
1988
|
def _prompt(label: str, *, default: str = "", required: bool) -> str | None:
|
|
@@ -1570,6 +2022,10 @@ def _ensure_env_values(env_paths: list[Path], fallback_dir: Path) -> None:
|
|
|
1570
2022
|
if (os.environ.get("ENVIRONMENT_API_KEY") or "").strip():
|
|
1571
2023
|
return
|
|
1572
2024
|
target = env_paths[0] if env_paths else (fallback_dir / ".env").resolve()
|
|
2025
|
+
click.echo(
|
|
2026
|
+
"⚠️ ENVIRONMENT_API_KEY not set. Run `uvx synth-ai setup`, "
|
|
2027
|
+
"or pass --env-file pointing at a .env with ENVIRONMENT_API_KEY."
|
|
2028
|
+
)
|
|
1573
2029
|
result = _interactive_fill_env(target)
|
|
1574
2030
|
if result is None:
|
|
1575
2031
|
raise click.ClickException("ENVIRONMENT_API_KEY required to continue")
|
|
@@ -1593,7 +2049,7 @@ def _deploy_entry(
|
|
|
1593
2049
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1594
2050
|
)
|
|
1595
2051
|
|
|
1596
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2052
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1597
2053
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1598
2054
|
_run_modal_with_entry(
|
|
1599
2055
|
entry,
|
|
@@ -1620,7 +2076,7 @@ def _modal_serve_entry(
|
|
|
1620
2076
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1621
2077
|
)
|
|
1622
2078
|
|
|
1623
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2079
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1624
2080
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1625
2081
|
_run_modal_with_entry(
|
|
1626
2082
|
entry,
|
|
@@ -1651,6 +2107,255 @@ def list_apps() -> None:
|
|
|
1651
2107
|
click.echo(f"- {entry.app_id}{aliases}: {entry.description}")
|
|
1652
2108
|
|
|
1653
2109
|
|
|
2110
|
+
@task_app_group.command("validate")
|
|
2111
|
+
@click.argument("app_id", type=str, required=True)
|
|
2112
|
+
@click.option(
|
|
2113
|
+
"--url",
|
|
2114
|
+
type=str,
|
|
2115
|
+
default=None,
|
|
2116
|
+
help="Task app URL to validate (if not provided, starts a local server)",
|
|
2117
|
+
)
|
|
2118
|
+
@click.option(
|
|
2119
|
+
"--port",
|
|
2120
|
+
type=int,
|
|
2121
|
+
default=8765,
|
|
2122
|
+
help="Port to use for temporary server (default: 8765)",
|
|
2123
|
+
)
|
|
2124
|
+
@click.option(
|
|
2125
|
+
"--api-key",
|
|
2126
|
+
type=str,
|
|
2127
|
+
default=None,
|
|
2128
|
+
envvar="ENVIRONMENT_API_KEY",
|
|
2129
|
+
help="API key for authentication (default: $ENVIRONMENT_API_KEY)",
|
|
2130
|
+
)
|
|
2131
|
+
@click.option(
|
|
2132
|
+
"--min-instances",
|
|
2133
|
+
type=int,
|
|
2134
|
+
default=10,
|
|
2135
|
+
help="Minimum number of task instances required (default: 10)",
|
|
2136
|
+
)
|
|
2137
|
+
@click.option(
|
|
2138
|
+
"--verbose",
|
|
2139
|
+
"-v",
|
|
2140
|
+
is_flag=True,
|
|
2141
|
+
help="Show detailed information about the task app",
|
|
2142
|
+
)
|
|
2143
|
+
@click.option(
|
|
2144
|
+
"--json",
|
|
2145
|
+
"output_json",
|
|
2146
|
+
is_flag=True,
|
|
2147
|
+
help="Output results as JSON",
|
|
2148
|
+
)
|
|
2149
|
+
def validate_task_app_cmd(
|
|
2150
|
+
app_id: str,
|
|
2151
|
+
url: str | None,
|
|
2152
|
+
port: int,
|
|
2153
|
+
api_key: str | None,
|
|
2154
|
+
min_instances: int,
|
|
2155
|
+
verbose: bool,
|
|
2156
|
+
output_json: bool,
|
|
2157
|
+
) -> None:
|
|
2158
|
+
"""Validate a task app deployment readiness.
|
|
2159
|
+
|
|
2160
|
+
This command verifies that a task app is properly configured and ready to run
|
|
2161
|
+
by checking all required HTTP endpoints, authentication, and task availability.
|
|
2162
|
+
|
|
2163
|
+
By default, it starts a temporary local server for validation. You can also
|
|
2164
|
+
validate a remote deployment by passing --url.
|
|
2165
|
+
|
|
2166
|
+
\b
|
|
2167
|
+
What gets validated:
|
|
2168
|
+
• Root endpoint (/) responds correctly
|
|
2169
|
+
• Health endpoint (/health) is accessible with proper authentication
|
|
2170
|
+
• Info endpoint (/info) returns valid task metadata
|
|
2171
|
+
• Task info endpoint (/task_info) provides task instances
|
|
2172
|
+
• Rollout endpoint (/rollout) is registered
|
|
2173
|
+
• At least N task instances are available (default: 10)
|
|
2174
|
+
|
|
2175
|
+
\b
|
|
2176
|
+
Examples:
|
|
2177
|
+
|
|
2178
|
+
\b
|
|
2179
|
+
Validate grpo-crafter (starts local server automatically):
|
|
2180
|
+
$ synth-ai task-app validate grpo-crafter
|
|
2181
|
+
|
|
2182
|
+
\b
|
|
2183
|
+
Validate sokoban with verbose output:
|
|
2184
|
+
$ synth-ai task-app validate sokoban --verbose
|
|
2185
|
+
|
|
2186
|
+
\b
|
|
2187
|
+
Validate with custom port:
|
|
2188
|
+
$ synth-ai task-app validate sokoban --port 9000
|
|
2189
|
+
|
|
2190
|
+
\b
|
|
2191
|
+
Validate a remote deployment:
|
|
2192
|
+
$ synth-ai task-app validate grpo-crafter --url https://my-crafter.modal.run
|
|
2193
|
+
|
|
2194
|
+
\b
|
|
2195
|
+
Require at least 20 task instances:
|
|
2196
|
+
$ synth-ai task-app validate grpo-crafter --min-instances 20
|
|
2197
|
+
|
|
2198
|
+
\b
|
|
2199
|
+
Get JSON output for automation:
|
|
2200
|
+
$ synth-ai task-app validate sokoban --json
|
|
2201
|
+
|
|
2202
|
+
\b
|
|
2203
|
+
Common use cases:
|
|
2204
|
+
• Pre-deployment verification: Check task app works before deploying to Modal
|
|
2205
|
+
• CI/CD integration: Use --json flag for automated validation in pipelines
|
|
2206
|
+
• Debug failing deployments: Use --verbose to see detailed endpoint responses
|
|
2207
|
+
• Test API key configuration: Verify authentication is set up correctly
|
|
2208
|
+
"""
|
|
2209
|
+
import asyncio
|
|
2210
|
+
import socket
|
|
2211
|
+
import subprocess
|
|
2212
|
+
import tempfile
|
|
2213
|
+
import time
|
|
2214
|
+
|
|
2215
|
+
# Import the validate_task_app function defined in this module
|
|
2216
|
+
from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
|
|
2217
|
+
|
|
2218
|
+
proc = None
|
|
2219
|
+
task_app_url = url
|
|
2220
|
+
|
|
2221
|
+
try:
|
|
2222
|
+
# If no URL provided, start a temporary server
|
|
2223
|
+
if not task_app_url:
|
|
2224
|
+
# Find an available port
|
|
2225
|
+
def is_port_available(port: int) -> bool:
|
|
2226
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
2227
|
+
try:
|
|
2228
|
+
s.bind(("", port))
|
|
2229
|
+
return True
|
|
2230
|
+
except OSError:
|
|
2231
|
+
return False
|
|
2232
|
+
|
|
2233
|
+
while not is_port_available(port):
|
|
2234
|
+
port += 1
|
|
2235
|
+
|
|
2236
|
+
task_app_url = f"http://localhost:{port}"
|
|
2237
|
+
|
|
2238
|
+
if not output_json:
|
|
2239
|
+
click.echo(f"Starting temporary {app_id} server on port {port}...")
|
|
2240
|
+
|
|
2241
|
+
# Start the server in background
|
|
2242
|
+
env = os.environ.copy()
|
|
2243
|
+
if api_key:
|
|
2244
|
+
env["ENVIRONMENT_API_KEY"] = api_key
|
|
2245
|
+
|
|
2246
|
+
# Create a temporary trace DB and trace dir to avoid prompts
|
|
2247
|
+
import tempfile
|
|
2248
|
+
temp_dir = tempfile.mkdtemp()
|
|
2249
|
+
temp_trace_db = os.path.join(temp_dir, "validate_trace.db")
|
|
2250
|
+
temp_trace_dir = os.path.join(temp_dir, "traces")
|
|
2251
|
+
os.makedirs(temp_trace_dir, exist_ok=True)
|
|
2252
|
+
|
|
2253
|
+
proc = subprocess.Popen(
|
|
2254
|
+
[
|
|
2255
|
+
"uv",
|
|
2256
|
+
"run",
|
|
2257
|
+
"synth-ai",
|
|
2258
|
+
"task-app",
|
|
2259
|
+
"serve",
|
|
2260
|
+
app_id,
|
|
2261
|
+
"--port",
|
|
2262
|
+
str(port),
|
|
2263
|
+
"--no-reload",
|
|
2264
|
+
"--trace",
|
|
2265
|
+
temp_trace_dir,
|
|
2266
|
+
"--trace-db",
|
|
2267
|
+
temp_trace_db,
|
|
2268
|
+
],
|
|
2269
|
+
env=env,
|
|
2270
|
+
stdin=subprocess.PIPE, # Add stdin to handle any prompts
|
|
2271
|
+
stdout=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2272
|
+
stderr=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2273
|
+
text=True,
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
# Write empty input to stdin to skip any prompts
|
|
2277
|
+
if proc.stdin:
|
|
2278
|
+
try:
|
|
2279
|
+
proc.stdin.write("\n")
|
|
2280
|
+
proc.stdin.flush()
|
|
2281
|
+
proc.stdin.close()
|
|
2282
|
+
except Exception:
|
|
2283
|
+
pass
|
|
2284
|
+
|
|
2285
|
+
# Wait for server to be ready
|
|
2286
|
+
if not output_json:
|
|
2287
|
+
click.echo("Waiting for server to start...")
|
|
2288
|
+
|
|
2289
|
+
import httpx
|
|
2290
|
+
for _attempt in range(60): # 30 seconds timeout
|
|
2291
|
+
try:
|
|
2292
|
+
async def check_health():
|
|
2293
|
+
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
2294
|
+
resp = await client.get(f"{task_app_url}/")
|
|
2295
|
+
return resp.status_code == 200
|
|
2296
|
+
|
|
2297
|
+
if asyncio.run(check_health()):
|
|
2298
|
+
break
|
|
2299
|
+
except Exception:
|
|
2300
|
+
pass
|
|
2301
|
+
|
|
2302
|
+
# Check if process died
|
|
2303
|
+
if proc.poll() is not None:
|
|
2304
|
+
stderr_output = ""
|
|
2305
|
+
if proc.stderr and not output_json:
|
|
2306
|
+
stderr_output = proc.stderr.read()
|
|
2307
|
+
click.echo(click.style("✗ Server process exited unexpectedly", fg="red"), err=True)
|
|
2308
|
+
if stderr_output and not output_json:
|
|
2309
|
+
click.echo(f"Error output:\n{stderr_output}", err=True)
|
|
2310
|
+
sys.exit(1)
|
|
2311
|
+
|
|
2312
|
+
time.sleep(0.5)
|
|
2313
|
+
else:
|
|
2314
|
+
click.echo(click.style("✗ Server failed to start within 30 seconds", fg="red"), err=True)
|
|
2315
|
+
sys.exit(1)
|
|
2316
|
+
|
|
2317
|
+
if not output_json:
|
|
2318
|
+
click.echo(click.style("✓ Server started", fg="green"))
|
|
2319
|
+
click.echo()
|
|
2320
|
+
|
|
2321
|
+
# Ensure URL doesn't have trailing slash
|
|
2322
|
+
task_app_url = task_app_url.rstrip("/")
|
|
2323
|
+
|
|
2324
|
+
async def _run() -> tuple[bool, dict[str, Any]]:
|
|
2325
|
+
return await validate_task_app(
|
|
2326
|
+
url=task_app_url,
|
|
2327
|
+
api_key=api_key,
|
|
2328
|
+
min_instances=min_instances,
|
|
2329
|
+
verbose=verbose,
|
|
2330
|
+
)
|
|
2331
|
+
|
|
2332
|
+
success, results = asyncio.run(_run())
|
|
2333
|
+
|
|
2334
|
+
if output_json:
|
|
2335
|
+
import json as _json
|
|
2336
|
+
click.echo(_json.dumps(results, indent=2))
|
|
2337
|
+
|
|
2338
|
+
sys.exit(0 if success else 1)
|
|
2339
|
+
|
|
2340
|
+
finally:
|
|
2341
|
+
# Cleanup: stop the temporary server
|
|
2342
|
+
if proc is not None:
|
|
2343
|
+
if not output_json:
|
|
2344
|
+
click.echo("\nStopping temporary server...")
|
|
2345
|
+
try:
|
|
2346
|
+
proc.terminate()
|
|
2347
|
+
proc.wait(timeout=5)
|
|
2348
|
+
except Exception:
|
|
2349
|
+
proc.kill()
|
|
2350
|
+
|
|
2351
|
+
# Cleanup temp trace DB
|
|
2352
|
+
if not url and 'temp_dir' in locals():
|
|
2353
|
+
import contextlib
|
|
2354
|
+
import shutil
|
|
2355
|
+
with contextlib.suppress(Exception):
|
|
2356
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
2357
|
+
|
|
2358
|
+
|
|
1654
2359
|
def _load_env_files_into_process(paths: Sequence[str]) -> None:
|
|
1655
2360
|
for p in paths:
|
|
1656
2361
|
try:
|
|
@@ -1907,7 +2612,9 @@ def serve_task_group(
|
|
|
1907
2612
|
)
|
|
1908
2613
|
|
|
1909
2614
|
|
|
1910
|
-
def _determine_env_files(
|
|
2615
|
+
def _determine_env_files(
|
|
2616
|
+
entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
|
|
2617
|
+
) -> list[Path]:
|
|
1911
2618
|
resolved: list[Path] = []
|
|
1912
2619
|
for candidate in user_env_files:
|
|
1913
2620
|
p = Path(candidate).expanduser()
|
|
@@ -1917,30 +2624,46 @@ def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str])
|
|
|
1917
2624
|
if resolved:
|
|
1918
2625
|
return resolved
|
|
1919
2626
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
2627
|
+
declared: list[Path] = []
|
|
2628
|
+
for candidate in getattr(entry, "env_files", ()) or ():
|
|
2629
|
+
try:
|
|
2630
|
+
p = Path(candidate).expanduser()
|
|
2631
|
+
except Exception:
|
|
2632
|
+
continue
|
|
2633
|
+
if p.exists() and p.is_file():
|
|
2634
|
+
declared.append(p)
|
|
2635
|
+
if declared:
|
|
2636
|
+
return declared
|
|
1928
2637
|
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
for repo_file in repo_env_files:
|
|
1933
|
-
if repo_file not in env_candidates:
|
|
1934
|
-
env_candidates.append(repo_file)
|
|
2638
|
+
def _append_candidate(collection: list[Path], candidate: Path) -> None:
|
|
2639
|
+
if candidate.exists() and candidate.is_file() and candidate not in collection:
|
|
2640
|
+
collection.append(candidate)
|
|
1935
2641
|
|
|
1936
|
-
|
|
1937
|
-
raise click.ClickException("No env file found. Pass --env-file explicitly.")
|
|
2642
|
+
auto_candidates: list[Path] = []
|
|
1938
2643
|
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
2644
|
+
search_dirs: list[Path] = []
|
|
2645
|
+
if original_path is not None:
|
|
2646
|
+
search_dirs.append(original_path.parent.resolve())
|
|
2647
|
+
for parent in original_path.parent.resolve().parents:
|
|
2648
|
+
search_dirs.append(parent)
|
|
2649
|
+
cwd = Path.cwd().resolve()
|
|
2650
|
+
if cwd not in search_dirs:
|
|
2651
|
+
search_dirs.append(cwd)
|
|
2652
|
+
repo_root = REPO_ROOT.resolve()
|
|
2653
|
+
if repo_root not in search_dirs:
|
|
2654
|
+
search_dirs.append(repo_root)
|
|
2655
|
+
|
|
2656
|
+
for directory in search_dirs:
|
|
2657
|
+
_append_candidate(auto_candidates, directory / ".env")
|
|
2658
|
+
for candidate in sorted(directory.glob("*.env")):
|
|
2659
|
+
_append_candidate(auto_candidates, candidate)
|
|
2660
|
+
|
|
2661
|
+
if auto_candidates:
|
|
2662
|
+
return [auto_candidates[0]]
|
|
2663
|
+
|
|
2664
|
+
raise click.ClickException(
|
|
2665
|
+
"No .env file discovered automatically. Pass --env-file /path/to/.env or generate one with `uvx synth-ai setup`."
|
|
2666
|
+
)
|
|
1944
2667
|
|
|
1945
2668
|
|
|
1946
2669
|
def _ensure_port_free(port: int, host: str, *, force: bool) -> None:
|
|
@@ -2242,7 +2965,14 @@ def deploy_app(
|
|
|
2242
2965
|
def modal_serve_app(
|
|
2243
2966
|
app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
|
|
2244
2967
|
) -> None:
|
|
2245
|
-
|
|
2968
|
+
click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
|
|
2969
|
+
try:
|
|
2970
|
+
choice = _select_app_choice(app_id, purpose="modal-serve")
|
|
2971
|
+
except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
|
|
2972
|
+
raise click.ClickException(
|
|
2973
|
+
f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
|
|
2974
|
+
"Make sure you're running the Click CLI (synth_ai.cli:cli)."
|
|
2975
|
+
) from exc
|
|
2246
2976
|
|
|
2247
2977
|
if choice.modal_script:
|
|
2248
2978
|
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
@@ -2251,6 +2981,7 @@ def modal_serve_app(
|
|
|
2251
2981
|
return
|
|
2252
2982
|
|
|
2253
2983
|
entry = choice.ensure_entry()
|
|
2984
|
+
click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
|
|
2254
2985
|
_modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
|
|
2255
2986
|
|
|
2256
2987
|
|
|
@@ -2480,22 +3211,60 @@ def register(cli: click.Group) -> None:
|
|
|
2480
3211
|
cli.add_command(serve_command)
|
|
2481
3212
|
cli.add_command(task_app_group)
|
|
2482
3213
|
cli.add_command(eval_command)
|
|
3214
|
+
cli.add_command(filter_command)
|
|
2483
3215
|
|
|
2484
3216
|
|
|
2485
|
-
@click.command(
|
|
3217
|
+
@click.command(
|
|
3218
|
+
"eval",
|
|
3219
|
+
help="Run one-off rollouts against a task app and print judge/eval summaries.",
|
|
3220
|
+
)
|
|
2486
3221
|
@click.argument("app_id", type=str, required=False)
|
|
2487
|
-
@click.option(
|
|
3222
|
+
@click.option(
|
|
3223
|
+
"--config",
|
|
3224
|
+
type=click.Path(),
|
|
3225
|
+
default=None,
|
|
3226
|
+
help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
|
|
3227
|
+
)
|
|
2488
3228
|
@click.option(
|
|
2489
3229
|
"--url",
|
|
2490
3230
|
"task_app_url",
|
|
2491
3231
|
type=str,
|
|
2492
3232
|
default=None,
|
|
2493
|
-
help="Base URL of a running task app (
|
|
3233
|
+
help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
|
|
3234
|
+
)
|
|
3235
|
+
@click.option(
|
|
3236
|
+
"--seeds",
|
|
3237
|
+
default="0,1,2,3,4",
|
|
3238
|
+
help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
|
|
2494
3239
|
)
|
|
2495
|
-
@click.option("--seeds", default="0,1,2,3,4", help="Comma-separated seeds/indices to evaluate")
|
|
2496
3240
|
@click.option("--split", default="train", show_default=True, help="Dataset split to use")
|
|
2497
|
-
@click.option(
|
|
2498
|
-
|
|
3241
|
+
@click.option(
|
|
3242
|
+
"--model",
|
|
3243
|
+
default=None,
|
|
3244
|
+
help="Model identifier. When omitted the CLI will prompt based on task metadata.",
|
|
3245
|
+
)
|
|
3246
|
+
@click.option(
|
|
3247
|
+
"--env-file",
|
|
3248
|
+
multiple=True,
|
|
3249
|
+
type=click.Path(),
|
|
3250
|
+
help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
|
|
3251
|
+
)
|
|
3252
|
+
@click.option(
|
|
3253
|
+
"--trace-db",
|
|
3254
|
+
default="traces/v3/eval_traces.db",
|
|
3255
|
+
show_default=True,
|
|
3256
|
+
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3257
|
+
)
|
|
3258
|
+
@click.option(
|
|
3259
|
+
"--metadata",
|
|
3260
|
+
multiple=True,
|
|
3261
|
+
help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
|
|
3262
|
+
)
|
|
3263
|
+
@click.option(
|
|
3264
|
+
"--metadata-sql",
|
|
3265
|
+
default=None,
|
|
3266
|
+
help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
|
|
3267
|
+
)
|
|
2499
3268
|
def eval_command(
|
|
2500
3269
|
app_id: str | None,
|
|
2501
3270
|
config: str | None,
|
|
@@ -2504,8 +3273,17 @@ def eval_command(
|
|
|
2504
3273
|
split: str,
|
|
2505
3274
|
model: str | None,
|
|
2506
3275
|
env_file: Sequence[str],
|
|
3276
|
+
trace_db: str,
|
|
3277
|
+
metadata: Sequence[str],
|
|
3278
|
+
metadata_sql: str | None,
|
|
2507
3279
|
) -> None:
|
|
2508
|
-
"""Run
|
|
3280
|
+
"""Run rollouts against a task app and report judge statistics.
|
|
3281
|
+
|
|
3282
|
+
By default the command spins up the selected task app in-process, executes the
|
|
3283
|
+
requested seeds, and prints aggregate scores (official and custom judges). When
|
|
3284
|
+
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3285
|
+
forward authentication headers to the running service.
|
|
3286
|
+
"""
|
|
2509
3287
|
cfg: dict[str, Any] = {}
|
|
2510
3288
|
config_path: Path | None = None
|
|
2511
3289
|
if config:
|
|
@@ -2534,6 +3312,50 @@ def eval_command(
|
|
|
2534
3312
|
|
|
2535
3313
|
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
2536
3314
|
|
|
3315
|
+
metadata_filters: dict[str, str] = {}
|
|
3316
|
+
cfg_metadata = cfg.get("metadata")
|
|
3317
|
+
if isinstance(cfg_metadata, dict):
|
|
3318
|
+
for key, value in cfg_metadata.items():
|
|
3319
|
+
metadata_filters[str(key)] = str(value)
|
|
3320
|
+
elif isinstance(cfg_metadata, list):
|
|
3321
|
+
for item in cfg_metadata:
|
|
3322
|
+
if isinstance(item, str) and "=" in item:
|
|
3323
|
+
key, value = item.split("=", 1)
|
|
3324
|
+
metadata_filters[key.strip()] = value.strip()
|
|
3325
|
+
|
|
3326
|
+
for item in metadata or ():
|
|
3327
|
+
if "=" not in item:
|
|
3328
|
+
raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
|
|
3329
|
+
key, value = item.split("=", 1)
|
|
3330
|
+
key = key.strip()
|
|
3331
|
+
value = value.strip()
|
|
3332
|
+
if not key or not value:
|
|
3333
|
+
raise click.ClickException(f"Invalid metadata filter: {item}")
|
|
3334
|
+
metadata_filters[key] = value
|
|
3335
|
+
|
|
3336
|
+
metadata_sql_query: str | None = None
|
|
3337
|
+
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3338
|
+
if isinstance(cfg_metadata_sql, dict):
|
|
3339
|
+
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3340
|
+
elif isinstance(cfg_metadata_sql, str):
|
|
3341
|
+
metadata_sql_query = cfg_metadata_sql
|
|
3342
|
+
|
|
3343
|
+
if metadata_sql:
|
|
3344
|
+
metadata_sql_query = metadata_sql
|
|
3345
|
+
if metadata_sql_query is not None:
|
|
3346
|
+
metadata_sql_query = str(metadata_sql_query)
|
|
3347
|
+
|
|
3348
|
+
trace_db_url: str | None = None
|
|
3349
|
+
trace_db = (trace_db or "").strip()
|
|
3350
|
+
if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
|
|
3351
|
+
if "://" in trace_db:
|
|
3352
|
+
trace_db_url = trace_db
|
|
3353
|
+
else:
|
|
3354
|
+
trace_path = Path(trace_db).expanduser()
|
|
3355
|
+
trace_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3356
|
+
trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
|
|
3357
|
+
trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
|
|
3358
|
+
|
|
2537
3359
|
# Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
|
|
2538
3360
|
if cfg.get("model") and not model:
|
|
2539
3361
|
model = str(cfg["model"]) # type: ignore[index]
|
|
@@ -2553,14 +3375,16 @@ def eval_command(
|
|
|
2553
3375
|
elif isinstance(ef, list):
|
|
2554
3376
|
env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
|
|
2555
3377
|
|
|
3378
|
+
choice_for_env: AppChoice | None = None
|
|
2556
3379
|
entry: TaskAppEntryType | None = None
|
|
2557
3380
|
if task_app_url is None:
|
|
2558
|
-
|
|
2559
|
-
entry =
|
|
3381
|
+
choice_for_env = _select_app_choice(app_id, purpose="eval")
|
|
3382
|
+
entry = choice_for_env.ensure_entry()
|
|
2560
3383
|
|
|
2561
3384
|
env_paths: list[Path] = []
|
|
2562
3385
|
if entry is not None:
|
|
2563
|
-
|
|
3386
|
+
original_env_path = choice_for_env.path if choice_for_env is not None else None
|
|
3387
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
|
|
2564
3388
|
else:
|
|
2565
3389
|
if not env_file:
|
|
2566
3390
|
raise click.ClickException("--env-file is required when using --url")
|
|
@@ -2583,12 +3407,30 @@ def eval_command(
|
|
|
2583
3407
|
app = create_task_app(config)
|
|
2584
3408
|
|
|
2585
3409
|
# Determine supported models
|
|
3410
|
+
inference_meta: dict[str, Any] = {}
|
|
2586
3411
|
supported: list[str] = []
|
|
3412
|
+
seen_models: set[str] = set()
|
|
3413
|
+
|
|
3414
|
+
def _add_supported_model(candidate: Any) -> None:
|
|
3415
|
+
if not candidate:
|
|
3416
|
+
return
|
|
3417
|
+
text = str(candidate).strip()
|
|
3418
|
+
if not text or text in seen_models:
|
|
3419
|
+
return
|
|
3420
|
+
supported.append(text)
|
|
3421
|
+
seen_models.add(text)
|
|
3422
|
+
|
|
2587
3423
|
if task_app_url is None:
|
|
2588
3424
|
try:
|
|
2589
|
-
|
|
3425
|
+
if hasattr(config, "base_task_info") and config.base_task_info:
|
|
3426
|
+
inf_obj = getattr(config.base_task_info, "inference", None)
|
|
3427
|
+
if inf_obj is not None:
|
|
3428
|
+
if hasattr(inf_obj, "model_dump"):
|
|
3429
|
+
inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
|
|
3430
|
+
elif isinstance(inf_obj, dict):
|
|
3431
|
+
inference_meta = dict(inf_obj)
|
|
2590
3432
|
except Exception:
|
|
2591
|
-
|
|
3433
|
+
inference_meta = {}
|
|
2592
3434
|
else:
|
|
2593
3435
|
try:
|
|
2594
3436
|
import httpx as _hx
|
|
@@ -2601,38 +3443,38 @@ def eval_command(
|
|
|
2601
3443
|
info = c.get("/info").json()
|
|
2602
3444
|
inf = info.get("inference") if isinstance(info, dict) else None
|
|
2603
3445
|
if isinstance(inf, dict):
|
|
2604
|
-
|
|
2605
|
-
if isinstance(m, list):
|
|
2606
|
-
supported = [str(x) for x in m]
|
|
2607
|
-
if not supported:
|
|
2608
|
-
providers = inf.get("providers")
|
|
2609
|
-
if isinstance(providers, list):
|
|
2610
|
-
if "openai" in providers:
|
|
2611
|
-
supported.append("gpt-5")
|
|
2612
|
-
if "groq" in providers:
|
|
2613
|
-
supported.append("groq:llama-3.1-70b-versatile")
|
|
2614
|
-
supported.append("synth:qwen-0.6b")
|
|
3446
|
+
inference_meta = dict(inf)
|
|
2615
3447
|
except Exception:
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
3448
|
+
inference_meta = {}
|
|
3449
|
+
|
|
3450
|
+
default_model = inference_meta.get("model")
|
|
3451
|
+
if isinstance(default_model, str):
|
|
3452
|
+
_add_supported_model(default_model)
|
|
3453
|
+
|
|
3454
|
+
models_field = inference_meta.get("models")
|
|
3455
|
+
if isinstance(models_field, list):
|
|
3456
|
+
for candidate in models_field:
|
|
3457
|
+
_add_supported_model(candidate)
|
|
3458
|
+
|
|
3459
|
+
supported_models = inference_meta.get("supported_models")
|
|
3460
|
+
if isinstance(supported_models, list):
|
|
3461
|
+
for candidate in supported_models:
|
|
3462
|
+
_add_supported_model(candidate)
|
|
3463
|
+
|
|
3464
|
+
providers = inference_meta.get("providers")
|
|
3465
|
+
if isinstance(providers, list):
|
|
3466
|
+
if "openai" in providers:
|
|
3467
|
+
_add_supported_model("gpt-5")
|
|
3468
|
+
if "groq" in providers:
|
|
3469
|
+
_add_supported_model("groq:llama-3.1-70b-versatile")
|
|
3470
|
+
|
|
3471
|
+
_add_supported_model("synth:qwen-0.6b")
|
|
2630
3472
|
|
|
2631
3473
|
selected_model = model
|
|
2632
3474
|
if not selected_model:
|
|
2633
3475
|
if not supported:
|
|
2634
3476
|
raise click.ClickException(
|
|
2635
|
-
"No supported models; supply --model or add base_task_info.inference.
|
|
3477
|
+
"No supported models; supply --model or add base_task_info.inference.model"
|
|
2636
3478
|
)
|
|
2637
3479
|
click.echo("Select model to evaluate:")
|
|
2638
3480
|
for idx, m in enumerate(supported, start=1):
|
|
@@ -2652,70 +3494,347 @@ def eval_command(
|
|
|
2652
3494
|
if api_key:
|
|
2653
3495
|
headers["X-API-Key"] = api_key
|
|
2654
3496
|
|
|
3497
|
+
# Precompute optional policy overrides from TOML
|
|
3498
|
+
policy_overrides: dict[str, Any] = {}
|
|
3499
|
+
try:
|
|
3500
|
+
# Accept [eval.policy] table or top-level keys for convenience
|
|
3501
|
+
if isinstance(cfg.get("policy"), dict):
|
|
3502
|
+
policy_overrides.update(dict(cfg["policy"]))
|
|
3503
|
+
# Back-compat: allow temperature/max_tokens at top level
|
|
3504
|
+
for k in (
|
|
3505
|
+
"temperature",
|
|
3506
|
+
"max_tokens",
|
|
3507
|
+
"reasoning_effort",
|
|
3508
|
+
"system_hint",
|
|
3509
|
+
"tool_choice",
|
|
3510
|
+
"inference_url",
|
|
3511
|
+
):
|
|
3512
|
+
if k in cfg and k not in policy_overrides:
|
|
3513
|
+
policy_overrides[k] = cfg.get(k)
|
|
3514
|
+
except Exception:
|
|
3515
|
+
policy_overrides = {}
|
|
3516
|
+
|
|
3517
|
+
raw_concurrency = cfg.get("concurrency")
|
|
3518
|
+
try:
|
|
3519
|
+
concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
|
|
3520
|
+
except Exception:
|
|
3521
|
+
concurrency_limit = 1
|
|
3522
|
+
if concurrency_limit <= 0:
|
|
3523
|
+
concurrency_limit = 1
|
|
3524
|
+
concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
|
|
3525
|
+
|
|
3526
|
+
judge_specs: list[JudgeSpec] = []
|
|
3527
|
+
|
|
3528
|
+
def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
|
|
3529
|
+
if not judge_cfg:
|
|
3530
|
+
return
|
|
3531
|
+
judge_module = judge_cfg.get("module")
|
|
3532
|
+
judge_path = judge_cfg.get("path")
|
|
3533
|
+
judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
|
|
3534
|
+
if judge_module and judge_path:
|
|
3535
|
+
raise click.ClickException("Judge config cannot set both 'module' and 'path'")
|
|
3536
|
+
if not judge_module and not judge_path:
|
|
3537
|
+
raise click.ClickException("Judge config requires 'module' or 'path'")
|
|
3538
|
+
try:
|
|
3539
|
+
if judge_module:
|
|
3540
|
+
module = importlib.import_module(str(judge_module))
|
|
3541
|
+
else:
|
|
3542
|
+
path = Path(str(judge_path)).expanduser()
|
|
3543
|
+
if not path.exists():
|
|
3544
|
+
raise click.ClickException(f"Judge module path not found: {path}")
|
|
3545
|
+
spec = importlib.util.spec_from_file_location(
|
|
3546
|
+
f"_eval_judge_{path.stem}", path
|
|
3547
|
+
)
|
|
3548
|
+
if not spec or not spec.loader:
|
|
3549
|
+
raise click.ClickException(f"Failed to load judge module from {path}")
|
|
3550
|
+
module = importlib.util.module_from_spec(spec)
|
|
3551
|
+
sys.modules[spec.name] = module
|
|
3552
|
+
spec.loader.exec_module(module)
|
|
3553
|
+
except click.ClickException:
|
|
3554
|
+
raise
|
|
3555
|
+
except Exception as exc:
|
|
3556
|
+
raise click.ClickException(f"Unable to load judge module: {exc}") from exc
|
|
3557
|
+
|
|
3558
|
+
if judge_callable_name:
|
|
3559
|
+
try:
|
|
3560
|
+
judge_fn = getattr(module, str(judge_callable_name))
|
|
3561
|
+
except AttributeError as exc:
|
|
3562
|
+
raise click.ClickException(
|
|
3563
|
+
f"Judge callable '{judge_callable_name}' not found in module"
|
|
3564
|
+
) from exc
|
|
3565
|
+
else:
|
|
3566
|
+
if hasattr(module, "judge"):
|
|
3567
|
+
judge_fn = module.judge
|
|
3568
|
+
else:
|
|
3569
|
+
raise click.ClickException("Judge module must expose 'judge' callable")
|
|
3570
|
+
|
|
3571
|
+
if not callable(judge_fn):
|
|
3572
|
+
raise click.ClickException("Judge callable is not callable")
|
|
3573
|
+
|
|
3574
|
+
judge_kwargs = {
|
|
3575
|
+
k: v
|
|
3576
|
+
for k, v in judge_cfg.items()
|
|
3577
|
+
if k not in {"module", "path", "callable", "function", "name"}
|
|
3578
|
+
}
|
|
3579
|
+
display_name = str(
|
|
3580
|
+
judge_cfg.get("name")
|
|
3581
|
+
or name_hint
|
|
3582
|
+
or f"judge{len(judge_specs) + 1}"
|
|
3583
|
+
)
|
|
3584
|
+
judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
|
|
3585
|
+
|
|
3586
|
+
raw_judge_cfg = cfg.get("judge")
|
|
3587
|
+
if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
|
|
3588
|
+
direct_keys = {"module", "path", "callable", "function", "name"}
|
|
3589
|
+
has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
|
|
3590
|
+
nested_candidates = [
|
|
3591
|
+
(key, value)
|
|
3592
|
+
for key, value in raw_judge_cfg.items()
|
|
3593
|
+
if isinstance(value, dict)
|
|
3594
|
+
]
|
|
3595
|
+
if has_direct_keys and not nested_candidates:
|
|
3596
|
+
_register_judge(None, raw_judge_cfg)
|
|
3597
|
+
else:
|
|
3598
|
+
for sub_name, sub_cfg in nested_candidates:
|
|
3599
|
+
_register_judge(sub_name, sub_cfg)
|
|
3600
|
+
|
|
3601
|
+
raw_judges_list = cfg.get("judges")
|
|
3602
|
+
if isinstance(raw_judges_list, list):
|
|
3603
|
+
for _index, entry in enumerate(raw_judges_list, start=1):
|
|
3604
|
+
if isinstance(entry, dict):
|
|
3605
|
+
_register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
|
|
3606
|
+
|
|
3607
|
+
records: list[dict[str, Any]] = []
|
|
3608
|
+
|
|
2655
3609
|
successes = 0
|
|
2656
3610
|
failures = 0
|
|
2657
3611
|
# Aggregate outcome stats across successful seeds
|
|
2658
3612
|
outcome_sum: float = 0.0
|
|
2659
3613
|
outcome_count: int = 0
|
|
2660
3614
|
outcome_correct: int = 0
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
)
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
2684
|
-
|
|
2685
|
-
"
|
|
2686
|
-
"
|
|
2687
|
-
"
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
policy_overrides[k] = cfg.get(k)
|
|
2691
|
-
except Exception:
|
|
2692
|
-
policy_overrides = {}
|
|
2693
|
-
|
|
2694
|
-
for seed_val in seed_values:
|
|
2695
|
-
body = {
|
|
2696
|
-
"run_id": str(uuid.uuid4()),
|
|
2697
|
-
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
2698
|
-
"policy": {
|
|
2699
|
-
"policy_name": selected_model,
|
|
2700
|
-
"config": {"model": selected_model, **policy_overrides},
|
|
2701
|
-
},
|
|
2702
|
-
"ops": [],
|
|
3615
|
+
|
|
3616
|
+
def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
|
|
3617
|
+
rows: dict[int, dict[str, Any]] = {}
|
|
3618
|
+
if not isinstance(taskset, dict):
|
|
3619
|
+
return rows
|
|
3620
|
+
|
|
3621
|
+
scenario_ids = taskset.get("scenario_ids") or []
|
|
3622
|
+
loop_ids = taskset.get("loop_ids") or []
|
|
3623
|
+
thread_ids = taskset.get("thread_ids") or []
|
|
3624
|
+
difficulty_map = taskset.get("difficulty_map") or {}
|
|
3625
|
+
|
|
3626
|
+
max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
|
|
3627
|
+
for seed in range(max_len):
|
|
3628
|
+
scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
|
|
3629
|
+
loop_id = loop_ids[seed] if seed < len(loop_ids) else None
|
|
3630
|
+
thread_id = thread_ids[seed] if seed < len(thread_ids) else None
|
|
3631
|
+
difficulty = None
|
|
3632
|
+
if isinstance(difficulty_map, dict):
|
|
3633
|
+
if scenario_id and scenario_id in difficulty_map:
|
|
3634
|
+
difficulty = difficulty_map.get(scenario_id)
|
|
3635
|
+
elif str(seed) in difficulty_map:
|
|
3636
|
+
difficulty = difficulty_map.get(str(seed))
|
|
3637
|
+
|
|
3638
|
+
rows[seed] = {
|
|
3639
|
+
"seed": seed,
|
|
3640
|
+
"scenario_id": scenario_id,
|
|
3641
|
+
"loop_id": loop_id,
|
|
3642
|
+
"thread_id": thread_id,
|
|
3643
|
+
"difficulty": difficulty,
|
|
2703
3644
|
}
|
|
3645
|
+
return rows
|
|
3646
|
+
|
|
3647
|
+
def _apply_metadata_filters(
|
|
3648
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
|
|
3649
|
+
) -> list[int]:
|
|
3650
|
+
if not filters:
|
|
3651
|
+
return seeds_list
|
|
3652
|
+
filtered: list[int] = []
|
|
3653
|
+
for seed in seeds_list:
|
|
3654
|
+
row = rows.get(seed)
|
|
3655
|
+
if not row:
|
|
3656
|
+
continue
|
|
3657
|
+
include = True
|
|
3658
|
+
for key, expected in filters.items():
|
|
3659
|
+
actual = row.get(key)
|
|
3660
|
+
if actual is None:
|
|
3661
|
+
include = False
|
|
3662
|
+
break
|
|
3663
|
+
if str(actual).lower() != expected.lower():
|
|
3664
|
+
include = False
|
|
3665
|
+
break
|
|
3666
|
+
if include:
|
|
3667
|
+
filtered.append(seed)
|
|
3668
|
+
return filtered
|
|
3669
|
+
|
|
3670
|
+
def _apply_metadata_sql(
|
|
3671
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
|
|
3672
|
+
) -> list[int]:
|
|
3673
|
+
"""Return seeds that satisfy an arbitrary SQL query.
|
|
3674
|
+
|
|
3675
|
+
The query is executed against an in-memory SQLite table named `tasks`
|
|
3676
|
+
with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
|
|
3677
|
+
Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
|
|
3678
|
+
"""
|
|
3679
|
+
if not query:
|
|
3680
|
+
return seeds_list
|
|
3681
|
+
conn = sqlite3.connect(":memory:")
|
|
3682
|
+
try:
|
|
3683
|
+
cur = conn.cursor()
|
|
3684
|
+
cur.execute(
|
|
3685
|
+
"CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
|
|
3686
|
+
)
|
|
3687
|
+
insert_stmt = (
|
|
3688
|
+
"INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
|
|
3689
|
+
)
|
|
3690
|
+
for seed in seeds_list:
|
|
3691
|
+
row = rows.get(seed, {})
|
|
3692
|
+
cur.execute(
|
|
3693
|
+
insert_stmt,
|
|
3694
|
+
[
|
|
3695
|
+
seed,
|
|
3696
|
+
row.get("scenario_id"),
|
|
3697
|
+
row.get("loop_id"),
|
|
3698
|
+
row.get("thread_id"),
|
|
3699
|
+
row.get("difficulty"),
|
|
3700
|
+
],
|
|
3701
|
+
)
|
|
3702
|
+
|
|
3703
|
+
result = cur.execute(query)
|
|
3704
|
+
fetched = result.fetchall()
|
|
3705
|
+
if not fetched:
|
|
3706
|
+
return []
|
|
3707
|
+
description = result.description or []
|
|
3708
|
+
col_names = [col[0] for col in description]
|
|
3709
|
+
seeds_out: list[int] = []
|
|
3710
|
+
for entry in fetched:
|
|
3711
|
+
value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
|
|
3712
|
+
try:
|
|
3713
|
+
seeds_out.append(int(value))
|
|
3714
|
+
except Exception as exc:
|
|
3715
|
+
raise click.ClickException(
|
|
3716
|
+
"metadata SQL query must return seed integers"
|
|
3717
|
+
) from exc
|
|
3718
|
+
seeds_set = set(seeds_out)
|
|
3719
|
+
return [seed for seed in seeds_list if seed in seeds_set]
|
|
3720
|
+
except sqlite3.Error as exc:
|
|
3721
|
+
raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
|
|
3722
|
+
finally:
|
|
3723
|
+
conn.close()
|
|
3724
|
+
|
|
3725
|
+
async def _run_eval() -> None:
|
|
3726
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
|
|
3727
|
+
|
|
3728
|
+
if trace_tracer is not None and trace_tracer.db is None:
|
|
3729
|
+
await trace_tracer.initialize()
|
|
3730
|
+
|
|
3731
|
+
if task_app_url is None:
|
|
3732
|
+
transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
|
|
3733
|
+
async_client = httpx.AsyncClient(
|
|
3734
|
+
transport=cast(Any, transport),
|
|
3735
|
+
base_url="http://eval.local",
|
|
3736
|
+
timeout=300.0,
|
|
3737
|
+
follow_redirects=True,
|
|
3738
|
+
headers=headers,
|
|
3739
|
+
)
|
|
3740
|
+
else:
|
|
3741
|
+
async_client = httpx.AsyncClient(
|
|
3742
|
+
base_url=task_app_url,
|
|
3743
|
+
timeout=300.0,
|
|
3744
|
+
follow_redirects=True,
|
|
3745
|
+
headers=headers,
|
|
3746
|
+
)
|
|
3747
|
+
|
|
3748
|
+
try:
|
|
3749
|
+
taskset_payload: dict[str, Any] | None = None
|
|
2704
3750
|
try:
|
|
2705
|
-
|
|
2706
|
-
|
|
3751
|
+
task_info_response = await async_client.get("/task_info")
|
|
3752
|
+
except Exception:
|
|
3753
|
+
task_info_response = None
|
|
3754
|
+
if task_info_response is not None and task_info_response.status_code == 200:
|
|
3755
|
+
with contextlib.suppress(Exception):
|
|
3756
|
+
payload_json = task_info_response.json()
|
|
3757
|
+
if isinstance(payload_json, dict) and "taskset" in payload_json:
|
|
3758
|
+
taskset_payload = payload_json.get("taskset")
|
|
3759
|
+
if not isinstance(taskset_payload, dict):
|
|
3760
|
+
taskset_payload = None
|
|
3761
|
+
elif isinstance(payload_json, dict):
|
|
3762
|
+
taskset_payload = payload_json
|
|
3763
|
+
|
|
3764
|
+
available_seeds = list(seed_values)
|
|
3765
|
+
if metadata_sql_query or metadata_filters:
|
|
3766
|
+
if not taskset_payload:
|
|
3767
|
+
raise click.ClickException(
|
|
3768
|
+
"Task metadata filters require the task app to expose /task_info metadata"
|
|
3769
|
+
)
|
|
3770
|
+
rows = _build_task_rows(taskset_payload)
|
|
3771
|
+
if metadata_sql_query:
|
|
3772
|
+
available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
|
|
3773
|
+
if metadata_filters:
|
|
3774
|
+
available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
|
|
3775
|
+
if not available_seeds:
|
|
3776
|
+
raise click.ClickException("No seeds match the provided metadata filters")
|
|
3777
|
+
seed_values = available_seeds
|
|
3778
|
+
|
|
3779
|
+
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
3780
|
+
|
|
3781
|
+
async def _run_seed(seed_val: int) -> None:
|
|
3782
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3783
|
+
body = {
|
|
3784
|
+
"run_id": str(uuid.uuid4()),
|
|
3785
|
+
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
3786
|
+
"policy": {
|
|
3787
|
+
"policy_name": selected_model,
|
|
3788
|
+
"config": {"model": selected_model, **policy_overrides},
|
|
3789
|
+
},
|
|
3790
|
+
"ops": [],
|
|
3791
|
+
}
|
|
3792
|
+
rollout_elapsed: float | None = None
|
|
3793
|
+
rollout_start = time.perf_counter()
|
|
3794
|
+
try:
|
|
3795
|
+
async with semaphore:
|
|
3796
|
+
response = await async_client.post("/rollout", json=body)
|
|
3797
|
+
rollout_elapsed = time.perf_counter() - rollout_start
|
|
3798
|
+
except Exception as exc:
|
|
3799
|
+
failures += 1
|
|
3800
|
+
click.echo(f"seed={seed_val} error={exc}")
|
|
3801
|
+
return
|
|
3802
|
+
|
|
3803
|
+
ok = 200 <= response.status_code < 300
|
|
2707
3804
|
if ok:
|
|
2708
3805
|
successes += 1
|
|
2709
3806
|
else:
|
|
2710
3807
|
failures += 1
|
|
2711
3808
|
|
|
2712
|
-
|
|
2713
|
-
|
|
3809
|
+
summary = [f"seed={seed_val}", f"status={response.status_code}"]
|
|
3810
|
+
data: Any
|
|
2714
3811
|
try:
|
|
2715
|
-
data =
|
|
3812
|
+
data = response.json()
|
|
2716
3813
|
except Exception:
|
|
2717
3814
|
data = None
|
|
3815
|
+
|
|
3816
|
+
metrics: dict[str, Any] | None = None
|
|
3817
|
+
completion: str | None = None
|
|
3818
|
+
prompt_index: int | None = None
|
|
3819
|
+
prompt_text: str | None = None
|
|
3820
|
+
task_id: str | None = None
|
|
3821
|
+
task_split: str | None = None
|
|
3822
|
+
task_rubric_id: str | None = None
|
|
3823
|
+
|
|
3824
|
+
trace_namespace: dict[str, Any] | None = None
|
|
3825
|
+
session_trace_dict: dict[str, Any] | None = None
|
|
3826
|
+
|
|
2718
3827
|
if isinstance(data, dict):
|
|
3828
|
+
trace_namespace = data.get("trace")
|
|
3829
|
+
if not isinstance(trace_namespace, dict):
|
|
3830
|
+
raise RuntimeError(
|
|
3831
|
+
"rollout response missing trace payload; task app must return tracing_v3 data"
|
|
3832
|
+
)
|
|
3833
|
+
session_trace_dict = trace_namespace.get("session_trace")
|
|
3834
|
+
if not isinstance(session_trace_dict, dict):
|
|
3835
|
+
raise RuntimeError(
|
|
3836
|
+
"rollout response trace missing 'session_trace'; ensure the task app is serving the tracing_v3 build"
|
|
3837
|
+
)
|
|
2719
3838
|
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
2720
3839
|
if metrics:
|
|
2721
3840
|
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
@@ -2724,7 +3843,6 @@ def eval_command(
|
|
|
2724
3843
|
summary.append(f"mean_return={mean_return}")
|
|
2725
3844
|
if outcome is not None:
|
|
2726
3845
|
summary.append(f"outcome={outcome}")
|
|
2727
|
-
# Aggregate outcome stats
|
|
2728
3846
|
try:
|
|
2729
3847
|
val = float(outcome)
|
|
2730
3848
|
outcome_sum += val
|
|
@@ -2733,7 +3851,6 @@ def eval_command(
|
|
|
2733
3851
|
outcome_correct += 1
|
|
2734
3852
|
except Exception:
|
|
2735
3853
|
pass
|
|
2736
|
-
# Try to infer tool call count from first trajectory step
|
|
2737
3854
|
trajs = (
|
|
2738
3855
|
data.get("trajectories")
|
|
2739
3856
|
if isinstance(data.get("trajectories"), list)
|
|
@@ -2747,38 +3864,163 @@ def eval_command(
|
|
|
2747
3864
|
tool_calls = step0.get("tool_calls") or step0.get("tools") or []
|
|
2748
3865
|
if isinstance(tool_calls, list):
|
|
2749
3866
|
summary.append(f"tool_calls={len(tool_calls)}")
|
|
3867
|
+
obs = step0.get("obs") if isinstance(step0, dict) else None
|
|
3868
|
+
if isinstance(obs, dict):
|
|
3869
|
+
idx_val = obs.get("prompt_index")
|
|
3870
|
+
if isinstance(idx_val, int):
|
|
3871
|
+
prompt_index = idx_val
|
|
3872
|
+
prompt_raw = obs.get("prompt")
|
|
3873
|
+
if isinstance(prompt_raw, str):
|
|
3874
|
+
prompt_text = prompt_raw
|
|
3875
|
+
if task_id is None:
|
|
3876
|
+
candidate_id = obs.get("task_id")
|
|
3877
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3878
|
+
task_id = candidate_id
|
|
3879
|
+
if task_split is None:
|
|
3880
|
+
candidate_split = obs.get("task_split")
|
|
3881
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3882
|
+
task_split = candidate_split
|
|
3883
|
+
if task_rubric_id is None:
|
|
3884
|
+
candidate_rid = obs.get("task_rubric_id")
|
|
3885
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3886
|
+
task_rubric_id = candidate_rid
|
|
3887
|
+
final = first.get("final") if isinstance(first, dict) else None
|
|
3888
|
+
if isinstance(final, dict):
|
|
3889
|
+
final_obs = final.get("observation")
|
|
3890
|
+
if isinstance(final_obs, dict):
|
|
3891
|
+
comp_val = final_obs.get("completion")
|
|
3892
|
+
if isinstance(comp_val, str):
|
|
3893
|
+
completion = comp_val
|
|
3894
|
+
if task_id is None:
|
|
3895
|
+
candidate_id = final_obs.get("task_id")
|
|
3896
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3897
|
+
task_id = candidate_id
|
|
3898
|
+
if task_split is None:
|
|
3899
|
+
candidate_split = final_obs.get("task_split")
|
|
3900
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3901
|
+
task_split = candidate_split
|
|
3902
|
+
if task_rubric_id is None:
|
|
3903
|
+
candidate_rid = final_obs.get("task_rubric_id")
|
|
3904
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3905
|
+
task_rubric_id = candidate_rid
|
|
3906
|
+
final_info = final.get("info")
|
|
3907
|
+
if isinstance(final_info, dict):
|
|
3908
|
+
if task_id is None:
|
|
3909
|
+
candidate_id = final_info.get("task_id")
|
|
3910
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
3911
|
+
task_id = candidate_id
|
|
3912
|
+
if task_split is None:
|
|
3913
|
+
candidate_split = final_info.get("task_split")
|
|
3914
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
3915
|
+
task_split = candidate_split
|
|
3916
|
+
if task_rubric_id is None:
|
|
3917
|
+
candidate_rid = final_info.get("task_rubric_id")
|
|
3918
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
3919
|
+
task_rubric_id = candidate_rid
|
|
3920
|
+
if task_id:
|
|
3921
|
+
summary.append(f"task_id={task_id}")
|
|
2750
3922
|
click.echo(" ".join(summary))
|
|
2751
|
-
# Print the full response JSON (trace, trajectories, metrics)
|
|
2752
3923
|
with contextlib.suppress(Exception):
|
|
2753
3924
|
click.echo(json.dumps(data, indent=2))
|
|
2754
3925
|
else:
|
|
2755
3926
|
click.echo(" ".join(summary))
|
|
2756
|
-
except Exception as exc:
|
|
2757
|
-
failures += 1
|
|
2758
|
-
click.echo(f"seed={seed_val} error={exc}")
|
|
2759
3927
|
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
except RuntimeError:
|
|
2769
|
-
# Fallback when already inside a running loop (rare for CLI).
|
|
2770
|
-
new_loop = asyncio.new_event_loop()
|
|
3928
|
+
official_score = None
|
|
3929
|
+
if isinstance(metrics, dict):
|
|
3930
|
+
for key in ("mean_return", "total_reward", "outcome_score"):
|
|
3931
|
+
val = metrics.get(key)
|
|
3932
|
+
if isinstance(val, int | float):
|
|
3933
|
+
official_score = float(val)
|
|
3934
|
+
break
|
|
3935
|
+
if official_score is None and isinstance(data, dict):
|
|
2771
3936
|
try:
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
3937
|
+
reward_val = data["trajectories"][0]["steps"][0].get("reward")
|
|
3938
|
+
if isinstance(reward_val, int | float):
|
|
3939
|
+
official_score = float(reward_val)
|
|
3940
|
+
except Exception:
|
|
3941
|
+
pass
|
|
3942
|
+
|
|
3943
|
+
if official_score is not None:
|
|
3944
|
+
if official_score < 0.0:
|
|
3945
|
+
official_score = 0.0
|
|
3946
|
+
elif official_score > 1.0:
|
|
3947
|
+
official_score = min(1.0, official_score)
|
|
3948
|
+
|
|
3949
|
+
judge_scores: dict[str, float | None] = {}
|
|
3950
|
+
judges_timings: dict[str, float | None] = {}
|
|
3951
|
+
timings: dict[str, Any] = {
|
|
3952
|
+
"rollout_s": rollout_elapsed,
|
|
3953
|
+
"judges": judges_timings,
|
|
3954
|
+
}
|
|
3955
|
+
if judge_specs:
|
|
3956
|
+
for spec in judge_specs:
|
|
3957
|
+
score_value: float | None = None
|
|
3958
|
+
judge_elapsed: float | None = None
|
|
3959
|
+
if completion is not None:
|
|
3960
|
+
judge_payload = {
|
|
3961
|
+
"seed": seed_val,
|
|
3962
|
+
"prompt_index": prompt_index,
|
|
3963
|
+
"prompt": prompt_text,
|
|
3964
|
+
"completion": completion,
|
|
3965
|
+
"metrics": metrics,
|
|
3966
|
+
"response": data,
|
|
3967
|
+
"trace": trace_namespace,
|
|
3968
|
+
}
|
|
3969
|
+
try:
|
|
3970
|
+
judge_start = time.perf_counter()
|
|
3971
|
+
result = spec.fn(judge_payload, **spec.kwargs)
|
|
3972
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
3973
|
+
if isinstance(result, int | float):
|
|
3974
|
+
score_value = float(result)
|
|
3975
|
+
except Exception as exc:
|
|
3976
|
+
if judge_elapsed is None:
|
|
3977
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
3978
|
+
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
3979
|
+
judges_timings[spec.name] = judge_elapsed
|
|
3980
|
+
judge_scores[spec.name] = score_value
|
|
3981
|
+
|
|
3982
|
+
if trace_tracer is not None and trace_namespace:
|
|
3983
|
+
storage_metadata = {
|
|
3984
|
+
"eval_seed": seed_val,
|
|
3985
|
+
"prompt_index": prompt_index,
|
|
3986
|
+
"task_id": task_id,
|
|
3987
|
+
"task_split": task_split,
|
|
3988
|
+
"task_rubric_id": task_rubric_id,
|
|
3989
|
+
"official_score": official_score,
|
|
3990
|
+
"judge_scores": judge_scores,
|
|
3991
|
+
"model": selected_model,
|
|
3992
|
+
"prompt": prompt_text,
|
|
3993
|
+
"completion": completion,
|
|
3994
|
+
}
|
|
3995
|
+
await _store_trace(trace_tracer, trace_namespace, storage_metadata)
|
|
3996
|
+
|
|
3997
|
+
records.append(
|
|
3998
|
+
{
|
|
3999
|
+
"seed": seed_val,
|
|
4000
|
+
"prompt_index": prompt_index,
|
|
4001
|
+
"task_id": task_id,
|
|
4002
|
+
"task_split": task_split,
|
|
4003
|
+
"task_rubric_id": task_rubric_id,
|
|
4004
|
+
"official_score": official_score,
|
|
4005
|
+
"judge_scores": judge_scores,
|
|
4006
|
+
"timings": timings,
|
|
4007
|
+
}
|
|
4008
|
+
)
|
|
4009
|
+
|
|
4010
|
+
await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
|
|
4011
|
+
finally:
|
|
4012
|
+
await async_client.aclose()
|
|
4013
|
+
|
|
4014
|
+
try:
|
|
4015
|
+
asyncio.run(_run_eval())
|
|
4016
|
+
finally:
|
|
4017
|
+
if trace_tracer is not None and trace_tracer.db is not None:
|
|
4018
|
+
asyncio.run(trace_tracer.db.close())
|
|
2777
4019
|
|
|
2778
4020
|
click.echo(
|
|
2779
4021
|
f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
|
|
2780
4022
|
)
|
|
2781
|
-
|
|
4023
|
+
|
|
2782
4024
|
if outcome_count > 0:
|
|
2783
4025
|
mean_outcome = outcome_sum / float(outcome_count)
|
|
2784
4026
|
frac_right = outcome_correct / float(outcome_count)
|
|
@@ -2786,6 +4028,285 @@ def eval_command(
|
|
|
2786
4028
|
f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
|
|
2787
4029
|
)
|
|
2788
4030
|
|
|
4031
|
+
if records:
|
|
4032
|
+
judge_specs = judge_specs or [] # ensure iterable
|
|
4033
|
+
official_scores = [
|
|
4034
|
+
r["official_score"] for r in records if r["official_score"] is not None
|
|
4035
|
+
]
|
|
4036
|
+
if official_scores:
|
|
4037
|
+
click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
|
|
4038
|
+
else:
|
|
4039
|
+
click.echo(" Official mean: n/a")
|
|
4040
|
+
|
|
4041
|
+
for spec in judge_specs:
|
|
4042
|
+
spec_scores = [
|
|
4043
|
+
record["judge_scores"].get(spec.name)
|
|
4044
|
+
for record in records
|
|
4045
|
+
if record["judge_scores"].get(spec.name) is not None
|
|
4046
|
+
]
|
|
4047
|
+
if spec_scores:
|
|
4048
|
+
mean_spec = sum(spec_scores) / len(spec_scores)
|
|
4049
|
+
click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
|
|
4050
|
+
else:
|
|
4051
|
+
click.echo(f" [{spec.name}] mean: n/a")
|
|
4052
|
+
|
|
4053
|
+
paired = [
|
|
4054
|
+
(
|
|
4055
|
+
record["official_score"],
|
|
4056
|
+
record["judge_scores"].get(spec.name),
|
|
4057
|
+
)
|
|
4058
|
+
for record in records
|
|
4059
|
+
if record["official_score"] is not None
|
|
4060
|
+
and record["judge_scores"].get(spec.name) is not None
|
|
4061
|
+
]
|
|
4062
|
+
if len(paired) >= 2:
|
|
4063
|
+
corr = _pearson(
|
|
4064
|
+
[p[0] for p in paired if p[0] is not None],
|
|
4065
|
+
[p[1] for p in paired if p[1] is not None],
|
|
4066
|
+
)
|
|
4067
|
+
if corr is not None:
|
|
4068
|
+
click.echo(f" Pearson r: {corr:.3f}")
|
|
4069
|
+
else:
|
|
4070
|
+
click.echo(" Pearson r: undefined (zero variance)")
|
|
4071
|
+
else:
|
|
4072
|
+
click.echo(" Pearson r: n/a (need ≥2 paired scores)")
|
|
4073
|
+
|
|
4074
|
+
header = ["Seed", "Prompt", "Official"]
|
|
4075
|
+
header.extend(spec.name for spec in judge_specs)
|
|
4076
|
+
rows: list[list[str]] = []
|
|
4077
|
+
for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
|
|
4078
|
+
seed_val = str(record["seed"])
|
|
4079
|
+
prompt_idx = (
|
|
4080
|
+
str(record["prompt_index"])
|
|
4081
|
+
if record["prompt_index"] is not None
|
|
4082
|
+
else "-"
|
|
4083
|
+
)
|
|
4084
|
+
official_val = (
|
|
4085
|
+
f"{record['official_score']:.3f}"
|
|
4086
|
+
if record["official_score"] is not None
|
|
4087
|
+
else "-"
|
|
4088
|
+
)
|
|
4089
|
+
row = [seed_val, prompt_idx, official_val]
|
|
4090
|
+
for spec in judge_specs:
|
|
4091
|
+
score_val = record["judge_scores"].get(spec.name)
|
|
4092
|
+
row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
|
|
4093
|
+
rows.append(row)
|
|
4094
|
+
|
|
4095
|
+
widths = [len(col) for col in header]
|
|
4096
|
+
for row in rows:
|
|
4097
|
+
for idx, cell in enumerate(row):
|
|
4098
|
+
widths[idx] = max(widths[idx], len(cell))
|
|
4099
|
+
|
|
4100
|
+
click.echo("")
|
|
4101
|
+
click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
|
|
4102
|
+
click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
|
|
4103
|
+
for row in rows:
|
|
4104
|
+
click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
|
|
4105
|
+
|
|
4106
|
+
|
|
4107
|
+
|
|
4108
|
+
@click.command(
|
|
4109
|
+
"filter",
|
|
4110
|
+
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
4111
|
+
)
|
|
4112
|
+
@click.option(
|
|
4113
|
+
"--config",
|
|
4114
|
+
"config_path",
|
|
4115
|
+
type=click.Path(),
|
|
4116
|
+
required=True,
|
|
4117
|
+
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
4118
|
+
)
|
|
4119
|
+
def filter_command(config_path: str) -> None:
|
|
4120
|
+
"""Render tracing sessions that match filter rules into SFT JSONL.
|
|
4121
|
+
|
|
4122
|
+
The TOML file should contain a `[filter]` table with at least:
|
|
4123
|
+
|
|
4124
|
+
db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
|
|
4125
|
+
output = \"ft_data/out.jsonl\" # destination JSONL
|
|
4126
|
+
|
|
4127
|
+
Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
|
|
4128
|
+
`min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
|
|
4129
|
+
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4130
|
+
for a working example.
|
|
4131
|
+
"""
|
|
4132
|
+
if _toml is None:
|
|
4133
|
+
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4134
|
+
|
|
4135
|
+
cfg_path = Path(config_path)
|
|
4136
|
+
if not cfg_path.exists():
|
|
4137
|
+
raise click.ClickException(f"Filter config not found: {cfg_path}")
|
|
4138
|
+
|
|
4139
|
+
try:
|
|
4140
|
+
config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
|
|
4141
|
+
except Exception as exc:
|
|
4142
|
+
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4143
|
+
|
|
4144
|
+
filter_cfg = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4145
|
+
if not isinstance(filter_cfg, dict):
|
|
4146
|
+
raise click.ClickException("Config must contain a [filter] table")
|
|
4147
|
+
|
|
4148
|
+
db_value = str(filter_cfg.get("db", "traces/v3/eval_traces.db")).strip()
|
|
4149
|
+
if not db_value:
|
|
4150
|
+
raise click.ClickException("filter.db must be provided")
|
|
4151
|
+
if "://" in db_value:
|
|
4152
|
+
db_url = db_value
|
|
4153
|
+
else:
|
|
4154
|
+
db_path = Path(db_value).expanduser()
|
|
4155
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4156
|
+
db_url = f"sqlite+aiosqlite:///{db_path}"
|
|
4157
|
+
|
|
4158
|
+
output_value = filter_cfg.get("output")
|
|
4159
|
+
if not output_value:
|
|
4160
|
+
raise click.ClickException("filter.output must be provided")
|
|
4161
|
+
output_path = Path(str(output_value)).expanduser()
|
|
4162
|
+
|
|
4163
|
+
splits = set(filter_cfg.get("splits", []) or [])
|
|
4164
|
+
task_ids = set(filter_cfg.get("task_ids", []) or [])
|
|
4165
|
+
models = set(filter_cfg.get("models", []) or [])
|
|
4166
|
+
min_official = filter_cfg.get("min_official_score")
|
|
4167
|
+
max_official = filter_cfg.get("max_official_score")
|
|
4168
|
+
if min_official is not None:
|
|
4169
|
+
try:
|
|
4170
|
+
min_official = float(min_official)
|
|
4171
|
+
except Exception as err:
|
|
4172
|
+
raise click.ClickException("filter.min_official_score must be numeric") from err
|
|
4173
|
+
if max_official is not None:
|
|
4174
|
+
try:
|
|
4175
|
+
max_official = float(max_official)
|
|
4176
|
+
except Exception as err:
|
|
4177
|
+
raise click.ClickException("filter.max_official_score must be numeric") from err
|
|
4178
|
+
min_judge_scores = filter_cfg.get("min_judge_scores", {}) or {}
|
|
4179
|
+
max_judge_scores = filter_cfg.get("max_judge_scores", {}) or {}
|
|
4180
|
+
try:
|
|
4181
|
+
min_judge_scores = {k: float(v) for k, v in min_judge_scores.items()}
|
|
4182
|
+
except Exception as err:
|
|
4183
|
+
raise click.ClickException("filter.min_judge_scores values must be numeric") from err
|
|
4184
|
+
try:
|
|
4185
|
+
max_judge_scores = {k: float(v) for k, v in max_judge_scores.items()}
|
|
4186
|
+
except Exception as err:
|
|
4187
|
+
raise click.ClickException("filter.max_judge_scores values must be numeric") from err
|
|
4188
|
+
min_created = _parse_datetime_for_trace(filter_cfg.get("min_created_at"))
|
|
4189
|
+
max_created = _parse_datetime_for_trace(filter_cfg.get("max_created_at"))
|
|
4190
|
+
limit = filter_cfg.get("limit")
|
|
4191
|
+
if limit is not None:
|
|
4192
|
+
try:
|
|
4193
|
+
limit = int(limit)
|
|
4194
|
+
except Exception as err:
|
|
4195
|
+
raise click.ClickException("filter.limit must be an integer") from err
|
|
4196
|
+
|
|
4197
|
+
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4198
|
+
try:
|
|
4199
|
+
if value is None:
|
|
4200
|
+
return min_val is None
|
|
4201
|
+
value = float(value)
|
|
4202
|
+
except Exception:
|
|
4203
|
+
return False
|
|
4204
|
+
if min_val is not None and value < float(min_val):
|
|
4205
|
+
return False
|
|
4206
|
+
return not (max_val is not None and value > float(max_val))
|
|
4207
|
+
|
|
4208
|
+
async def _run_filter() -> None:
|
|
4209
|
+
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
4210
|
+
await tracer.initialize()
|
|
4211
|
+
|
|
4212
|
+
df = await tracer.db.query_traces(
|
|
4213
|
+
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
4214
|
+
)
|
|
4215
|
+
if getattr(df, "empty", True):
|
|
4216
|
+
raise click.ClickException("No traces found in database")
|
|
4217
|
+
|
|
4218
|
+
sessions = df.to_dict("records")
|
|
4219
|
+
accepted: list[dict[str, Any]] = []
|
|
4220
|
+
|
|
4221
|
+
for row in sessions:
|
|
4222
|
+
metadata_raw = row.get("metadata")
|
|
4223
|
+
if isinstance(metadata_raw, str):
|
|
4224
|
+
try:
|
|
4225
|
+
metadata = json.loads(metadata_raw)
|
|
4226
|
+
except Exception:
|
|
4227
|
+
metadata = {}
|
|
4228
|
+
elif isinstance(metadata_raw, dict):
|
|
4229
|
+
metadata = dict(metadata_raw)
|
|
4230
|
+
else:
|
|
4231
|
+
metadata = {}
|
|
4232
|
+
|
|
4233
|
+
created_at_raw = row.get("created_at")
|
|
4234
|
+
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
4235
|
+
|
|
4236
|
+
session_id = row.get("session_id")
|
|
4237
|
+
|
|
4238
|
+
if splits and metadata.get("task_split") not in splits:
|
|
4239
|
+
continue
|
|
4240
|
+
if task_ids and metadata.get("task_id") not in task_ids:
|
|
4241
|
+
continue
|
|
4242
|
+
if models and metadata.get("model") not in models:
|
|
4243
|
+
continue
|
|
4244
|
+
|
|
4245
|
+
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
4246
|
+
continue
|
|
4247
|
+
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4248
|
+
continue
|
|
4249
|
+
|
|
4250
|
+
if not _score_ok(metadata.get("official_score"), min_official, max_official):
|
|
4251
|
+
continue
|
|
4252
|
+
|
|
4253
|
+
judge_scores = metadata.get("judge_scores") or {}
|
|
4254
|
+
include = True
|
|
4255
|
+
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
4256
|
+
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
4257
|
+
include = False
|
|
4258
|
+
break
|
|
4259
|
+
if not include:
|
|
4260
|
+
continue
|
|
4261
|
+
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
4262
|
+
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
4263
|
+
include = False
|
|
4264
|
+
break
|
|
4265
|
+
if not include:
|
|
4266
|
+
continue
|
|
4267
|
+
|
|
4268
|
+
prompt = metadata.get("prompt") or ""
|
|
4269
|
+
completion = metadata.get("completion") or ""
|
|
4270
|
+
if not prompt or not completion:
|
|
4271
|
+
continue
|
|
4272
|
+
|
|
4273
|
+
record = {
|
|
4274
|
+
"messages": [
|
|
4275
|
+
{"role": "user", "content": str(prompt)},
|
|
4276
|
+
{"role": "assistant", "content": str(completion)},
|
|
4277
|
+
],
|
|
4278
|
+
"metadata": {
|
|
4279
|
+
"session_id": session_id,
|
|
4280
|
+
"task_id": metadata.get("task_id"),
|
|
4281
|
+
"task_split": metadata.get("task_split"),
|
|
4282
|
+
"task_rubric_id": metadata.get("task_rubric_id"),
|
|
4283
|
+
"official_score": metadata.get("official_score"),
|
|
4284
|
+
"judge_scores": judge_scores,
|
|
4285
|
+
"model": metadata.get("model"),
|
|
4286
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4287
|
+
"prompt": prompt,
|
|
4288
|
+
"completion": completion,
|
|
4289
|
+
},
|
|
4290
|
+
}
|
|
4291
|
+
accepted.append(record)
|
|
4292
|
+
|
|
4293
|
+
if not accepted:
|
|
4294
|
+
raise click.ClickException("No sessions matched the provided filters")
|
|
4295
|
+
|
|
4296
|
+
if limit is not None and limit > 0:
|
|
4297
|
+
accepted = accepted[:limit]
|
|
4298
|
+
|
|
4299
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4300
|
+
with output_path.open("w", encoding="utf-8") as handle:
|
|
4301
|
+
for item in accepted:
|
|
4302
|
+
handle.write(json.dumps(item, ensure_ascii=False))
|
|
4303
|
+
handle.write("\n")
|
|
4304
|
+
|
|
4305
|
+
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
4306
|
+
await tracer.db.close()
|
|
4307
|
+
|
|
4308
|
+
asyncio.run(_run_filter())
|
|
4309
|
+
|
|
2789
4310
|
|
|
2790
4311
|
def register_eval(cli: click.Group) -> None:
|
|
2791
4312
|
cli.add_command(eval_command)
|