synth-ai 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +5 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +125 -10
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +12 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +58 -1487
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -11
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/validators.py +2 -2
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/utils/env.py +25 -18
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/modal.py +2 -2
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/METADATA +8 -3
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/RECORD +184 -109
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EvalCliError(RuntimeError):
|
|
7
|
+
"""Base exception for eval CLI failures."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(slots=True)
|
|
11
|
+
class TomlUnavailableError(EvalCliError):
|
|
12
|
+
hint: str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(slots=True)
|
|
16
|
+
class EvalConfigNotFoundError(EvalCliError):
|
|
17
|
+
path: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(slots=True)
|
|
21
|
+
class EvalConfigParseError(EvalCliError):
|
|
22
|
+
path: str
|
|
23
|
+
detail: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(slots=True)
|
|
27
|
+
class MissingEvalTableError(EvalCliError):
|
|
28
|
+
"""Raised when the eval config lacks an [eval] table."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class InvalidEvalConfigError(EvalCliError):
|
|
33
|
+
detail: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(slots=True)
|
|
37
|
+
class SeedParseError(EvalCliError):
|
|
38
|
+
value: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(slots=True)
|
|
42
|
+
class MetadataFilterFormatError(EvalCliError):
|
|
43
|
+
entry: str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(slots=True)
|
|
47
|
+
class TaskInfoUnavailableError(EvalCliError):
|
|
48
|
+
"""Raised when metadata filters require task info but the task app does not expose it."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(slots=True)
|
|
52
|
+
class NoSeedsMatchedError(EvalCliError):
|
|
53
|
+
hint: str | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(slots=True)
|
|
57
|
+
class MetadataSQLExecutionError(EvalCliError):
|
|
58
|
+
query: str
|
|
59
|
+
detail: str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass(slots=True)
|
|
63
|
+
class MetadataSQLResultError(EvalCliError):
|
|
64
|
+
query: str
|
|
65
|
+
detail: str
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__all__ = [
|
|
69
|
+
"EvalCliError",
|
|
70
|
+
"TomlUnavailableError",
|
|
71
|
+
"EvalConfigNotFoundError",
|
|
72
|
+
"EvalConfigParseError",
|
|
73
|
+
"MissingEvalTableError",
|
|
74
|
+
"InvalidEvalConfigError",
|
|
75
|
+
"SeedParseError",
|
|
76
|
+
"MetadataFilterFormatError",
|
|
77
|
+
"TaskInfoUnavailableError",
|
|
78
|
+
"NoSeedsMatchedError",
|
|
79
|
+
"MetadataSQLExecutionError",
|
|
80
|
+
"MetadataSQLResultError",
|
|
81
|
+
]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import MutableMapping
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
__all__ = ["validate_eval_options"]
|
|
8
|
+
|
|
9
|
+
_SEED_RANGE = re.compile(r"^\s*(-?\d+)\s*-\s*(-?\d+)\s*$")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _coerce_bool(value: Any) -> bool:
|
|
13
|
+
if isinstance(value, str):
|
|
14
|
+
return value.strip().lower() in {"1", "true", "yes", "on"}
|
|
15
|
+
return bool(value)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _coerce_int(value: Any) -> int | None:
|
|
19
|
+
if value is None or value == "":
|
|
20
|
+
return None
|
|
21
|
+
return int(value)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _parse_seeds(value: Any) -> list[int]:
|
|
25
|
+
if value is None:
|
|
26
|
+
return []
|
|
27
|
+
if isinstance(value, str):
|
|
28
|
+
chunks = [chunk.strip() for chunk in value.split(",") if chunk.strip()]
|
|
29
|
+
elif isinstance(value, list | tuple | set):
|
|
30
|
+
chunks = list(value)
|
|
31
|
+
else:
|
|
32
|
+
chunks = [value]
|
|
33
|
+
seeds: list[int] = []
|
|
34
|
+
for chunk in chunks:
|
|
35
|
+
if isinstance(chunk, int):
|
|
36
|
+
seeds.append(chunk)
|
|
37
|
+
else:
|
|
38
|
+
text = str(chunk).strip()
|
|
39
|
+
if not text:
|
|
40
|
+
continue
|
|
41
|
+
match = _SEED_RANGE.match(text)
|
|
42
|
+
if match:
|
|
43
|
+
start = int(match.group(1))
|
|
44
|
+
end = int(match.group(2))
|
|
45
|
+
if start > end:
|
|
46
|
+
raise ValueError(f"Invalid seed range '{text}': start must be <= end")
|
|
47
|
+
seeds.extend(range(start, end + 1))
|
|
48
|
+
else:
|
|
49
|
+
seeds.append(int(text))
|
|
50
|
+
return seeds
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _normalize_metadata(value: Any) -> dict[str, str]:
|
|
54
|
+
if value is None:
|
|
55
|
+
return {}
|
|
56
|
+
if isinstance(value, MutableMapping):
|
|
57
|
+
return {str(k): str(v) for k, v in value.items()}
|
|
58
|
+
if isinstance(value, list | tuple):
|
|
59
|
+
result: dict[str, str] = {}
|
|
60
|
+
for item in value:
|
|
61
|
+
if isinstance(item, str) and "=" in item:
|
|
62
|
+
key, val = item.split("=", 1)
|
|
63
|
+
result[key.strip()] = val.strip()
|
|
64
|
+
return result
|
|
65
|
+
if isinstance(value, str) and "=" in value:
|
|
66
|
+
key, val = value.split("=", 1)
|
|
67
|
+
return {key.strip(): val.strip()}
|
|
68
|
+
return {}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _ensure_list(value: Any) -> list[str] | None:
|
|
72
|
+
if value is None:
|
|
73
|
+
return None
|
|
74
|
+
if isinstance(value, list | tuple | set):
|
|
75
|
+
return [str(item) for item in value]
|
|
76
|
+
return [str(value)]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _ensure_dict(value: Any) -> dict[str, Any]:
|
|
80
|
+
if isinstance(value, MutableMapping):
|
|
81
|
+
return dict(value)
|
|
82
|
+
return {}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def validate_eval_options(options: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
|
|
86
|
+
"""Validate and normalise eval configuration options."""
|
|
87
|
+
|
|
88
|
+
result: dict[str, Any] = dict(options)
|
|
89
|
+
|
|
90
|
+
if "seeds" in result:
|
|
91
|
+
result["seeds"] = _parse_seeds(result.get("seeds"))
|
|
92
|
+
|
|
93
|
+
for field in ("max_turns", "max_llm_calls", "concurrency"):
|
|
94
|
+
try:
|
|
95
|
+
result[field] = _coerce_int(result.get(field))
|
|
96
|
+
except Exception as exc:
|
|
97
|
+
raise ValueError(f"Invalid value for {field}: {result.get(field)}") from exc
|
|
98
|
+
|
|
99
|
+
if result.get("max_llm_calls") is None:
|
|
100
|
+
result["max_llm_calls"] = 10
|
|
101
|
+
if result.get("concurrency") is None:
|
|
102
|
+
result["concurrency"] = 1
|
|
103
|
+
|
|
104
|
+
if "return_trace" in result:
|
|
105
|
+
result["return_trace"] = _coerce_bool(result.get("return_trace"))
|
|
106
|
+
|
|
107
|
+
metadata_value = result.get("metadata")
|
|
108
|
+
result["metadata"] = _normalize_metadata(metadata_value)
|
|
109
|
+
|
|
110
|
+
if "ops" in result:
|
|
111
|
+
ops_list = _ensure_list(result.get("ops"))
|
|
112
|
+
result["ops"] = ops_list
|
|
113
|
+
|
|
114
|
+
result["env_config"] = _ensure_dict(result.get("env_config"))
|
|
115
|
+
result["policy_config"] = _ensure_dict(result.get("policy_config"))
|
|
116
|
+
|
|
117
|
+
trace_format = result.get("trace_format")
|
|
118
|
+
if trace_format is not None:
|
|
119
|
+
result["trace_format"] = str(trace_format)
|
|
120
|
+
|
|
121
|
+
metadata_sql = result.get("metadata_sql")
|
|
122
|
+
if metadata_sql is not None and not isinstance(metadata_sql, str):
|
|
123
|
+
result["metadata_sql"] = str(metadata_sql)
|
|
124
|
+
|
|
125
|
+
model = result.get("model")
|
|
126
|
+
if model is not None:
|
|
127
|
+
result["model"] = str(model)
|
|
128
|
+
|
|
129
|
+
app_id = result.get("app_id")
|
|
130
|
+
if app_id is not None:
|
|
131
|
+
result["app_id"] = str(app_id)
|
|
132
|
+
|
|
133
|
+
return result
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .core import command, get_command
|
|
4
|
+
from .errors import FilterCliError
|
|
5
|
+
from .validation import validate_filter_options
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"command",
|
|
9
|
+
"get_command",
|
|
10
|
+
"FilterCliError",
|
|
11
|
+
"validate_filter_options",
|
|
12
|
+
]
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from datetime import UTC, datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Sequence
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
try: # Python 3.11+
|
|
12
|
+
import tomllib as _toml # type: ignore[attr-defined]
|
|
13
|
+
except Exception: # pragma: no cover
|
|
14
|
+
_toml = None # type: ignore[assignment]
|
|
15
|
+
|
|
16
|
+
from synth_ai.task.config import FilterConfig
|
|
17
|
+
from synth_ai.tracing_v3 import SessionTracer # type: ignore[import-untyped]
|
|
18
|
+
|
|
19
|
+
from .errors import (
|
|
20
|
+
FilterCliError,
|
|
21
|
+
FilterConfigNotFoundError,
|
|
22
|
+
FilterConfigParseError,
|
|
23
|
+
InvalidFilterConfigError,
|
|
24
|
+
MissingFilterTableError,
|
|
25
|
+
NoSessionsMatchedError,
|
|
26
|
+
NoTracesFoundError,
|
|
27
|
+
TomlUnavailableError,
|
|
28
|
+
)
|
|
29
|
+
from .validation import validate_filter_options
|
|
30
|
+
|
|
31
|
+
__all__ = ["command", "get_command", "filter_command"]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _parse_datetime_for_trace(value: Any) -> datetime | None:
|
|
35
|
+
if isinstance(value, datetime):
|
|
36
|
+
return value if value.tzinfo else value.replace(tzinfo=UTC)
|
|
37
|
+
if isinstance(value, str):
|
|
38
|
+
value = value.replace("Z", "+00:00")
|
|
39
|
+
try:
|
|
40
|
+
dt = datetime.fromisoformat(value)
|
|
41
|
+
except ValueError:
|
|
42
|
+
try:
|
|
43
|
+
dt = datetime.fromtimestamp(float(value), tz=UTC)
|
|
44
|
+
except Exception:
|
|
45
|
+
return None
|
|
46
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
|
47
|
+
if isinstance(value, int | float):
|
|
48
|
+
try:
|
|
49
|
+
return datetime.fromtimestamp(float(value), tz=UTC)
|
|
50
|
+
except Exception:
|
|
51
|
+
return None
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
56
|
+
try:
|
|
57
|
+
if value is None:
|
|
58
|
+
return min_val is None
|
|
59
|
+
value = float(value)
|
|
60
|
+
except Exception:
|
|
61
|
+
return False
|
|
62
|
+
if min_val is not None and value < float(min_val):
|
|
63
|
+
return False
|
|
64
|
+
return not (max_val is not None and value > float(max_val))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _load_filter_config(config_path: Path) -> tuple[FilterConfig, dict[str, Any]]:
|
|
68
|
+
if _toml is None:
|
|
69
|
+
raise TomlUnavailableError(hint="Install tomli or use Python 3.11+")
|
|
70
|
+
|
|
71
|
+
if not config_path.exists():
|
|
72
|
+
raise FilterConfigNotFoundError(path=str(config_path))
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
config_data = _toml.loads(config_path.read_text(encoding="utf-8"))
|
|
76
|
+
except Exception as exc: # pragma: no cover - validation tests cover common cases
|
|
77
|
+
raise FilterConfigParseError(path=str(config_path), detail=str(exc)) from exc
|
|
78
|
+
|
|
79
|
+
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
80
|
+
if not isinstance(filter_cfg_dict, dict):
|
|
81
|
+
raise MissingFilterTableError()
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
normalized = validate_filter_options(filter_cfg_dict)
|
|
85
|
+
normalized_dict = dict(normalized)
|
|
86
|
+
filter_cfg = FilterConfig.from_dict(normalized_dict)
|
|
87
|
+
except (ValueError, TypeError) as validation_error:
|
|
88
|
+
raise InvalidFilterConfigError(detail=str(validation_error)) from validation_error
|
|
89
|
+
|
|
90
|
+
click.echo(
|
|
91
|
+
f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}"
|
|
92
|
+
)
|
|
93
|
+
if filter_cfg.min_official_score is not None:
|
|
94
|
+
click.echo(
|
|
95
|
+
f" → Filtering for official score >= {filter_cfg.min_official_score}"
|
|
96
|
+
)
|
|
97
|
+
if filter_cfg.limit:
|
|
98
|
+
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
99
|
+
|
|
100
|
+
return filter_cfg, normalized_dict
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _extract_content(content: Any) -> Any:
|
|
104
|
+
if isinstance(content, dict) and "content" in content:
|
|
105
|
+
return content["content"]
|
|
106
|
+
return content
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _extract_text(content: Any) -> str:
|
|
110
|
+
if isinstance(content, str):
|
|
111
|
+
return content
|
|
112
|
+
if isinstance(content, dict):
|
|
113
|
+
payload = content.get("payload") if isinstance(content.get("payload"), dict) else None
|
|
114
|
+
if payload and "content" in payload:
|
|
115
|
+
return _extract_text(payload["content"])
|
|
116
|
+
for key in ("text", "content", "content_text"):
|
|
117
|
+
if key in content:
|
|
118
|
+
value = content[key]
|
|
119
|
+
if isinstance(value, str):
|
|
120
|
+
return value
|
|
121
|
+
try:
|
|
122
|
+
return json.dumps(content)
|
|
123
|
+
except Exception: # pragma: no cover - defensive
|
|
124
|
+
return str(content)
|
|
125
|
+
if isinstance(content, list):
|
|
126
|
+
parts = []
|
|
127
|
+
for item in content:
|
|
128
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
129
|
+
parts.append(item.get("text", ""))
|
|
130
|
+
return " ".join(parts) if parts else str(content)
|
|
131
|
+
return str(content)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _select_messages(message_rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
135
|
+
records: list[dict[str, Any]] = []
|
|
136
|
+
for index, msg_row in enumerate(message_rows):
|
|
137
|
+
msg_type = msg_row.get("message_type")
|
|
138
|
+
content_raw = msg_row.get("content")
|
|
139
|
+
if msg_type not in {"user", "policy_user_prompt"}:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
assistant_msg = None
|
|
143
|
+
for follow in range(index + 1, len(message_rows)):
|
|
144
|
+
next_type = message_rows[follow].get("message_type")
|
|
145
|
+
if next_type in {"assistant", "policy_system_prompt"}:
|
|
146
|
+
if next_type == "assistant":
|
|
147
|
+
assistant_msg = message_rows[follow]
|
|
148
|
+
break
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
152
|
+
except Exception:
|
|
153
|
+
user_content = content_raw
|
|
154
|
+
|
|
155
|
+
user_content = _extract_content(user_content)
|
|
156
|
+
user_text = _extract_text(user_content)
|
|
157
|
+
if not user_text:
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
assistant_content = None
|
|
161
|
+
if assistant_msg is not None:
|
|
162
|
+
raw = assistant_msg.get("content")
|
|
163
|
+
try:
|
|
164
|
+
assistant_content = json.loads(raw) if isinstance(raw, str) else raw
|
|
165
|
+
except Exception:
|
|
166
|
+
assistant_content = raw
|
|
167
|
+
assistant_content = _extract_content(assistant_content)
|
|
168
|
+
|
|
169
|
+
assistant_text = _extract_text(assistant_content) if assistant_content is not None else ""
|
|
170
|
+
user_payload = user_content if isinstance(user_content, list) else user_text
|
|
171
|
+
assistant_payload = (
|
|
172
|
+
assistant_content
|
|
173
|
+
if isinstance(assistant_content, list)
|
|
174
|
+
else (assistant_text or "[no response recorded]")
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
records.append(
|
|
178
|
+
{
|
|
179
|
+
"messages": [
|
|
180
|
+
{"role": "user", "content": user_payload},
|
|
181
|
+
{"role": "assistant", "content": assistant_payload},
|
|
182
|
+
]
|
|
183
|
+
}
|
|
184
|
+
)
|
|
185
|
+
return records
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@click.command(
|
|
189
|
+
"filter",
|
|
190
|
+
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
191
|
+
)
|
|
192
|
+
@click.option(
|
|
193
|
+
"--config",
|
|
194
|
+
"config_path",
|
|
195
|
+
type=click.Path(),
|
|
196
|
+
required=True,
|
|
197
|
+
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
198
|
+
)
|
|
199
|
+
def filter_command(config_path: str) -> None:
|
|
200
|
+
try:
|
|
201
|
+
filter_cfg, raw_cfg = _load_filter_config(Path(config_path))
|
|
202
|
+
except FilterCliError as exc:
|
|
203
|
+
raise click.ClickException(_format_filter_error(exc)) from exc
|
|
204
|
+
|
|
205
|
+
db_url = filter_cfg.get_db_url()
|
|
206
|
+
output_path = filter_cfg.get_output_path()
|
|
207
|
+
|
|
208
|
+
splits = set(filter_cfg.splits)
|
|
209
|
+
task_ids = set(filter_cfg.task_ids)
|
|
210
|
+
models = set(filter_cfg.models)
|
|
211
|
+
min_official = filter_cfg.min_official_score
|
|
212
|
+
max_official = filter_cfg.max_official_score
|
|
213
|
+
min_judge_scores = filter_cfg.min_judge_scores
|
|
214
|
+
max_judge_scores = filter_cfg.max_judge_scores
|
|
215
|
+
min_created = _parse_datetime_for_trace(raw_cfg.get("min_created_at"))
|
|
216
|
+
max_created = _parse_datetime_for_trace(raw_cfg.get("max_created_at"))
|
|
217
|
+
limit = filter_cfg.limit
|
|
218
|
+
|
|
219
|
+
async def _run() -> None:
|
|
220
|
+
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
221
|
+
await tracer.initialize()
|
|
222
|
+
assert tracer.db is not None, "Database should be initialized"
|
|
223
|
+
|
|
224
|
+
df = await tracer.db.query_traces(
|
|
225
|
+
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
226
|
+
)
|
|
227
|
+
if getattr(df, "empty", True):
|
|
228
|
+
raise NoTracesFoundError(db_url=db_url)
|
|
229
|
+
|
|
230
|
+
sessions = df.to_dict("records")
|
|
231
|
+
accepted: list[dict[str, Any]] = []
|
|
232
|
+
|
|
233
|
+
for row in sessions:
|
|
234
|
+
metadata_raw = row.get("metadata")
|
|
235
|
+
if isinstance(metadata_raw, str):
|
|
236
|
+
try:
|
|
237
|
+
metadata = json.loads(metadata_raw)
|
|
238
|
+
except Exception:
|
|
239
|
+
metadata = {}
|
|
240
|
+
elif isinstance(metadata_raw, dict):
|
|
241
|
+
metadata = dict(metadata_raw)
|
|
242
|
+
else:
|
|
243
|
+
metadata = {}
|
|
244
|
+
|
|
245
|
+
created_at_raw = row.get("created_at")
|
|
246
|
+
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
247
|
+
session_id = row.get("session_id")
|
|
248
|
+
|
|
249
|
+
if splits and metadata.get("task_split") not in splits:
|
|
250
|
+
continue
|
|
251
|
+
if task_ids and metadata.get("task_id") not in task_ids:
|
|
252
|
+
continue
|
|
253
|
+
if models and metadata.get("model") not in models:
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
257
|
+
continue
|
|
258
|
+
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
total_reward = None
|
|
262
|
+
achievements_count = None
|
|
263
|
+
if min_official is not None or max_official is not None:
|
|
264
|
+
reward_rows = await tracer.db.query_traces(
|
|
265
|
+
"SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id",
|
|
266
|
+
{"session_id": session_id},
|
|
267
|
+
)
|
|
268
|
+
reward_records = (
|
|
269
|
+
reward_rows.to_dict("records")
|
|
270
|
+
if hasattr(reward_rows, "to_dict")
|
|
271
|
+
else []
|
|
272
|
+
)
|
|
273
|
+
if reward_records:
|
|
274
|
+
total_reward = reward_records[0].get("total_reward")
|
|
275
|
+
achievements_count = reward_records[0].get("achievements_count")
|
|
276
|
+
if not _score_ok(total_reward, min_official, max_official):
|
|
277
|
+
continue
|
|
278
|
+
elif min_official is not None:
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
judge_scores = metadata.get("judge_scores") or {}
|
|
282
|
+
include = True
|
|
283
|
+
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
284
|
+
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
285
|
+
include = False
|
|
286
|
+
break
|
|
287
|
+
if not include:
|
|
288
|
+
continue
|
|
289
|
+
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
290
|
+
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
291
|
+
include = False
|
|
292
|
+
break
|
|
293
|
+
if not include:
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
messages_query = (
|
|
297
|
+
"\n SELECT message_type, content, timestamp \n FROM messages \n WHERE session_id = :session_id\n ORDER BY timestamp ASC, id ASC\n "
|
|
298
|
+
)
|
|
299
|
+
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
300
|
+
message_rows = (
|
|
301
|
+
msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if not message_rows:
|
|
305
|
+
prompt = metadata.get("prompt") or ""
|
|
306
|
+
completion = metadata.get("completion") or ""
|
|
307
|
+
if prompt and completion:
|
|
308
|
+
accepted.append(
|
|
309
|
+
{
|
|
310
|
+
"messages": [
|
|
311
|
+
{"role": "user", "content": str(prompt)},
|
|
312
|
+
{"role": "assistant", "content": str(completion)},
|
|
313
|
+
],
|
|
314
|
+
"metadata": {
|
|
315
|
+
"session_id": session_id,
|
|
316
|
+
"env_name": metadata.get("env_name"),
|
|
317
|
+
"policy_name": metadata.get("policy_name"),
|
|
318
|
+
"seed": metadata.get("seed"),
|
|
319
|
+
"total_reward": total_reward,
|
|
320
|
+
"achievements_count": achievements_count,
|
|
321
|
+
"model": metadata.get("model"),
|
|
322
|
+
"created_at": created_at_dt.isoformat()
|
|
323
|
+
if created_at_dt
|
|
324
|
+
else created_at_raw,
|
|
325
|
+
},
|
|
326
|
+
}
|
|
327
|
+
)
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
for record in _select_messages(message_rows):
|
|
331
|
+
record["metadata"] = {
|
|
332
|
+
"session_id": session_id,
|
|
333
|
+
"env_name": metadata.get("env_name"),
|
|
334
|
+
"policy_name": metadata.get("policy_name"),
|
|
335
|
+
"seed": metadata.get("seed"),
|
|
336
|
+
"total_reward": total_reward,
|
|
337
|
+
"achievements_count": achievements_count,
|
|
338
|
+
"model": metadata.get("model"),
|
|
339
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
340
|
+
}
|
|
341
|
+
accepted.append(record)
|
|
342
|
+
|
|
343
|
+
if not accepted:
|
|
344
|
+
raise NoSessionsMatchedError()
|
|
345
|
+
|
|
346
|
+
if limit is not None and limit > 0:
|
|
347
|
+
accepted[:] = accepted[:limit]
|
|
348
|
+
|
|
349
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
350
|
+
with output_path.open("w", encoding="utf-8") as handle:
|
|
351
|
+
for item in accepted:
|
|
352
|
+
handle.write(json.dumps(item, ensure_ascii=False))
|
|
353
|
+
handle.write("\n")
|
|
354
|
+
|
|
355
|
+
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
356
|
+
await tracer.db.close()
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
asyncio.run(_run())
|
|
360
|
+
except FilterCliError as exc:
|
|
361
|
+
raise click.ClickException(_format_filter_error(exc)) from exc
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _format_filter_error(err: FilterCliError) -> str:
|
|
365
|
+
if isinstance(err, TomlUnavailableError):
|
|
366
|
+
hint = err.hint or "Install tomli or use Python 3.11+."
|
|
367
|
+
return f"TOML parser not available. {hint}"
|
|
368
|
+
if isinstance(err, FilterConfigNotFoundError):
|
|
369
|
+
return f"Filter config not found: {err.path}"
|
|
370
|
+
if isinstance(err, FilterConfigParseError):
|
|
371
|
+
return f"Failed to parse TOML '{err.path}': {err.detail}"
|
|
372
|
+
if isinstance(err, MissingFilterTableError):
|
|
373
|
+
return "Config must contain a [filter] table."
|
|
374
|
+
if isinstance(err, InvalidFilterConfigError):
|
|
375
|
+
return f"Invalid filter config: {err.detail}"
|
|
376
|
+
if isinstance(err, NoTracesFoundError):
|
|
377
|
+
return f"No traces found in database ({err.db_url})."
|
|
378
|
+
if isinstance(err, NoSessionsMatchedError):
|
|
379
|
+
hint = err.hint or "Adjust the filter thresholds or choose a different dataset."
|
|
380
|
+
return f"No sessions matched the provided filters. {hint}"
|
|
381
|
+
return str(err)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
command = filter_command
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def get_command() -> click.Command:
|
|
388
|
+
return command
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FilterCliError(RuntimeError):
|
|
7
|
+
"""Base exception for filter CLI failures."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(slots=True)
|
|
11
|
+
class TomlUnavailableError(FilterCliError):
|
|
12
|
+
hint: str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(slots=True)
|
|
16
|
+
class FilterConfigNotFoundError(FilterCliError):
|
|
17
|
+
path: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(slots=True)
|
|
21
|
+
class FilterConfigParseError(FilterCliError):
|
|
22
|
+
path: str
|
|
23
|
+
detail: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(slots=True)
|
|
27
|
+
class MissingFilterTableError(FilterCliError):
|
|
28
|
+
"""Raised when the filter config lacks a [filter] table."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class InvalidFilterConfigError(FilterCliError):
|
|
33
|
+
detail: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(slots=True)
|
|
37
|
+
class NoTracesFoundError(FilterCliError):
|
|
38
|
+
db_url: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(slots=True)
|
|
42
|
+
class NoSessionsMatchedError(FilterCliError):
|
|
43
|
+
hint: str | None = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"FilterCliError",
|
|
48
|
+
"TomlUnavailableError",
|
|
49
|
+
"FilterConfigNotFoundError",
|
|
50
|
+
"FilterConfigParseError",
|
|
51
|
+
"MissingFilterTableError",
|
|
52
|
+
"InvalidFilterConfigError",
|
|
53
|
+
"NoTracesFoundError",
|
|
54
|
+
"NoSessionsMatchedError",
|
|
55
|
+
]
|