synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +1 -0
- examples/swe/task_app/hosted/rollout.py +2 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- synth_ai/api/models/supported.py +1 -0
- synth_ai/cli/__init__.py +46 -13
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +354 -143
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/judge_schemas.py +8 -8
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +3 -0
- synth_ai/task/rubrics/loaders.py +22 -3
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +144 -0
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +63 -40
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
synth_ai/cli/traces.py
CHANGED
synth_ai/cli/tui.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
CLI: Interactive TUI dashboard for Synth AI.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def register(cli):
|
|
13
|
+
@cli.command()
|
|
14
|
+
@click.option(
|
|
15
|
+
"--url",
|
|
16
|
+
"db_url",
|
|
17
|
+
default="sqlite+libsql://http://127.0.0.1:8080",
|
|
18
|
+
help="Database URL (default: sqlite+libsql://http://127.0.0.1:8080)",
|
|
19
|
+
)
|
|
20
|
+
@click.option("--debug", is_flag=True, help="Enable debug logging")
|
|
21
|
+
def tui(db_url: str, debug: bool):
|
|
22
|
+
"""Launch interactive TUI dashboard showing experiments, balance, and active runs."""
|
|
23
|
+
console = Console()
|
|
24
|
+
|
|
25
|
+
# Import here to avoid circular imports and handle optional dependencies
|
|
26
|
+
try:
|
|
27
|
+
from synth_ai.tui.dashboard import main as tui_main
|
|
28
|
+
except (ImportError, ModuleNotFoundError) as e:
|
|
29
|
+
console.print("[red]Error:[/red] TUI dashboard not available.")
|
|
30
|
+
console.print(f"Missing dependencies: {e}")
|
|
31
|
+
console.print("Install with: pip install textual")
|
|
32
|
+
return
|
|
33
|
+
except Exception as e:
|
|
34
|
+
# Handle other import errors (like missing libsql, type annotation issues, etc.)
|
|
35
|
+
console.print("[red]Error:[/red] TUI dashboard not available.")
|
|
36
|
+
console.print("This may be due to missing dependencies or Python version compatibility.")
|
|
37
|
+
console.print("Try: pip install textual libsql")
|
|
38
|
+
console.print("If using Python < 3.10, you may need to update Python or install eval_type_backport.")
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
# Set environment variables for the TUI to use
|
|
42
|
+
os.environ.setdefault("TUI_DB_URL", db_url)
|
|
43
|
+
if debug:
|
|
44
|
+
os.environ["TUI_DEBUG"] = "1"
|
|
45
|
+
|
|
46
|
+
# Run the TUI by calling the module directly with sanitized argv
|
|
47
|
+
try:
|
|
48
|
+
tui_args = ["--url", db_url]
|
|
49
|
+
if debug:
|
|
50
|
+
tui_args.append("--debug")
|
|
51
|
+
tui_main(tui_args)
|
|
52
|
+
except KeyboardInterrupt:
|
|
53
|
+
console.print("\n[blue]TUI closed.[/blue]")
|
|
54
|
+
except Exception as e:
|
|
55
|
+
console.print(f"\n[red]Error running TUI:[/red] {e}")
|
|
56
|
+
if debug:
|
|
57
|
+
raise
|
synth_ai/cli/turso.py
CHANGED
synth_ai/cli/watch.py
CHANGED
|
@@ -397,7 +397,7 @@ class CrafterClassicEnvironment(StatefulEnvironment, ReproducibleEnvironment[Cra
|
|
|
397
397
|
priv_state, pub_state, self.custom_step_observation_callable
|
|
398
398
|
)
|
|
399
399
|
total_step_time = time.time() - step_start_time
|
|
400
|
-
logger.
|
|
400
|
+
logger.debug(
|
|
401
401
|
f"CrafterClassic step completed in {total_step_time:.3f}s (interact: {interact_time:.3f}s)"
|
|
402
402
|
)
|
|
403
403
|
return obs
|
|
@@ -46,7 +46,7 @@ class VerilogCompileSuccessComponent(RewardComponent):
|
|
|
46
46
|
if hasattr(action, "get") and action.get("type") == "compile":
|
|
47
47
|
# Check if compilation was successful (returncode 0)
|
|
48
48
|
if action.get("returncode") == 0:
|
|
49
|
-
return 0.1
|
|
49
|
+
return 0.01 # Normalized: 0.1 / 10.0 = 0.01
|
|
50
50
|
return 0.0
|
|
51
51
|
|
|
52
52
|
|
|
@@ -55,12 +55,12 @@ class VerilogSimulationPassComponent(RewardComponent):
|
|
|
55
55
|
if hasattr(action, "get") and action.get("type") == "simulate":
|
|
56
56
|
# Check if simulation passed
|
|
57
57
|
if action.get("passed", False):
|
|
58
|
-
return 1.0
|
|
58
|
+
return 0.1 # Normalized: 1.0 / 10.0 = 0.1
|
|
59
59
|
return 0.0
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
class VerilogStepPenaltyComponent(RewardComponent):
|
|
63
|
-
def __init__(self, penalty: float =
|
|
63
|
+
def __init__(self, penalty: float = 0.0): # No per-step reward - only reward accomplishments
|
|
64
64
|
self.penalty = penalty
|
|
65
65
|
|
|
66
66
|
async def score(self, state: Any, action: Any) -> float:
|
|
@@ -68,12 +68,12 @@ class VerilogStepPenaltyComponent(RewardComponent):
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
class VerilogSubmitSuccessComponent(RewardComponent):
|
|
71
|
-
"""Reward for successful submission (tests passed)."""
|
|
71
|
+
"""Reward for successful submission (tests passed). Max reward = 1.0 (normalized)."""
|
|
72
72
|
async def score(self, state: VerilogPublicState, action: Any) -> float:
|
|
73
73
|
if hasattr(action, "get") and action.get("type") == "submit":
|
|
74
74
|
# Check if submission passed
|
|
75
75
|
if action.get("passed", False):
|
|
76
|
-
return
|
|
76
|
+
return 1.0 # Normalized: Maximum reward is now 1.0
|
|
77
77
|
return 0.0
|
|
78
78
|
|
|
79
79
|
|
|
@@ -83,6 +83,9 @@ class VerilogEngine(StatefulEngine):
|
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
85
|
def __init__(self, task_instance: TaskInstance):
|
|
86
|
+
# Validate required Verilog tools are available
|
|
87
|
+
self._validate_verilog_tools()
|
|
88
|
+
|
|
86
89
|
self.task_instance = task_instance
|
|
87
90
|
self._total_reward = 0.0
|
|
88
91
|
self._current_action_for_reward: Optional[Dict[str, Any]] = None
|
|
@@ -92,7 +95,7 @@ class VerilogEngine(StatefulEngine):
|
|
|
92
95
|
VerilogCompileSuccessComponent(),
|
|
93
96
|
VerilogSimulationPassComponent(),
|
|
94
97
|
VerilogSubmitSuccessComponent(),
|
|
95
|
-
VerilogStepPenaltyComponent(penalty
|
|
98
|
+
VerilogStepPenaltyComponent(penalty=0.0), # No per-step reward
|
|
96
99
|
]
|
|
97
100
|
)
|
|
98
101
|
|
|
@@ -103,6 +106,39 @@ class VerilogEngine(StatefulEngine):
|
|
|
103
106
|
# Track last compile/simulate outputs
|
|
104
107
|
self._last_compile_output: Optional[str] = None
|
|
105
108
|
self._last_simulate_output: Optional[str] = None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _validate_verilog_tools() -> None:
|
|
112
|
+
"""Validate that required Verilog tools (iverilog, vvp) are available."""
|
|
113
|
+
missing_tools = []
|
|
114
|
+
|
|
115
|
+
if not shutil.which("iverilog"):
|
|
116
|
+
missing_tools.append("iverilog")
|
|
117
|
+
if not shutil.which("vvp"):
|
|
118
|
+
missing_tools.append("vvp")
|
|
119
|
+
|
|
120
|
+
if missing_tools:
|
|
121
|
+
error_msg = (
|
|
122
|
+
f"🚨🚨🚨 CRITICAL CONFIGURATION ERROR 🚨🚨🚨\n"
|
|
123
|
+
f"\n"
|
|
124
|
+
f"Missing required Verilog tools: {', '.join(missing_tools)}\n"
|
|
125
|
+
f"\n"
|
|
126
|
+
f"The Verilog environment CANNOT function without these tools.\n"
|
|
127
|
+
f"ALL compile/simulate operations will FAIL.\n"
|
|
128
|
+
f"ALL rewards will be ZERO.\n"
|
|
129
|
+
f"Training or evaluation will be COMPLETELY BROKEN.\n"
|
|
130
|
+
f"\n"
|
|
131
|
+
f"🔧 FIX THIS NOW:\n"
|
|
132
|
+
f"1. Add 'iverilog' to apt_packages in Modal deployment config\n"
|
|
133
|
+
f"2. Location: examples/task_apps/verilog/task_app/grpo_verilog.py\n"
|
|
134
|
+
f"3. Look for: modal=ModalDeploymentConfig(\n"
|
|
135
|
+
f"4. Add: apt_packages=('iverilog',) # Provides both iverilog and vvp\n"
|
|
136
|
+
f"5. Redeploy: uvx synth-ai modal-serve grpo-verilog\n"
|
|
137
|
+
f"\n"
|
|
138
|
+
f"{'='*80}"
|
|
139
|
+
)
|
|
140
|
+
print(f"\n{'='*80}\n{error_msg}\n{'='*80}\n", flush=True)
|
|
141
|
+
raise RuntimeError(error_msg)
|
|
106
142
|
|
|
107
143
|
async def _reset_engine(
|
|
108
144
|
self, *, seed: Optional[int] = None
|
|
@@ -133,6 +169,13 @@ class VerilogEngine(StatefulEngine):
|
|
|
133
169
|
) -> Tuple[VerilogPrivateState, VerilogPublicState]:
|
|
134
170
|
"""Process an action result and update engine state."""
|
|
135
171
|
self._current_action_for_reward = action_result
|
|
172
|
+
|
|
173
|
+
# DEBUG: Print action_result
|
|
174
|
+
print(f"\n[ENGINE DEBUG] _step_engine called")
|
|
175
|
+
print(f" action_result: {action_result}")
|
|
176
|
+
print(f" action_result.type: {action_result.get('type')}")
|
|
177
|
+
print(f" action_result.returncode: {action_result.get('returncode')}")
|
|
178
|
+
print(f" action_result.ok: {action_result.get('ok')}")
|
|
136
179
|
|
|
137
180
|
# Update last outputs if this is a compile or simulate action
|
|
138
181
|
if action_result.get("type") == "compile":
|
|
@@ -147,18 +190,21 @@ class VerilogEngine(StatefulEngine):
|
|
|
147
190
|
current_pub_state = VerilogPublicState(
|
|
148
191
|
files=self._get_file_contents(),
|
|
149
192
|
build_dir=str(self.build_dir),
|
|
150
|
-
task_completed=action_result.get("passed", False),
|
|
193
|
+
task_completed=action_result.get("submitted", False) and action_result.get("passed", False),
|
|
151
194
|
)
|
|
152
195
|
|
|
153
196
|
reward_from_stack = await self.reward_stack.step_reward(
|
|
154
197
|
state=current_pub_state, action=self._current_action_for_reward
|
|
155
198
|
)
|
|
156
199
|
self._current_action_for_reward = None
|
|
200
|
+
|
|
201
|
+
# DEBUG: Print reward
|
|
202
|
+
print(f"[ENGINE DEBUG] reward_from_stack: {reward_from_stack}")
|
|
157
203
|
|
|
158
204
|
self._total_reward += reward_from_stack
|
|
159
205
|
|
|
160
|
-
# Check termination conditions
|
|
161
|
-
terminated = action_result.get("
|
|
206
|
+
# Check termination conditions - only terminate if submitted (regardless of pass/fail)
|
|
207
|
+
terminated = action_result.get("submitted", False)
|
|
162
208
|
|
|
163
209
|
priv = VerilogPrivateState(
|
|
164
210
|
reward_last=reward_from_stack,
|
|
@@ -170,7 +216,7 @@ class VerilogEngine(StatefulEngine):
|
|
|
170
216
|
pub = VerilogPublicState(
|
|
171
217
|
files=self._get_file_contents(),
|
|
172
218
|
build_dir=str(self.build_dir),
|
|
173
|
-
task_completed=action_result.get("passed", False),
|
|
219
|
+
task_completed=action_result.get("submitted", False) and action_result.get("passed", False),
|
|
174
220
|
last_compile_output=self._last_compile_output,
|
|
175
221
|
last_simulate_output=self._last_simulate_output,
|
|
176
222
|
)
|
|
@@ -259,6 +305,16 @@ class VerilogEngine(StatefulEngine):
|
|
|
259
305
|
}
|
|
260
306
|
except subprocess.TimeoutExpired:
|
|
261
307
|
return {"ok": False, "error": "Compilation timeout", "type": "compile"}
|
|
308
|
+
except FileNotFoundError:
|
|
309
|
+
error_msg = (
|
|
310
|
+
"🚨 CRITICAL ERROR: 'iverilog' executable not found! 🚨\n"
|
|
311
|
+
"The Verilog compiler (iverilog) is not installed in this environment.\n"
|
|
312
|
+
"This will cause ALL compile operations to fail and result in ZERO rewards.\n"
|
|
313
|
+
"Fix: Add 'iverilog' to apt_packages in the Modal deployment config.\n"
|
|
314
|
+
"Location: examples/task_apps/verilog/task_app/grpo_verilog.py -> modal=ModalDeploymentConfig(apt_packages=('iverilog',))"
|
|
315
|
+
)
|
|
316
|
+
print(f"\n{'='*80}\n{error_msg}\n{'='*80}\n", flush=True)
|
|
317
|
+
raise RuntimeError(error_msg) from None
|
|
262
318
|
except Exception as e:
|
|
263
319
|
return {"ok": False, "error": str(e), "type": "compile"}
|
|
264
320
|
|
|
@@ -290,6 +346,16 @@ class VerilogEngine(StatefulEngine):
|
|
|
290
346
|
}
|
|
291
347
|
except subprocess.TimeoutExpired:
|
|
292
348
|
return {"ok": False, "error": "Simulation timeout", "type": "simulate"}
|
|
349
|
+
except FileNotFoundError:
|
|
350
|
+
error_msg = (
|
|
351
|
+
"🚨 CRITICAL ERROR: 'vvp' executable not found! 🚨\n"
|
|
352
|
+
"The Verilog simulator (vvp) is not installed in this environment.\n"
|
|
353
|
+
"This will cause ALL simulate operations to fail and result in ZERO rewards.\n"
|
|
354
|
+
"Fix: Add 'iverilog' to apt_packages in the Modal deployment config (provides both iverilog and vvp).\n"
|
|
355
|
+
"Location: examples/task_apps/verilog/task_app/grpo_verilog.py -> modal=ModalDeploymentConfig(apt_packages=('iverilog',))"
|
|
356
|
+
)
|
|
357
|
+
print(f"\n{'='*80}\n{error_msg}\n{'='*80}\n", flush=True)
|
|
358
|
+
raise RuntimeError(error_msg) from None
|
|
293
359
|
except Exception as e:
|
|
294
360
|
return {"ok": False, "error": str(e), "type": "simulate"}
|
|
295
361
|
|
synth_ai/judge_schemas.py
CHANGED
|
@@ -9,7 +9,7 @@ This is the canonical contract that the backend MUST conform to.
|
|
|
9
9
|
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
|
-
from typing import Any, Literal
|
|
12
|
+
from typing import Any, Dict, Literal, Optional
|
|
13
13
|
|
|
14
14
|
from pydantic import BaseModel, Field
|
|
15
15
|
|
|
@@ -31,7 +31,7 @@ class ReviewPayload(BaseModel):
|
|
|
31
31
|
description="Map of criterion keys to their scores"
|
|
32
32
|
)
|
|
33
33
|
total: float = Field(default=0.0, description="Aggregated total score")
|
|
34
|
-
summary: str
|
|
34
|
+
summary: Optional[str] = Field(None, description="Optional text summary")
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class JudgeScoreResponse(BaseModel):
|
|
@@ -46,7 +46,7 @@ class JudgeScoreResponse(BaseModel):
|
|
|
46
46
|
default_factory=list,
|
|
47
47
|
description="List of per-event rubric reviews (one per step)"
|
|
48
48
|
)
|
|
49
|
-
outcome_review: ReviewPayload
|
|
49
|
+
outcome_review: Optional[ReviewPayload] = Field(
|
|
50
50
|
None,
|
|
51
51
|
description="Optional outcome-level rubric review"
|
|
52
52
|
)
|
|
@@ -92,15 +92,15 @@ class JudgeTaskApp(BaseModel):
|
|
|
92
92
|
"""Task application metadata."""
|
|
93
93
|
|
|
94
94
|
id: str = Field(..., description="Task app identifier")
|
|
95
|
-
base_url: str
|
|
95
|
+
base_url: Optional[str] = Field(None, description="Optional base URL for task app")
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
class JudgeOptions(BaseModel):
|
|
99
99
|
"""Judge provider and configuration options."""
|
|
100
100
|
|
|
101
|
-
provider: str
|
|
102
|
-
model: str
|
|
103
|
-
rubric_id: str
|
|
101
|
+
provider: Optional[str] = Field(None, description="Judge provider (e.g., 'openai', 'groq')")
|
|
102
|
+
model: Optional[str] = Field(None, description="Model identifier")
|
|
103
|
+
rubric_id: Optional[str] = Field(None, description="Rubric identifier")
|
|
104
104
|
event: bool = Field(True, description="Enable event-level judging")
|
|
105
105
|
outcome: bool = Field(True, description="Enable outcome-level judging")
|
|
106
106
|
|
|
@@ -123,5 +123,5 @@ class JudgeScoreRequest(BaseModel):
|
|
|
123
123
|
task_app: JudgeTaskApp = Field(..., description="Task application metadata")
|
|
124
124
|
trace: JudgeTracePayload = Field(..., description="Trajectory trace to evaluate")
|
|
125
125
|
options: JudgeOptions = Field(default_factory=lambda: JudgeOptions(), description="Judge options")
|
|
126
|
-
rubric:
|
|
126
|
+
rubric: Optional[Dict[str, Any]] = Field(None, description="Optional explicit rubric criteria")
|
|
127
127
|
|
synth_ai/task/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from .auth import (
|
|
|
4
4
|
require_api_key_dependency,
|
|
5
5
|
)
|
|
6
6
|
from .client import TaskAppClient
|
|
7
|
+
from .config import EvalConfig, FilterConfig
|
|
7
8
|
from .contracts import (
|
|
8
9
|
DatasetInfo,
|
|
9
10
|
InferenceInfo,
|
|
@@ -50,7 +51,12 @@ from .server import (
|
|
|
50
51
|
create_task_app,
|
|
51
52
|
run_task_app,
|
|
52
53
|
)
|
|
53
|
-
from .validators import
|
|
54
|
+
from .validators import (
|
|
55
|
+
normalize_inference_url,
|
|
56
|
+
validate_rollout_response_for_rl,
|
|
57
|
+
validate_task_app_endpoint,
|
|
58
|
+
validate_task_app_url,
|
|
59
|
+
)
|
|
54
60
|
from .vendors import (
|
|
55
61
|
get_groq_key_or_503,
|
|
56
62
|
get_openai_key_or_503,
|
|
@@ -58,9 +64,13 @@ from .vendors import (
|
|
|
58
64
|
)
|
|
59
65
|
|
|
60
66
|
__all__ = [
|
|
67
|
+
"normalize_inference_url",
|
|
68
|
+
"validate_rollout_response_for_rl",
|
|
61
69
|
"validate_task_app_url",
|
|
62
70
|
"validate_task_app_endpoint",
|
|
63
71
|
"task_app_health",
|
|
72
|
+
"EvalConfig",
|
|
73
|
+
"FilterConfig",
|
|
64
74
|
"TaskAppEndpoints",
|
|
65
75
|
"RolloutEnvSpec",
|
|
66
76
|
"RolloutPolicySpec",
|
synth_ai/task/apps/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ class ModalDeploymentConfig:
|
|
|
18
18
|
app_name: str
|
|
19
19
|
python_version: str = "3.11"
|
|
20
20
|
pip_packages: Sequence[str] = field(default_factory=tuple)
|
|
21
|
+
apt_packages: Sequence[str] = field(default_factory=tuple)
|
|
21
22
|
extra_local_dirs: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
|
22
23
|
secret_names: Sequence[str] = field(default_factory=tuple)
|
|
23
24
|
volume_mounts: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
synth_ai/task/config.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Configuration dataclasses for task app CLI commands (eval, filter)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(slots=True)
|
|
11
|
+
class EvalConfig:
|
|
12
|
+
"""Configuration for 'synth-ai eval' command.
|
|
13
|
+
|
|
14
|
+
Validates and provides defaults for evaluation runs against task apps.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# Required: Task app identifier
|
|
18
|
+
app_id: str
|
|
19
|
+
|
|
20
|
+
# Required: Model to evaluate
|
|
21
|
+
model: str
|
|
22
|
+
|
|
23
|
+
# Required: Seeds to run
|
|
24
|
+
seeds: list[int]
|
|
25
|
+
|
|
26
|
+
# Optional: Task app URL (None = spawn in-process)
|
|
27
|
+
task_app_url: str | None = None
|
|
28
|
+
|
|
29
|
+
# Optional: Data split to use
|
|
30
|
+
split: str = "train"
|
|
31
|
+
|
|
32
|
+
# Optional: Maximum turns/steps per episode
|
|
33
|
+
max_turns: int | None = None
|
|
34
|
+
|
|
35
|
+
# Optional: Maximum LLM calls per episode
|
|
36
|
+
max_llm_calls: int = 10
|
|
37
|
+
|
|
38
|
+
# Optional: Concurrency for parallel rollouts
|
|
39
|
+
concurrency: int = 1
|
|
40
|
+
|
|
41
|
+
# Optional: Environment name
|
|
42
|
+
env_name: str | None = None
|
|
43
|
+
|
|
44
|
+
# Optional: Policy name
|
|
45
|
+
policy_name: str | None = None
|
|
46
|
+
|
|
47
|
+
# Optional: Trace format ("compact", "full", "structured")
|
|
48
|
+
trace_format: Literal["compact", "full", "structured"] = "compact"
|
|
49
|
+
|
|
50
|
+
# Optional: Whether to return traces in response
|
|
51
|
+
return_trace: bool = False
|
|
52
|
+
|
|
53
|
+
# Optional: Operations sequence (if not provided, generates default)
|
|
54
|
+
ops: list[str] | None = None
|
|
55
|
+
|
|
56
|
+
# Optional: Environment config overrides
|
|
57
|
+
env_config: dict[str, Any] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
# Optional: Policy config overrides
|
|
60
|
+
policy_config: dict[str, Any] = field(default_factory=dict)
|
|
61
|
+
|
|
62
|
+
# Optional: Metadata for traces
|
|
63
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
# Optional: SQL query for metadata filtering
|
|
66
|
+
metadata_sql: str | None = None
|
|
67
|
+
|
|
68
|
+
def __post_init__(self):
|
|
69
|
+
"""Validate configuration after initialization."""
|
|
70
|
+
if not self.app_id:
|
|
71
|
+
raise ValueError("app_id is required")
|
|
72
|
+
|
|
73
|
+
if not self.model:
|
|
74
|
+
raise ValueError("model is required")
|
|
75
|
+
|
|
76
|
+
if not self.seeds:
|
|
77
|
+
raise ValueError("seeds list cannot be empty")
|
|
78
|
+
|
|
79
|
+
if not isinstance(self.seeds, list):
|
|
80
|
+
raise ValueError("seeds must be a list of integers")
|
|
81
|
+
|
|
82
|
+
if self.concurrency < 1:
|
|
83
|
+
raise ValueError("concurrency must be >= 1")
|
|
84
|
+
|
|
85
|
+
if self.max_llm_calls < 1:
|
|
86
|
+
raise ValueError("max_llm_calls must be >= 1")
|
|
87
|
+
|
|
88
|
+
if self.max_turns is not None and self.max_turns < 1:
|
|
89
|
+
raise ValueError("max_turns must be >= 1")
|
|
90
|
+
|
|
91
|
+
if self.trace_format not in ("compact", "full", "structured"):
|
|
92
|
+
raise ValueError(f"trace_format must be 'compact', 'full', or 'structured', got: {self.trace_format}")
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_dict(cls, data: dict[str, Any]) -> EvalConfig:
|
|
96
|
+
"""Create EvalConfig from a dictionary (e.g. from TOML).
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
data: Dictionary with eval configuration
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Validated EvalConfig instance
|
|
103
|
+
"""
|
|
104
|
+
# Extract known fields
|
|
105
|
+
config_dict = {
|
|
106
|
+
"app_id": data.get("app_id"),
|
|
107
|
+
"model": data.get("model"),
|
|
108
|
+
"seeds": data.get("seeds", []),
|
|
109
|
+
"task_app_url": data.get("task_app_url"),
|
|
110
|
+
"split": data.get("split", "train"),
|
|
111
|
+
"max_turns": data.get("max_turns"),
|
|
112
|
+
"max_llm_calls": data.get("max_llm_calls", 10),
|
|
113
|
+
"concurrency": data.get("concurrency", 1),
|
|
114
|
+
"env_name": data.get("env_name"),
|
|
115
|
+
"policy_name": data.get("policy_name"),
|
|
116
|
+
"trace_format": data.get("trace_format", "compact"),
|
|
117
|
+
"return_trace": data.get("return_trace", False),
|
|
118
|
+
"ops": data.get("ops"),
|
|
119
|
+
"env_config": data.get("env_config", {}),
|
|
120
|
+
"policy_config": data.get("policy_config", {}),
|
|
121
|
+
"metadata": data.get("metadata", {}),
|
|
122
|
+
"metadata_sql": data.get("metadata_sql"),
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
return cls(**config_dict)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass(slots=True)
|
|
129
|
+
class FilterConfig:
|
|
130
|
+
"""Configuration for 'synth-ai filter' command.
|
|
131
|
+
|
|
132
|
+
Validates and provides defaults for filtering traces into SFT datasets.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Required: Database path or URL
|
|
136
|
+
db: str
|
|
137
|
+
|
|
138
|
+
# Required: Output JSONL path
|
|
139
|
+
output: str
|
|
140
|
+
|
|
141
|
+
# Optional: Filter by data splits
|
|
142
|
+
splits: list[str] = field(default_factory=list)
|
|
143
|
+
|
|
144
|
+
# Optional: Filter by task IDs
|
|
145
|
+
task_ids: list[str] = field(default_factory=list)
|
|
146
|
+
|
|
147
|
+
# Optional: Filter by models
|
|
148
|
+
models: list[str] = field(default_factory=list)
|
|
149
|
+
|
|
150
|
+
# Optional: Minimum official score threshold
|
|
151
|
+
min_official_score: float | None = None
|
|
152
|
+
|
|
153
|
+
# Optional: Maximum official score threshold
|
|
154
|
+
max_official_score: float | None = None
|
|
155
|
+
|
|
156
|
+
# Optional: Minimum judge scores (judge_name -> min_score)
|
|
157
|
+
min_judge_scores: dict[str, float] = field(default_factory=dict)
|
|
158
|
+
|
|
159
|
+
# Optional: Maximum judge scores (judge_name -> max_score)
|
|
160
|
+
max_judge_scores: dict[str, float] = field(default_factory=dict)
|
|
161
|
+
|
|
162
|
+
# Optional: Limit number of examples
|
|
163
|
+
limit: int | None = None
|
|
164
|
+
|
|
165
|
+
# Optional: Offset for pagination
|
|
166
|
+
offset: int | None = None
|
|
167
|
+
|
|
168
|
+
# Optional: Whether to shuffle results
|
|
169
|
+
shuffle: bool = False
|
|
170
|
+
|
|
171
|
+
# Optional: Random seed for shuffling
|
|
172
|
+
shuffle_seed: int | None = None
|
|
173
|
+
|
|
174
|
+
def __post_init__(self):
|
|
175
|
+
"""Validate configuration after initialization."""
|
|
176
|
+
if not self.db:
|
|
177
|
+
raise ValueError("db (database path or URL) is required")
|
|
178
|
+
|
|
179
|
+
if not self.output:
|
|
180
|
+
raise ValueError("output (JSONL file path) is required")
|
|
181
|
+
|
|
182
|
+
# Validate output has .jsonl extension
|
|
183
|
+
output_path = Path(self.output)
|
|
184
|
+
if output_path.suffix.lower() not in (".jsonl", ".json"):
|
|
185
|
+
raise ValueError(f"output must be a .jsonl or .json file, got: {self.output}")
|
|
186
|
+
|
|
187
|
+
# Validate score thresholds
|
|
188
|
+
if self.min_official_score is not None and self.max_official_score is not None:
|
|
189
|
+
if self.min_official_score > self.max_official_score:
|
|
190
|
+
raise ValueError("min_official_score cannot be greater than max_official_score")
|
|
191
|
+
|
|
192
|
+
# Validate limit/offset
|
|
193
|
+
if self.limit is not None and self.limit < 1:
|
|
194
|
+
raise ValueError("limit must be >= 1")
|
|
195
|
+
|
|
196
|
+
if self.offset is not None and self.offset < 0:
|
|
197
|
+
raise ValueError("offset must be >= 0")
|
|
198
|
+
|
|
199
|
+
# Validate shuffle seed requires shuffle
|
|
200
|
+
if self.shuffle_seed is not None and not self.shuffle:
|
|
201
|
+
raise ValueError("shuffle_seed requires shuffle=true")
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def from_dict(cls, data: dict[str, Any]) -> FilterConfig:
|
|
205
|
+
"""Create FilterConfig from a dictionary (e.g. from TOML).
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
data: Dictionary with filter configuration
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Validated FilterConfig instance
|
|
212
|
+
"""
|
|
213
|
+
# Extract known fields
|
|
214
|
+
config_dict = {
|
|
215
|
+
"db": data.get("db"),
|
|
216
|
+
"output": data.get("output"),
|
|
217
|
+
"splits": data.get("splits", []),
|
|
218
|
+
"task_ids": data.get("task_ids", []),
|
|
219
|
+
"models": data.get("models", []),
|
|
220
|
+
"min_official_score": data.get("min_official_score"),
|
|
221
|
+
"max_official_score": data.get("max_official_score"),
|
|
222
|
+
"min_judge_scores": data.get("min_judge_scores", {}),
|
|
223
|
+
"max_judge_scores": data.get("max_judge_scores", {}),
|
|
224
|
+
"limit": data.get("limit"),
|
|
225
|
+
"offset": data.get("offset"),
|
|
226
|
+
"shuffle": data.get("shuffle", False),
|
|
227
|
+
"shuffle_seed": data.get("shuffle_seed"),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
return cls(**config_dict)
|
|
231
|
+
|
|
232
|
+
def get_db_url(self) -> str:
|
|
233
|
+
"""Convert db path to proper SQLite URL if needed.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Database URL suitable for SQLAlchemy/aiosqlite
|
|
237
|
+
"""
|
|
238
|
+
db_value = self.db.strip()
|
|
239
|
+
if "://" in db_value:
|
|
240
|
+
return db_value
|
|
241
|
+
else:
|
|
242
|
+
db_path = Path(db_value).expanduser().resolve()
|
|
243
|
+
# Ensure parent directory exists
|
|
244
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
245
|
+
return f"sqlite+aiosqlite:///{db_path}"
|
|
246
|
+
|
|
247
|
+
def get_output_path(self) -> Path:
|
|
248
|
+
"""Get resolved output path with parent directory created.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Resolved Path object with parent directory created
|
|
252
|
+
"""
|
|
253
|
+
output_path = Path(self.output).expanduser().resolve()
|
|
254
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
255
|
+
return output_path
|
|
256
|
+
|
|
257
|
+
|