synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +1 -0
- examples/swe/task_app/hosted/rollout.py +2 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- synth_ai/api/models/supported.py +1 -0
- synth_ai/cli/__init__.py +46 -13
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +354 -143
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/judge_schemas.py +8 -8
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +3 -0
- synth_ai/task/rubrics/loaders.py +22 -3
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +144 -0
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +63 -40
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -24,9 +24,9 @@ import types
|
|
|
24
24
|
import uuid
|
|
25
25
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
26
26
|
from dataclasses import dataclass
|
|
27
|
-
from datetime import
|
|
27
|
+
from datetime import datetime, timezone
|
|
28
28
|
from pathlib import Path
|
|
29
|
-
from typing import Any, cast
|
|
29
|
+
from typing import Any, Optional, cast
|
|
30
30
|
|
|
31
31
|
try: # Python 3.11+
|
|
32
32
|
import tomllib as _toml
|
|
@@ -36,19 +36,29 @@ except Exception: # pragma: no cover - fallback
|
|
|
36
36
|
import click
|
|
37
37
|
from click.exceptions import Abort
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
39
|
+
# Tracing imports - make conditional for optional dependencies
|
|
40
|
+
try:
|
|
41
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
42
|
+
BaseEvent,
|
|
43
|
+
EnvironmentEvent,
|
|
44
|
+
RuntimeEvent,
|
|
45
|
+
SessionEventMarkovBlanketMessage,
|
|
46
|
+
SessionMessageContent,
|
|
47
|
+
SessionTimeStep,
|
|
48
|
+
SessionTracer,
|
|
49
|
+
TimeRecord,
|
|
50
|
+
)
|
|
51
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
52
|
+
SessionTrace as V3SessionTrace,
|
|
53
|
+
)
|
|
54
|
+
_TRACING_AVAILABLE = True
|
|
55
|
+
except (ImportError, ModuleNotFoundError, TypeError):
|
|
56
|
+
# Tracing system not available (missing optional dependencies)
|
|
57
|
+
BaseEvent = EnvironmentEvent = RuntimeEvent = None # type: ignore
|
|
58
|
+
SessionEventMarkovBlanketMessage = SessionMessageContent = None # type: ignore
|
|
59
|
+
SessionTimeStep = SessionTracer = TimeRecord = None # type: ignore
|
|
60
|
+
V3SessionTrace = None # type: ignore
|
|
61
|
+
_TRACING_AVAILABLE = False
|
|
52
62
|
|
|
53
63
|
# ---------------------------------------------------------------------------
|
|
54
64
|
# Dynamic imports to avoid hard dependencies during type checking.
|
|
@@ -82,14 +92,14 @@ except Exception as exc: # pragma: no cover - critical dependency
|
|
|
82
92
|
raise RuntimeError("Unable to load task app server utilities") from exc
|
|
83
93
|
|
|
84
94
|
|
|
85
|
-
def _load_demo_directory() -> Path
|
|
95
|
+
def _load_demo_directory() -> Optional[Path]:
|
|
86
96
|
"""Return the demo task apps directory if available."""
|
|
87
97
|
|
|
88
98
|
try:
|
|
89
99
|
module = cast(
|
|
90
100
|
Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
91
101
|
)
|
|
92
|
-
loader = cast(Callable[[], str | Path
|
|
102
|
+
loader = cast(Callable[[], Optional[str | Path]], module.load_demo_dir)
|
|
93
103
|
demo_dir = loader()
|
|
94
104
|
if isinstance(demo_dir, str | Path):
|
|
95
105
|
demo_path = Path(demo_dir)
|
|
@@ -129,7 +139,7 @@ DEFAULT_SEARCH_RELATIVE = (
|
|
|
129
139
|
)
|
|
130
140
|
|
|
131
141
|
|
|
132
|
-
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float
|
|
142
|
+
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> Optional[float]:
|
|
133
143
|
if len(xs) != len(ys) or len(xs) < 2:
|
|
134
144
|
return None
|
|
135
145
|
mean_x = sum(xs) / len(xs)
|
|
@@ -154,7 +164,7 @@ class AppChoice:
|
|
|
154
164
|
label: str
|
|
155
165
|
path: Path
|
|
156
166
|
source: str
|
|
157
|
-
description: str
|
|
167
|
+
description: Optional[str] = None
|
|
158
168
|
aliases: tuple[str, ...] = ()
|
|
159
169
|
entry: TaskAppEntryType | None = None
|
|
160
170
|
entry_loader: Callable[[], TaskAppEntryType] | None = None
|
|
@@ -178,21 +188,21 @@ class JudgeSpec:
|
|
|
178
188
|
kwargs: dict[str, Any]
|
|
179
189
|
|
|
180
190
|
|
|
181
|
-
def _parse_datetime_for_trace(value: Any) -> datetime
|
|
191
|
+
def _parse_datetime_for_trace(value: Any) -> Optional[datetime]:
|
|
182
192
|
if isinstance(value, datetime):
|
|
183
|
-
return value if value.tzinfo else value.replace(tzinfo=
|
|
193
|
+
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
|
184
194
|
if isinstance(value, str):
|
|
185
195
|
value = value.replace("Z", "+00:00")
|
|
186
196
|
try:
|
|
187
197
|
dt = datetime.fromisoformat(value)
|
|
188
198
|
except ValueError:
|
|
189
199
|
try:
|
|
190
|
-
dt = datetime.fromtimestamp(float(value), tz=
|
|
200
|
+
dt = datetime.fromtimestamp(float(value), tz=timezone.utc)
|
|
191
201
|
except Exception:
|
|
192
202
|
return None
|
|
193
|
-
return dt if dt.tzinfo else dt.replace(tzinfo=
|
|
203
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
|
|
194
204
|
if isinstance(value, int | float):
|
|
195
|
-
return datetime.fromtimestamp(float(value), tz=
|
|
205
|
+
return datetime.fromtimestamp(float(value), tz=timezone.utc)
|
|
196
206
|
return None
|
|
197
207
|
|
|
198
208
|
|
|
@@ -269,7 +279,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
|
269
279
|
for msg in payload.get("markov_blanket_messages", [])
|
|
270
280
|
if isinstance(msg, dict)
|
|
271
281
|
]
|
|
272
|
-
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(
|
|
282
|
+
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(timezone.utc)
|
|
273
283
|
completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
|
|
274
284
|
return SessionTimeStep(
|
|
275
285
|
step_id=payload.get("step_id", ""),
|
|
@@ -283,7 +293,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
|
283
293
|
)
|
|
284
294
|
|
|
285
295
|
|
|
286
|
-
def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace
|
|
296
|
+
def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace]:
|
|
287
297
|
if not isinstance(payload, dict):
|
|
288
298
|
return None
|
|
289
299
|
steps = [
|
|
@@ -301,7 +311,7 @@ def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
|
|
|
301
311
|
for msg in payload.get("markov_blanket_message_history", [])
|
|
302
312
|
if isinstance(msg, dict)
|
|
303
313
|
]
|
|
304
|
-
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(
|
|
314
|
+
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(timezone.utc)
|
|
305
315
|
metadata = payload.get("metadata") or {}
|
|
306
316
|
session_metadata = payload.get("session_metadata")
|
|
307
317
|
return V3SessionTrace(
|
|
@@ -320,21 +330,43 @@ async def _store_trace(
|
|
|
320
330
|
trace_namespace: dict[str, Any] | None,
|
|
321
331
|
extra_metadata: dict[str, Any] | None = None,
|
|
322
332
|
):
|
|
333
|
+
import logging
|
|
334
|
+
_logger = logging.getLogger(__name__)
|
|
335
|
+
|
|
336
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
|
|
337
|
+
|
|
323
338
|
if tracer is None or not isinstance(trace_namespace, dict):
|
|
339
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] Early return: tracer={tracer is not None}, trace_namespace type={type(trace_namespace)}")
|
|
324
340
|
return
|
|
341
|
+
|
|
342
|
+
_logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
|
|
343
|
+
|
|
325
344
|
session_payload = trace_namespace.get("session_trace")
|
|
326
345
|
if not isinstance(session_payload, dict):
|
|
346
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] No session_trace found or wrong type: {type(session_payload)}")
|
|
327
347
|
return
|
|
348
|
+
|
|
349
|
+
_logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
|
|
350
|
+
msg_count = len(session_payload.get("markov_blanket_message_history", []))
|
|
351
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Found {msg_count} messages in session_payload")
|
|
352
|
+
|
|
328
353
|
trace_obj = _session_trace_from_dict(session_payload)
|
|
329
354
|
if trace_obj is None:
|
|
355
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] _session_trace_from_dict returned None")
|
|
330
356
|
return
|
|
357
|
+
|
|
358
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
|
|
359
|
+
|
|
331
360
|
if tracer.db is None:
|
|
332
361
|
await tracer.initialize()
|
|
333
362
|
meta = dict(trace_obj.metadata or {})
|
|
334
363
|
if extra_metadata:
|
|
335
364
|
meta.update(extra_metadata)
|
|
336
365
|
trace_obj.metadata = meta
|
|
366
|
+
|
|
367
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
|
|
337
368
|
await tracer.db.insert_session_trace(trace_obj)
|
|
369
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Successfully inserted trace")
|
|
338
370
|
|
|
339
371
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
340
372
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
@@ -3044,6 +3076,11 @@ def _write_modal_entrypoint(
|
|
|
3044
3076
|
if not any(str(p).startswith("synth-ai") for p in pip_packages):
|
|
3045
3077
|
pip_packages.insert(0, synth_pkg)
|
|
3046
3078
|
|
|
3079
|
+
apt_packages = list(modal_cfg.apt_packages)
|
|
3080
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages type: {type(modal_cfg.apt_packages)}")
|
|
3081
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages value: {modal_cfg.apt_packages}")
|
|
3082
|
+
click.echo(f"[DEBUG] apt_packages after list(): {apt_packages}")
|
|
3083
|
+
|
|
3047
3084
|
local_dirs = [(str(Path(src)), dst) for src, dst in modal_cfg.extra_local_dirs]
|
|
3048
3085
|
# Also mount the host synth_ai source if available to ensure latest code is used
|
|
3049
3086
|
if host_synth is not None:
|
|
@@ -3090,6 +3127,15 @@ INLINE_SECRET_VALUES = {inline_secret_values!r}
|
|
|
3090
3127
|
|
|
3091
3128
|
image = Image.debian_slim(python_version={modal_cfg.python_version!r})
|
|
3092
3129
|
|
|
3130
|
+
# CRITICAL: Install iverilog for Verilog task app (hardcoded to prevent config issues)
|
|
3131
|
+
if {entry.app_id!r} == "grpo-verilog":
|
|
3132
|
+
image = image.apt_install("iverilog")
|
|
3133
|
+
|
|
3134
|
+
# Install apt packages first (before pip)
|
|
3135
|
+
apt_packages = {apt_packages!r}
|
|
3136
|
+
if apt_packages:
|
|
3137
|
+
image = image.apt_install(*apt_packages)
|
|
3138
|
+
|
|
3093
3139
|
pip_packages = {pip_packages!r}
|
|
3094
3140
|
if pip_packages:
|
|
3095
3141
|
image = image.pip_install(*pip_packages)
|
|
@@ -3251,7 +3297,7 @@ def register(cli: click.Group) -> None:
|
|
|
3251
3297
|
)
|
|
3252
3298
|
@click.option(
|
|
3253
3299
|
"--trace-db",
|
|
3254
|
-
default="traces/v3/
|
|
3300
|
+
default="traces/v3/synth_ai.db",
|
|
3255
3301
|
show_default=True,
|
|
3256
3302
|
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3257
3303
|
)
|
|
@@ -3284,8 +3330,13 @@ def eval_command(
|
|
|
3284
3330
|
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3285
3331
|
forward authentication headers to the running service.
|
|
3286
3332
|
"""
|
|
3333
|
+
# Parse and validate TOML config
|
|
3334
|
+
from synth_ai.task.config import EvalConfig
|
|
3335
|
+
|
|
3287
3336
|
cfg: dict[str, Any] = {}
|
|
3337
|
+
eval_cfg: EvalConfig | None = None
|
|
3288
3338
|
config_path: Path | None = None
|
|
3339
|
+
|
|
3289
3340
|
if config:
|
|
3290
3341
|
config_path = Path(config)
|
|
3291
3342
|
else:
|
|
@@ -3307,21 +3358,37 @@ def eval_command(
|
|
|
3307
3358
|
if isinstance(parsed, dict):
|
|
3308
3359
|
section = parsed.get("eval")
|
|
3309
3360
|
cfg = dict(section) if isinstance(section, dict) else dict(parsed)
|
|
3361
|
+
|
|
3362
|
+
# Validate config with dataclass
|
|
3363
|
+
try:
|
|
3364
|
+
eval_cfg = EvalConfig.from_dict(cfg)
|
|
3365
|
+
click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
|
|
3366
|
+
except (ValueError, TypeError) as validation_error:
|
|
3367
|
+
raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
|
|
3368
|
+
except click.ClickException:
|
|
3369
|
+
raise
|
|
3310
3370
|
except Exception as exc:
|
|
3311
3371
|
raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
|
|
3312
3372
|
|
|
3313
|
-
|
|
3373
|
+
# CLI args override config
|
|
3374
|
+
if eval_cfg:
|
|
3375
|
+
app_id = app_id or eval_cfg.app_id
|
|
3376
|
+
else:
|
|
3377
|
+
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
3314
3378
|
|
|
3315
3379
|
metadata_filters: dict[str, str] = {}
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
|
|
3319
|
-
|
|
3320
|
-
|
|
3321
|
-
|
|
3322
|
-
|
|
3323
|
-
|
|
3324
|
-
|
|
3380
|
+
if eval_cfg:
|
|
3381
|
+
metadata_filters.update(eval_cfg.metadata)
|
|
3382
|
+
else:
|
|
3383
|
+
cfg_metadata = cfg.get("metadata")
|
|
3384
|
+
if isinstance(cfg_metadata, dict):
|
|
3385
|
+
for key, value in cfg_metadata.items():
|
|
3386
|
+
metadata_filters[str(key)] = str(value)
|
|
3387
|
+
elif isinstance(cfg_metadata, list):
|
|
3388
|
+
for item in cfg_metadata:
|
|
3389
|
+
if isinstance(item, str) and "=" in item:
|
|
3390
|
+
key, value = item.split("=", 1)
|
|
3391
|
+
metadata_filters[key.strip()] = value.strip()
|
|
3325
3392
|
|
|
3326
3393
|
for item in metadata or ():
|
|
3327
3394
|
if "=" not in item:
|
|
@@ -3334,11 +3401,14 @@ def eval_command(
|
|
|
3334
3401
|
metadata_filters[key] = value
|
|
3335
3402
|
|
|
3336
3403
|
metadata_sql_query: str | None = None
|
|
3337
|
-
|
|
3338
|
-
|
|
3339
|
-
|
|
3340
|
-
|
|
3341
|
-
|
|
3404
|
+
if eval_cfg and eval_cfg.metadata_sql:
|
|
3405
|
+
metadata_sql_query = eval_cfg.metadata_sql
|
|
3406
|
+
else:
|
|
3407
|
+
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3408
|
+
if isinstance(cfg_metadata_sql, dict):
|
|
3409
|
+
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3410
|
+
elif isinstance(cfg_metadata_sql, str):
|
|
3411
|
+
metadata_sql_query = cfg_metadata_sql
|
|
3342
3412
|
|
|
3343
3413
|
if metadata_sql:
|
|
3344
3414
|
metadata_sql_query = metadata_sql
|
|
@@ -3780,18 +3850,52 @@ def eval_command(
|
|
|
3780
3850
|
|
|
3781
3851
|
async def _run_seed(seed_val: int) -> None:
|
|
3782
3852
|
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3853
|
+
# Read env_name and policy_name from config if available
|
|
3854
|
+
env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
|
|
3855
|
+
policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
|
|
3856
|
+
env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
|
|
3857
|
+
policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
|
|
3858
|
+
|
|
3859
|
+
# Debug: print config parsing
|
|
3860
|
+
if seed_val == 0:
|
|
3861
|
+
click.echo(f"[DEBUG] env_name from config: {env_name}")
|
|
3862
|
+
click.echo(f"[DEBUG] policy_name from config: {policy_name}")
|
|
3863
|
+
|
|
3864
|
+
# Generate default ops sequence if not provided
|
|
3865
|
+
max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
|
|
3866
|
+
ops_list = cfg.get("ops", [])
|
|
3867
|
+
if not ops_list:
|
|
3868
|
+
# Generate default "agent, env" pairs for max_llm_calls
|
|
3869
|
+
ops_list = ["agent", "env"] * int(max_llm_calls)
|
|
3870
|
+
|
|
3783
3871
|
body = {
|
|
3784
3872
|
"run_id": str(uuid.uuid4()),
|
|
3785
|
-
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
3873
|
+
"env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
|
|
3786
3874
|
"policy": {
|
|
3787
|
-
"policy_name": selected_model,
|
|
3788
|
-
"config": {"model": selected_model, **policy_overrides},
|
|
3875
|
+
"policy_name": policy_name or selected_model,
|
|
3876
|
+
"config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
|
|
3877
|
+
},
|
|
3878
|
+
"ops": ops_list,
|
|
3879
|
+
"record": {
|
|
3880
|
+
"return_trace": cfg.get("return_trace", True),
|
|
3881
|
+
"trace_format": cfg.get("trace_format", "structured"),
|
|
3789
3882
|
},
|
|
3790
|
-
"
|
|
3883
|
+
"mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
|
|
3791
3884
|
}
|
|
3885
|
+
if env_name:
|
|
3886
|
+
body["env"]["env_name"] = env_name
|
|
3887
|
+
|
|
3888
|
+
# Debug: print the body being sent
|
|
3889
|
+
if seed_val == 0:
|
|
3890
|
+
click.echo(f"[DEBUG] rollout body env: {body['env']}")
|
|
3891
|
+
click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
|
|
3892
|
+
click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
|
|
3792
3893
|
rollout_elapsed: float | None = None
|
|
3793
3894
|
rollout_start = time.perf_counter()
|
|
3794
3895
|
try:
|
|
3896
|
+
import logging
|
|
3897
|
+
_log = logging.getLogger(__name__)
|
|
3898
|
+
_log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
|
|
3795
3899
|
async with semaphore:
|
|
3796
3900
|
response = await async_client.post("/rollout", json=body)
|
|
3797
3901
|
rollout_elapsed = time.perf_counter() - rollout_start
|
|
@@ -3812,6 +3916,10 @@ def eval_command(
|
|
|
3812
3916
|
data = response.json()
|
|
3813
3917
|
except Exception:
|
|
3814
3918
|
data = None
|
|
3919
|
+
|
|
3920
|
+
# Debug: print validation errors
|
|
3921
|
+
if response.status_code == 422 and data:
|
|
3922
|
+
click.echo(f"[DEBUG] 422 Validation Error: {data}")
|
|
3815
3923
|
|
|
3816
3924
|
metrics: dict[str, Any] | None = None
|
|
3817
3925
|
completion: str | None = None
|
|
@@ -3825,16 +3933,33 @@ def eval_command(
|
|
|
3825
3933
|
session_trace_dict: dict[str, Any] | None = None
|
|
3826
3934
|
|
|
3827
3935
|
if isinstance(data, dict):
|
|
3936
|
+
import logging
|
|
3937
|
+
_logger = logging.getLogger(__name__)
|
|
3938
|
+
_logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
|
|
3939
|
+
if "detail" in data:
|
|
3940
|
+
_logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
|
|
3828
3941
|
trace_namespace = data.get("trace")
|
|
3942
|
+
_logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
|
|
3829
3943
|
if not isinstance(trace_namespace, dict):
|
|
3830
3944
|
raise RuntimeError(
|
|
3831
|
-
"
|
|
3945
|
+
"The 'synth-ai eval' command requires trace payloads in rollout responses. "
|
|
3946
|
+
"Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
|
|
3947
|
+
"and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
|
|
3948
|
+
"Note: This is specific to the eval command - general rollout endpoints don't require traces."
|
|
3832
3949
|
)
|
|
3950
|
+
# Handle both "compact" and "full" trace formats:
|
|
3951
|
+
# - compact: trace_namespace contains {session_id, metadata, ...}
|
|
3952
|
+
# - full: trace_namespace IS the full session_trace dict
|
|
3833
3953
|
session_trace_dict = trace_namespace.get("session_trace")
|
|
3834
3954
|
if not isinstance(session_trace_dict, dict):
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3955
|
+
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
3956
|
+
if "session_id" in trace_namespace:
|
|
3957
|
+
session_trace_dict = trace_namespace
|
|
3958
|
+
else:
|
|
3959
|
+
raise RuntimeError(
|
|
3960
|
+
"The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
|
|
3961
|
+
"Ensure the task app is using tracing_v3 and returning structured trace data."
|
|
3962
|
+
)
|
|
3838
3963
|
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
3839
3964
|
if metrics:
|
|
3840
3965
|
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
@@ -3956,26 +4081,27 @@ def eval_command(
|
|
|
3956
4081
|
for spec in judge_specs:
|
|
3957
4082
|
score_value: float | None = None
|
|
3958
4083
|
judge_elapsed: float | None = None
|
|
3959
|
-
|
|
3960
|
-
|
|
3961
|
-
|
|
3962
|
-
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
3966
|
-
|
|
3967
|
-
|
|
3968
|
-
|
|
3969
|
-
|
|
3970
|
-
|
|
3971
|
-
|
|
4084
|
+
# Run judges for all tasks (text-based and trajectory-based)
|
|
4085
|
+
# Text-based tasks have completion, trajectory-based tasks use response
|
|
4086
|
+
judge_payload = {
|
|
4087
|
+
"seed": seed_val,
|
|
4088
|
+
"prompt_index": prompt_index,
|
|
4089
|
+
"prompt": prompt_text,
|
|
4090
|
+
"completion": completion,
|
|
4091
|
+
"metrics": metrics,
|
|
4092
|
+
"response": data,
|
|
4093
|
+
"trace": trace_namespace,
|
|
4094
|
+
}
|
|
4095
|
+
try:
|
|
4096
|
+
judge_start = time.perf_counter()
|
|
4097
|
+
result = spec.fn(judge_payload, **spec.kwargs)
|
|
4098
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
4099
|
+
if isinstance(result, int | float):
|
|
4100
|
+
score_value = float(result)
|
|
4101
|
+
except Exception as exc:
|
|
4102
|
+
if judge_elapsed is None:
|
|
3972
4103
|
judge_elapsed = time.perf_counter() - judge_start
|
|
3973
|
-
|
|
3974
|
-
score_value = float(result)
|
|
3975
|
-
except Exception as exc:
|
|
3976
|
-
if judge_elapsed is None:
|
|
3977
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
3978
|
-
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
4104
|
+
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
3979
4105
|
judges_timings[spec.name] = judge_elapsed
|
|
3980
4106
|
judge_scores[spec.name] = score_value
|
|
3981
4107
|
|
|
@@ -4129,6 +4255,9 @@ def filter_command(config_path: str) -> None:
|
|
|
4129
4255
|
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4130
4256
|
for a working example.
|
|
4131
4257
|
"""
|
|
4258
|
+
# Parse and validate TOML config
|
|
4259
|
+
from synth_ai.task.config import FilterConfig
|
|
4260
|
+
|
|
4132
4261
|
if _toml is None:
|
|
4133
4262
|
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4134
4263
|
|
|
@@ -4141,58 +4270,37 @@ def filter_command(config_path: str) -> None:
|
|
|
4141
4270
|
except Exception as exc:
|
|
4142
4271
|
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4143
4272
|
|
|
4144
|
-
|
|
4145
|
-
if not isinstance(
|
|
4273
|
+
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4274
|
+
if not isinstance(filter_cfg_dict, dict):
|
|
4146
4275
|
raise click.ClickException("Config must contain a [filter] table")
|
|
4147
4276
|
|
|
4148
|
-
|
|
4149
|
-
if not db_value:
|
|
4150
|
-
raise click.ClickException("filter.db must be provided")
|
|
4151
|
-
if "://" in db_value:
|
|
4152
|
-
db_url = db_value
|
|
4153
|
-
else:
|
|
4154
|
-
db_path = Path(db_value).expanduser()
|
|
4155
|
-
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4156
|
-
db_url = f"sqlite+aiosqlite:///{db_path}"
|
|
4157
|
-
|
|
4158
|
-
output_value = filter_cfg.get("output")
|
|
4159
|
-
if not output_value:
|
|
4160
|
-
raise click.ClickException("filter.output must be provided")
|
|
4161
|
-
output_path = Path(str(output_value)).expanduser()
|
|
4162
|
-
|
|
4163
|
-
splits = set(filter_cfg.get("splits", []) or [])
|
|
4164
|
-
task_ids = set(filter_cfg.get("task_ids", []) or [])
|
|
4165
|
-
models = set(filter_cfg.get("models", []) or [])
|
|
4166
|
-
min_official = filter_cfg.get("min_official_score")
|
|
4167
|
-
max_official = filter_cfg.get("max_official_score")
|
|
4168
|
-
if min_official is not None:
|
|
4169
|
-
try:
|
|
4170
|
-
min_official = float(min_official)
|
|
4171
|
-
except Exception as err:
|
|
4172
|
-
raise click.ClickException("filter.min_official_score must be numeric") from err
|
|
4173
|
-
if max_official is not None:
|
|
4174
|
-
try:
|
|
4175
|
-
max_official = float(max_official)
|
|
4176
|
-
except Exception as err:
|
|
4177
|
-
raise click.ClickException("filter.max_official_score must be numeric") from err
|
|
4178
|
-
min_judge_scores = filter_cfg.get("min_judge_scores", {}) or {}
|
|
4179
|
-
max_judge_scores = filter_cfg.get("max_judge_scores", {}) or {}
|
|
4277
|
+
# Validate config with dataclass
|
|
4180
4278
|
try:
|
|
4181
|
-
|
|
4182
|
-
|
|
4183
|
-
|
|
4184
|
-
|
|
4185
|
-
|
|
4186
|
-
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
|
|
4190
|
-
|
|
4191
|
-
|
|
4192
|
-
|
|
4193
|
-
|
|
4194
|
-
|
|
4195
|
-
|
|
4279
|
+
filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
|
|
4280
|
+
click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
|
|
4281
|
+
if filter_cfg.min_official_score is not None:
|
|
4282
|
+
click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
|
|
4283
|
+
if filter_cfg.limit:
|
|
4284
|
+
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
4285
|
+
except (ValueError, TypeError) as validation_error:
|
|
4286
|
+
raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
|
|
4287
|
+
|
|
4288
|
+
# Use validated config
|
|
4289
|
+
db_url = filter_cfg.get_db_url()
|
|
4290
|
+
output_path = filter_cfg.get_output_path()
|
|
4291
|
+
|
|
4292
|
+
# Extract validated fields from dataclass
|
|
4293
|
+
splits = set(filter_cfg.splits)
|
|
4294
|
+
task_ids = set(filter_cfg.task_ids)
|
|
4295
|
+
models = set(filter_cfg.models)
|
|
4296
|
+
min_official = filter_cfg.min_official_score
|
|
4297
|
+
max_official = filter_cfg.max_official_score
|
|
4298
|
+
min_judge_scores = filter_cfg.min_judge_scores
|
|
4299
|
+
max_judge_scores = filter_cfg.max_judge_scores
|
|
4300
|
+
# Note: min_created_at and max_created_at not yet in FilterConfig dataclass
|
|
4301
|
+
min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
|
|
4302
|
+
max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
|
|
4303
|
+
limit = filter_cfg.limit
|
|
4196
4304
|
|
|
4197
4305
|
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4198
4306
|
try:
|
|
@@ -4247,8 +4355,21 @@ def filter_command(config_path: str) -> None:
|
|
|
4247
4355
|
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4248
4356
|
continue
|
|
4249
4357
|
|
|
4250
|
-
if
|
|
4251
|
-
|
|
4358
|
+
# Check against outcome_rewards if score filter is set
|
|
4359
|
+
total_reward = None
|
|
4360
|
+
achievements_count = None
|
|
4361
|
+
if min_official is not None or max_official is not None:
|
|
4362
|
+
reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
|
|
4363
|
+
reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
|
|
4364
|
+
reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
|
|
4365
|
+
if reward_records:
|
|
4366
|
+
total_reward = reward_records[0].get("total_reward")
|
|
4367
|
+
achievements_count = reward_records[0].get("achievements_count")
|
|
4368
|
+
if not _score_ok(total_reward, min_official, max_official):
|
|
4369
|
+
continue
|
|
4370
|
+
elif min_official is not None:
|
|
4371
|
+
# No reward found, but score filter requires it
|
|
4372
|
+
continue
|
|
4252
4373
|
|
|
4253
4374
|
judge_scores = metadata.get("judge_scores") or {}
|
|
4254
4375
|
include = True
|
|
@@ -4265,30 +4386,120 @@ def filter_command(config_path: str) -> None:
|
|
|
4265
4386
|
if not include:
|
|
4266
4387
|
continue
|
|
4267
4388
|
|
|
4268
|
-
|
|
4269
|
-
|
|
4270
|
-
|
|
4389
|
+
# Query messages for this session
|
|
4390
|
+
messages_query = """
|
|
4391
|
+
SELECT message_type, content, timestamp
|
|
4392
|
+
FROM messages
|
|
4393
|
+
WHERE session_id = :session_id
|
|
4394
|
+
ORDER BY timestamp ASC, id ASC
|
|
4395
|
+
"""
|
|
4396
|
+
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
4397
|
+
message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
4398
|
+
|
|
4399
|
+
if not message_rows:
|
|
4400
|
+
# Fallback: check if prompt/completion in metadata (old format)
|
|
4401
|
+
prompt = metadata.get("prompt") or ""
|
|
4402
|
+
completion = metadata.get("completion") or ""
|
|
4403
|
+
if prompt and completion:
|
|
4404
|
+
record = {
|
|
4405
|
+
"messages": [
|
|
4406
|
+
{"role": "user", "content": str(prompt)},
|
|
4407
|
+
{"role": "assistant", "content": str(completion)},
|
|
4408
|
+
],
|
|
4409
|
+
"metadata": {
|
|
4410
|
+
"session_id": session_id,
|
|
4411
|
+
"env_name": metadata.get("env_name"),
|
|
4412
|
+
"policy_name": metadata.get("policy_name"),
|
|
4413
|
+
"seed": metadata.get("seed"),
|
|
4414
|
+
"total_reward": total_reward,
|
|
4415
|
+
"achievements_count": achievements_count,
|
|
4416
|
+
"model": metadata.get("model"),
|
|
4417
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4418
|
+
},
|
|
4419
|
+
}
|
|
4420
|
+
accepted.append(record)
|
|
4271
4421
|
continue
|
|
4272
4422
|
|
|
4273
|
-
|
|
4274
|
-
|
|
4275
|
-
|
|
4276
|
-
|
|
4277
|
-
|
|
4278
|
-
|
|
4279
|
-
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4283
|
-
|
|
4284
|
-
|
|
4285
|
-
|
|
4286
|
-
|
|
4287
|
-
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4423
|
+
# Extract user/assistant pairs from messages
|
|
4424
|
+
for i, msg_row in enumerate(message_rows):
|
|
4425
|
+
msg_type = msg_row.get("message_type")
|
|
4426
|
+
content_raw = msg_row.get("content")
|
|
4427
|
+
|
|
4428
|
+
# Look for user message
|
|
4429
|
+
if msg_type in ("user", "policy_user_prompt"):
|
|
4430
|
+
# Find next policy_system_prompt or assistant
|
|
4431
|
+
assistant_msg = None
|
|
4432
|
+
for j in range(i + 1, len(message_rows)):
|
|
4433
|
+
next_type = message_rows[j].get("message_type")
|
|
4434
|
+
if next_type in ("assistant", "policy_system_prompt"):
|
|
4435
|
+
if next_type == "assistant":
|
|
4436
|
+
assistant_msg = message_rows[j]
|
|
4437
|
+
break
|
|
4438
|
+
|
|
4439
|
+
# Parse content
|
|
4440
|
+
try:
|
|
4441
|
+
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
4442
|
+
except Exception:
|
|
4443
|
+
user_content = content_raw
|
|
4444
|
+
|
|
4445
|
+
# Extract text from structured content
|
|
4446
|
+
def extract_text(content: Any) -> str:
|
|
4447
|
+
if isinstance(content, str):
|
|
4448
|
+
return content
|
|
4449
|
+
if isinstance(content, dict):
|
|
4450
|
+
# Try payload.content for user prompts
|
|
4451
|
+
if "payload" in content and isinstance(content["payload"], dict):
|
|
4452
|
+
payload = content["payload"]
|
|
4453
|
+
if "content" in payload:
|
|
4454
|
+
return extract_text(payload["content"])
|
|
4455
|
+
# Try common keys
|
|
4456
|
+
for key in ["text", "content", "content_text"]:
|
|
4457
|
+
if key in content:
|
|
4458
|
+
val = content[key]
|
|
4459
|
+
if isinstance(val, str):
|
|
4460
|
+
return val
|
|
4461
|
+
return json.dumps(content)
|
|
4462
|
+
if isinstance(content, list):
|
|
4463
|
+
# Multimodal content - concatenate text parts
|
|
4464
|
+
parts = []
|
|
4465
|
+
for item in content:
|
|
4466
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
4467
|
+
parts.append(item.get("text", ""))
|
|
4468
|
+
return " ".join(parts) if parts else str(content)
|
|
4469
|
+
return str(content)
|
|
4470
|
+
|
|
4471
|
+
user_text = extract_text(user_content)
|
|
4472
|
+
|
|
4473
|
+
# For assistant, we might not have it recorded, so use tool calls as completion
|
|
4474
|
+
assistant_text = ""
|
|
4475
|
+
if assistant_msg:
|
|
4476
|
+
assistant_content_raw = assistant_msg.get("content")
|
|
4477
|
+
try:
|
|
4478
|
+
assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
|
|
4479
|
+
except Exception:
|
|
4480
|
+
assistant_content = assistant_content_raw
|
|
4481
|
+
assistant_text = extract_text(assistant_content)
|
|
4482
|
+
|
|
4483
|
+
if not user_text:
|
|
4484
|
+
continue
|
|
4485
|
+
|
|
4486
|
+
record = {
|
|
4487
|
+
"messages": [
|
|
4488
|
+
{"role": "user", "content": user_text},
|
|
4489
|
+
{"role": "assistant", "content": assistant_text if assistant_text else "[no response recorded]"},
|
|
4490
|
+
],
|
|
4491
|
+
"metadata": {
|
|
4492
|
+
"session_id": session_id,
|
|
4493
|
+
"env_name": metadata.get("env_name"),
|
|
4494
|
+
"policy_name": metadata.get("policy_name"),
|
|
4495
|
+
"seed": metadata.get("seed"),
|
|
4496
|
+
"total_reward": total_reward,
|
|
4497
|
+
"achievements_count": achievements_count,
|
|
4498
|
+
"model": metadata.get("model"),
|
|
4499
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4500
|
+
},
|
|
4501
|
+
}
|
|
4502
|
+
accepted.append(record)
|
|
4292
4503
|
|
|
4293
4504
|
if not accepted:
|
|
4294
4505
|
raise click.ClickException("No sessions matched the provided filters")
|