synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -335,10 +335,10 @@ class IntegrationTestHandler(StreamHandler):
|
|
|
335
335
|
|
|
336
336
|
|
|
337
337
|
class GraphGenHandler(StreamHandler):
|
|
338
|
-
"""Handler for
|
|
338
|
+
"""Handler for Graph Opt jobs that delegate child job streams to an underlying handler.
|
|
339
339
|
|
|
340
|
-
|
|
341
|
-
provides light
|
|
340
|
+
Graph Opt jobs emit events from child jobs (GEPA, MIPRO, RL, SFT, etc.). This handler
|
|
341
|
+
provides light Graph Opt-aware filtering and routing while keeping child job output
|
|
342
342
|
intact via a delegate handler. The delegate can be supplied directly or created
|
|
343
343
|
via a factory; by default we choose a prompt-learning handler for GEPA/MIPRO and
|
|
344
344
|
a basic CLI handler for other job types.
|
|
@@ -365,7 +365,7 @@ class GraphGenHandler(StreamHandler):
|
|
|
365
365
|
self._pl_show_validation = show_validation
|
|
366
366
|
|
|
367
367
|
self.filter_verbose_events = filter_verbose_events
|
|
368
|
-
# If False, skip
|
|
368
|
+
# If False, skip Graph Opt-specific filtering/transformations and just pass through.
|
|
369
369
|
self.wrap_child_events = wrap_child_events
|
|
370
370
|
|
|
371
371
|
# Detected child job type (gepa/mipro/rl/sft/etc.)
|
|
@@ -436,7 +436,7 @@ class GraphGenHandler(StreamHandler):
|
|
|
436
436
|
elif event_type.startswith("sft.") or ".sft." in event_type:
|
|
437
437
|
self.child_job_type = "sft"
|
|
438
438
|
else:
|
|
439
|
-
# Fall back to the first segment as a hint (e.g., "
|
|
439
|
+
# Fall back to the first segment as a hint (e.g., "graphgen.child_type")
|
|
440
440
|
parts = event_type.split(".")
|
|
441
441
|
if parts:
|
|
442
442
|
self.child_job_type = parts[0]
|
|
@@ -504,7 +504,7 @@ class GraphGenHandler(StreamHandler):
|
|
|
504
504
|
return any(pattern in event_type_lower for pattern in verbose_patterns)
|
|
505
505
|
|
|
506
506
|
def _transform_event_message(self, message: StreamMessage) -> StreamMessage:
|
|
507
|
-
"""Transform event messages for
|
|
507
|
+
"""Transform event messages for Graph Opt context (currently passthrough)."""
|
|
508
508
|
return message
|
|
509
509
|
|
|
510
510
|
def flush(self) -> None:
|
|
@@ -142,11 +142,11 @@ class StreamEndpoints:
|
|
|
142
142
|
)
|
|
143
143
|
|
|
144
144
|
@classmethod
|
|
145
|
-
def
|
|
146
|
-
"""Endpoints for GraphGen
|
|
145
|
+
def graphgen(cls, job_id: str) -> StreamEndpoints:
|
|
146
|
+
"""Endpoints for GraphGen workflow optimization jobs.
|
|
147
147
|
|
|
148
148
|
GraphGen jobs use /api/graphgen/jobs/{job_id} endpoints.
|
|
149
|
-
The backend handles GraphGen
|
|
149
|
+
The backend handles GraphGen -> graph_evolve job ID resolution internally using job_relationships.
|
|
150
150
|
No fallbacks needed - GraphGen endpoints resolve everything.
|
|
151
151
|
"""
|
|
152
152
|
base = f"/graphgen/jobs/{job_id}"
|
|
@@ -158,6 +158,7 @@ class StreamEndpoints:
|
|
|
158
158
|
)
|
|
159
159
|
|
|
160
160
|
|
|
161
|
+
|
|
161
162
|
class JobStreamer:
|
|
162
163
|
"""Poll job endpoints and dispatch messages to configured handlers."""
|
|
163
164
|
|
|
@@ -503,14 +504,17 @@ class JobStreamer:
|
|
|
503
504
|
except Exception as e:
|
|
504
505
|
error_str = str(e)
|
|
505
506
|
print(f"[DEBUG] Error polling {path}: {e}", file=sys.stderr)
|
|
506
|
-
# Fail fast if we get 404 on
|
|
507
|
-
if "404" in error_str and (
|
|
507
|
+
# Fail fast if we get 404 on GraphGen and fallback endpoints (indicates job ID mapping issue)
|
|
508
|
+
if "404" in error_str and (
|
|
509
|
+
"graphgen" in path.lower()
|
|
510
|
+
or "prompt-learning" in path.lower()
|
|
511
|
+
):
|
|
508
512
|
# Check if this is the last fallback path - if so, raise to fail fast
|
|
509
513
|
if path == self._event_paths[-1]: # Last fallback path
|
|
510
514
|
raise RuntimeError(
|
|
511
515
|
f"Failed to poll events: All endpoints returned 404. "
|
|
512
516
|
f"This likely indicates a job ID mapping issue. "
|
|
513
|
-
f"
|
|
517
|
+
f"GraphGen endpoints need the GraphGen job ID; GEPA fallback endpoints need the GEPA job ID. "
|
|
514
518
|
f"Last error: {error_str}"
|
|
515
519
|
)
|
|
516
520
|
continue
|
synth_ai/sdk/task/__init__.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
|
+
"""Task namespace (legacy).
|
|
2
|
+
|
|
3
|
+
Prefer synth_ai.sdk.localapi.* moving forward. This module remains for backward
|
|
4
|
+
compatibility during the naming transition.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
from .auth import (
|
|
2
8
|
is_api_key_header_authorized,
|
|
3
9
|
normalize_environment_api_key,
|
|
4
10
|
require_api_key_dependency,
|
|
5
11
|
)
|
|
6
|
-
from .client import TaskAppClient
|
|
12
|
+
from .client import LocalAPIClient, TaskAppClient
|
|
7
13
|
from .config import EvalConfig, FilterConfig
|
|
8
14
|
from .contracts import (
|
|
9
15
|
DatasetInfo,
|
|
10
16
|
InferenceInfo,
|
|
11
17
|
LimitsInfo,
|
|
18
|
+
LocalAPIEndpoints,
|
|
12
19
|
RolloutEnvSpec,
|
|
13
20
|
RolloutMetrics,
|
|
14
21
|
RolloutPolicySpec,
|
|
@@ -16,8 +23,6 @@ from .contracts import (
|
|
|
16
23
|
RolloutRequest,
|
|
17
24
|
RolloutResponse,
|
|
18
25
|
RolloutSafetyConfig,
|
|
19
|
-
RolloutStep,
|
|
20
|
-
RolloutTrajectory,
|
|
21
26
|
RubricInfo,
|
|
22
27
|
RubricSection,
|
|
23
28
|
TaskAppEndpoints,
|
|
@@ -54,15 +59,19 @@ from .rubrics import (
|
|
|
54
59
|
score_outcome_against_rubric,
|
|
55
60
|
)
|
|
56
61
|
from .server import (
|
|
62
|
+
LocalAPIConfig,
|
|
57
63
|
ProxyConfig,
|
|
58
64
|
RubricBundle,
|
|
59
65
|
TaskAppConfig,
|
|
60
66
|
create_task_app,
|
|
67
|
+
run_server_background,
|
|
61
68
|
run_task_app,
|
|
62
69
|
)
|
|
63
70
|
from .trace_correlation_helpers import (
|
|
71
|
+
build_trace_payload,
|
|
64
72
|
build_trajectory_trace,
|
|
65
73
|
extract_trace_correlation_id,
|
|
74
|
+
include_event_history_in_response,
|
|
66
75
|
include_event_history_in_trajectories,
|
|
67
76
|
include_trace_correlation_id_in_response,
|
|
68
77
|
validate_trace_correlation_id,
|
|
@@ -89,14 +98,13 @@ __all__ = [
|
|
|
89
98
|
"EvalConfig",
|
|
90
99
|
"FilterConfig",
|
|
91
100
|
"TaskAppEndpoints",
|
|
101
|
+
"LocalAPIEndpoints",
|
|
92
102
|
"RolloutEnvSpec",
|
|
93
103
|
"RolloutPolicySpec",
|
|
94
104
|
"RolloutRecordConfig",
|
|
95
105
|
"RolloutSafetyConfig",
|
|
96
106
|
"RolloutRequest",
|
|
97
107
|
"RolloutResponse",
|
|
98
|
-
"RolloutTrajectory",
|
|
99
|
-
"RolloutStep",
|
|
100
108
|
"RolloutMetrics",
|
|
101
109
|
"TaskDescriptor",
|
|
102
110
|
"DatasetInfo",
|
|
@@ -127,14 +135,17 @@ __all__ = [
|
|
|
127
135
|
"score_events_against_rubric",
|
|
128
136
|
"score_outcome_against_rubric",
|
|
129
137
|
"TaskAppClient",
|
|
138
|
+
"LocalAPIClient",
|
|
130
139
|
"error_payload",
|
|
131
140
|
"http_exception",
|
|
132
141
|
"json_error_response",
|
|
133
142
|
"run_task_app",
|
|
143
|
+
"run_server_background",
|
|
134
144
|
"create_task_app",
|
|
135
145
|
"RubricBundle",
|
|
136
146
|
"ProxyConfig",
|
|
137
147
|
"TaskAppConfig",
|
|
148
|
+
"LocalAPIConfig",
|
|
138
149
|
"InferenceAPIClient",
|
|
139
150
|
"InProcessTaskApp",
|
|
140
151
|
"InProcessJobResult",
|
|
@@ -143,7 +154,9 @@ __all__ = [
|
|
|
143
154
|
"run_in_process_job",
|
|
144
155
|
"run_in_process_job_sync",
|
|
145
156
|
"build_trajectory_trace",
|
|
157
|
+
"build_trace_payload",
|
|
146
158
|
"extract_trace_correlation_id",
|
|
159
|
+
"include_event_history_in_response",
|
|
147
160
|
"include_event_history_in_trajectories",
|
|
148
161
|
"include_trace_correlation_id_in_response",
|
|
149
162
|
"validate_trace_correlation_id",
|
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
"""Registry for Task Apps exposed via the shared FastAPI harness.
|
|
1
|
+
"""Registry for Task Apps exposed via the shared FastAPI harness.
|
|
2
|
+
|
|
3
|
+
Prefer synth_ai.sdk.localapi.apps moving forward. This module remains for
|
|
4
|
+
backward compatibility during the naming transition.
|
|
5
|
+
"""
|
|
2
6
|
|
|
3
7
|
from __future__ import annotations
|
|
4
8
|
|
|
@@ -43,6 +47,22 @@ class TaskAppEntry:
|
|
|
43
47
|
modal: ModalDeploymentConfig | None = None
|
|
44
48
|
|
|
45
49
|
|
|
50
|
+
@dataclass(slots=True)
|
|
51
|
+
class LocalAPIEntry:
|
|
52
|
+
"""Metadata describing a registered local API."""
|
|
53
|
+
|
|
54
|
+
api_id: str
|
|
55
|
+
description: str
|
|
56
|
+
config_factory: Callable[[], TaskAppConfig]
|
|
57
|
+
aliases: Sequence[str] = field(default_factory=tuple)
|
|
58
|
+
env_files: Sequence[str] = field(default_factory=tuple)
|
|
59
|
+
modal: ModalDeploymentConfig | None = None
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def app_id(self) -> str:
|
|
63
|
+
return self.api_id
|
|
64
|
+
|
|
65
|
+
|
|
46
66
|
class TaskAppRegistry:
|
|
47
67
|
"""In-memory registry of known task apps."""
|
|
48
68
|
|
|
@@ -86,6 +106,22 @@ def register_task_app(*, entry: TaskAppEntry) -> None:
|
|
|
86
106
|
registry.register(entry)
|
|
87
107
|
|
|
88
108
|
|
|
109
|
+
def register_local_api(*, entry: LocalAPIEntry | TaskAppEntry) -> None:
|
|
110
|
+
if isinstance(entry, LocalAPIEntry):
|
|
111
|
+
registry.register(
|
|
112
|
+
TaskAppEntry(
|
|
113
|
+
app_id=entry.api_id,
|
|
114
|
+
description=entry.description,
|
|
115
|
+
config_factory=entry.config_factory,
|
|
116
|
+
aliases=entry.aliases,
|
|
117
|
+
env_files=entry.env_files,
|
|
118
|
+
modal=entry.modal,
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
return
|
|
122
|
+
registry.register(entry)
|
|
123
|
+
|
|
124
|
+
|
|
89
125
|
def discover_task_apps_from_cwd() -> None:
|
|
90
126
|
"""Discover and register task apps from the current working directory and subdirectories."""
|
|
91
127
|
cwd = Path.cwd()
|
synth_ai/sdk/task/client.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
"""Async HTTP client for interacting with Task Apps.
|
|
1
|
+
"""Async HTTP client for interacting with Task Apps.
|
|
2
|
+
|
|
3
|
+
Prefer synth_ai.sdk.localapi.client moving forward. This module remains for
|
|
4
|
+
backward compatibility during the naming transition.
|
|
5
|
+
"""
|
|
2
6
|
|
|
3
7
|
from __future__ import annotations
|
|
4
8
|
|
|
@@ -142,6 +146,10 @@ class TaskAppClient:
|
|
|
142
146
|
return RolloutResponse.model_validate(data)
|
|
143
147
|
|
|
144
148
|
|
|
149
|
+
class LocalAPIClient(TaskAppClient):
|
|
150
|
+
"""Alias for TaskAppClient with LocalAPI naming."""
|
|
151
|
+
|
|
152
|
+
|
|
145
153
|
class _TaskAppEnvironmentClient:
|
|
146
154
|
def __init__(self, client: TaskAppClient) -> None:
|
|
147
155
|
self._client = client
|
synth_ai/sdk/task/config.py
CHANGED
|
@@ -50,9 +50,6 @@ class EvalConfig:
|
|
|
50
50
|
# Optional: Whether to return traces in response
|
|
51
51
|
return_trace: bool = False
|
|
52
52
|
|
|
53
|
-
# Optional: Operations sequence (if not provided, generates default)
|
|
54
|
-
ops: list[str] | None = None
|
|
55
|
-
|
|
56
53
|
# Optional: Environment config overrides
|
|
57
54
|
env_config: dict[str, Any] = field(default_factory=dict)
|
|
58
55
|
|
|
@@ -115,7 +112,6 @@ class EvalConfig:
|
|
|
115
112
|
"policy_name": data.get("policy_name"),
|
|
116
113
|
"trace_format": data.get("trace_format", "compact"),
|
|
117
114
|
"return_trace": data.get("return_trace", False),
|
|
118
|
-
"ops": data.get("ops"),
|
|
119
115
|
"env_config": data.get("env_config", {}),
|
|
120
116
|
"policy_config": data.get("policy_config", {}),
|
|
121
117
|
"metadata": data.get("metadata", {}),
|
|
@@ -153,11 +149,11 @@ class FilterConfig:
|
|
|
153
149
|
# Optional: Maximum official score threshold
|
|
154
150
|
max_official_score: float | None = None
|
|
155
151
|
|
|
156
|
-
# Optional: Minimum
|
|
157
|
-
|
|
152
|
+
# Optional: Minimum verifier scores (verifier_name -> min_score)
|
|
153
|
+
min_verifier_scores: dict[str, float] = field(default_factory=dict)
|
|
158
154
|
|
|
159
|
-
# Optional: Maximum
|
|
160
|
-
|
|
155
|
+
# Optional: Maximum verifier scores (verifier_name -> max_score)
|
|
156
|
+
max_verifier_scores: dict[str, float] = field(default_factory=dict)
|
|
161
157
|
|
|
162
158
|
# Optional: Limit number of examples
|
|
163
159
|
limit: int | None = None
|
|
@@ -222,8 +218,8 @@ class FilterConfig:
|
|
|
222
218
|
"models": data.get("models", []),
|
|
223
219
|
"min_official_score": data.get("min_official_score"),
|
|
224
220
|
"max_official_score": data.get("max_official_score"),
|
|
225
|
-
"
|
|
226
|
-
"
|
|
221
|
+
"min_verifier_scores": data.get("min_verifier_scores", {}),
|
|
222
|
+
"max_verifier_scores": data.get("max_verifier_scores", {}),
|
|
227
223
|
"limit": data.get("limit"),
|
|
228
224
|
"offset": data.get("offset"),
|
|
229
225
|
"shuffle": data.get("shuffle", False),
|
|
@@ -258,4 +254,3 @@ class FilterConfig:
|
|
|
258
254
|
return output_path
|
|
259
255
|
|
|
260
256
|
|
|
261
|
-
|
synth_ai/sdk/task/contracts.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""Contracts for Task Apps.
|
|
2
|
+
|
|
3
|
+
Prefer synth_ai.sdk.localapi.contracts moving forward. This module remains for
|
|
4
|
+
backward compatibility during the naming transition.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
from __future__ import annotations
|
|
2
8
|
|
|
3
9
|
from dataclasses import dataclass
|
|
@@ -64,6 +70,11 @@ class TaskAppEndpoints:
|
|
|
64
70
|
rollout: str = "/rollout"
|
|
65
71
|
|
|
66
72
|
|
|
73
|
+
@dataclass(frozen=True)
|
|
74
|
+
class LocalAPIEndpoints(TaskAppEndpoints):
|
|
75
|
+
"""Alias for TaskAppEndpoints with LocalAPI naming."""
|
|
76
|
+
|
|
77
|
+
|
|
67
78
|
# --- Unified rollout schema used by Task App services and SDK utilities ---
|
|
68
79
|
|
|
69
80
|
|
|
@@ -91,7 +102,6 @@ class RolloutPolicySpec(BaseModel):
|
|
|
91
102
|
|
|
92
103
|
|
|
93
104
|
class RolloutRecordConfig(BaseModel):
|
|
94
|
-
trajectories: bool = True
|
|
95
105
|
logprobs: bool = False
|
|
96
106
|
value: bool = False
|
|
97
107
|
return_trace: bool = False
|
|
@@ -99,7 +109,6 @@ class RolloutRecordConfig(BaseModel):
|
|
|
99
109
|
|
|
100
110
|
|
|
101
111
|
class RolloutSafetyConfig(BaseModel):
|
|
102
|
-
max_ops: int = 100000
|
|
103
112
|
max_time_s: float = 3600.0
|
|
104
113
|
|
|
105
114
|
|
|
@@ -107,121 +116,148 @@ class RolloutRequest(BaseModel):
|
|
|
107
116
|
run_id: str
|
|
108
117
|
env: RolloutEnvSpec
|
|
109
118
|
policy: RolloutPolicySpec
|
|
110
|
-
ops: list[dict[str, Any]] | list[str]
|
|
111
119
|
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
112
120
|
on_done: str = "reset"
|
|
113
121
|
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
114
122
|
training_session_id: str | None = None
|
|
115
123
|
synth_base_url: str | None = None
|
|
116
|
-
mode: RolloutMode #
|
|
124
|
+
mode: RolloutMode = RolloutMode.RL # Default to RL mode for training/optimization
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class RolloutMetrics(BaseModel):
|
|
128
|
+
"""Metrics from a rollout execution.
|
|
129
|
+
|
|
130
|
+
## Preferred Fields (New - Normalized)
|
|
131
|
+
|
|
132
|
+
- `outcome_reward`: The reward for this rollout (PREFERRED)
|
|
133
|
+
- `event_rewards`: Optional per-step rewards
|
|
134
|
+
|
|
135
|
+
## Legacy Fields (Backward Compatibility)
|
|
136
|
+
|
|
137
|
+
- `episode_rewards`, `reward_mean`, `num_steps`: Still supported for backward
|
|
138
|
+
compatibility. For new implementations, just use `outcome_reward`.
|
|
139
|
+
- `outcome_score`: Alias for `outcome_reward` (deprecated)
|
|
140
|
+
|
|
141
|
+
## Example - Minimal (New Style)
|
|
117
142
|
|
|
143
|
+
metrics = RolloutMetrics(
|
|
144
|
+
outcome_reward=1.0, # PREFERRED - just provide the reward
|
|
145
|
+
)
|
|
118
146
|
|
|
119
|
-
|
|
120
|
-
"""Single step in a rollout trajectory.
|
|
147
|
+
## Example - Full (Backward Compatible)
|
|
121
148
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
149
|
+
metrics = RolloutMetrics(
|
|
150
|
+
episode_rewards=[1.0],
|
|
151
|
+
reward_mean=1.0,
|
|
152
|
+
num_steps=1,
|
|
153
|
+
outcome_reward=1.0, # PREFERRED
|
|
154
|
+
)
|
|
125
155
|
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
info: dict[str, Any] | None = None
|
|
132
|
-
|
|
133
|
-
# Unified output fields (supports all output modes)
|
|
134
|
-
output: dict[str, Any] | str | None = Field(
|
|
156
|
+
|
|
157
|
+
# =========================================================================
|
|
158
|
+
# PREFERRED FIELDS (New - Normalized)
|
|
159
|
+
# =========================================================================
|
|
160
|
+
outcome_reward: float | None = Field(
|
|
135
161
|
default=None,
|
|
136
|
-
description="
|
|
162
|
+
description="The reward for this rollout. PREFERRED field for scoring.",
|
|
137
163
|
)
|
|
138
|
-
|
|
164
|
+
event_rewards: list[float] | None = Field(
|
|
139
165
|
default=None,
|
|
140
|
-
description="
|
|
166
|
+
description="Optional per-step/event rewards for multi-step tasks.",
|
|
141
167
|
)
|
|
142
168
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
"""
|
|
168
|
-
env_id: str
|
|
169
|
-
policy_id: str
|
|
170
|
-
steps: list[RolloutStep]
|
|
171
|
-
final: dict[str, Any] | None = None
|
|
172
|
-
length: int
|
|
173
|
-
|
|
174
|
-
# Required for trace correlation with inference mesh (optional initially for backward compat)
|
|
175
|
-
# See: monorepo/INFERENCE_URL_REQUIREMENT_PLAN.md and trace_creation_and_judgement.txt
|
|
176
|
-
inference_url: str
|
|
177
|
-
|
|
178
|
-
# Required by monorepo trace_validation.py: trajectory-level trace with event_history
|
|
179
|
-
# The event_history contains LM call records for input/output extraction
|
|
180
|
-
trace: dict[str, Any] | None = Field(
|
|
169
|
+
# =========================================================================
|
|
170
|
+
# LEGACY FIELDS (Backward Compatibility)
|
|
171
|
+
# =========================================================================
|
|
172
|
+
episode_rewards: list[float] = Field(
|
|
173
|
+
default_factory=list,
|
|
174
|
+
description="[LEGACY] Per-episode rewards. Use outcome_reward instead.",
|
|
175
|
+
)
|
|
176
|
+
reward_mean: float = Field(
|
|
177
|
+
default=0.0,
|
|
178
|
+
description="[LEGACY] Mean reward. Use outcome_reward instead.",
|
|
179
|
+
)
|
|
180
|
+
num_steps: int = Field(
|
|
181
|
+
default=1,
|
|
182
|
+
description="[LEGACY] Step count. Can be derived from event_rewards or trace.",
|
|
183
|
+
)
|
|
184
|
+
num_episodes: int = Field(
|
|
185
|
+
default=1,
|
|
186
|
+
description="[LEGACY] Episode count. Usually 1 for GEPA tasks.",
|
|
187
|
+
)
|
|
188
|
+
outcome_score: float | None = Field(
|
|
189
|
+
default=None,
|
|
190
|
+
description="[DEPRECATED] Alias for outcome_reward. Use outcome_reward instead.",
|
|
191
|
+
)
|
|
192
|
+
events_score: float | None = Field(
|
|
181
193
|
default=None,
|
|
182
|
-
description="
|
|
194
|
+
description="[LEGACY] Aggregate event score. Use event_rewards instead.",
|
|
195
|
+
)
|
|
196
|
+
details: dict[str, Any] = Field(
|
|
197
|
+
default_factory=dict,
|
|
198
|
+
description="Metadata only. Do NOT use details.correct for rewards.",
|
|
183
199
|
)
|
|
184
200
|
|
|
185
|
-
decision_samples: list[dict[str, Any]] | None = None
|
|
186
201
|
|
|
202
|
+
class RolloutResponse(BaseModel):
|
|
203
|
+
"""Response from a rollout execution.
|
|
187
204
|
|
|
188
|
-
|
|
189
|
-
episode_returns: list[float]
|
|
190
|
-
mean_return: float
|
|
191
|
-
num_steps: int
|
|
192
|
-
num_episodes: int = 0
|
|
193
|
-
outcome_score: float | None = None
|
|
194
|
-
events_score: float | None = None
|
|
195
|
-
details: dict[str, Any] = Field(default_factory=dict)
|
|
205
|
+
## Key Fields
|
|
196
206
|
|
|
207
|
+
- `run_id`: Echo from request (required)
|
|
208
|
+
- `metrics`: Rollout metrics with `outcome_reward` (required)
|
|
209
|
+
- `trace`: v3 trace payload (required for verifier scoring)
|
|
197
210
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
211
|
+
## Canonical Locations (Top-Level)
|
|
212
|
+
|
|
213
|
+
- `trace_correlation_id`: Correlation ID for trace recovery (TOP-LEVEL CANONICAL)
|
|
214
|
+
- `inference_url`: Inference URL used for this rollout (TOP-LEVEL CANONICAL)
|
|
215
|
+
|
|
216
|
+
These fields SHOULD be at top-level. The monorepo parses from top-level first,
|
|
217
|
+
with fallback to nested locations for backward compatibility.
|
|
218
|
+
|
|
219
|
+
## Example
|
|
220
|
+
|
|
221
|
+
response = RolloutResponse(
|
|
222
|
+
run_id=request.run_id,
|
|
223
|
+
metrics=RolloutMetrics(outcome_reward=1.0),
|
|
224
|
+
trace=trace_payload,
|
|
225
|
+
trace_correlation_id="trace_abc123",
|
|
226
|
+
inference_url="https://api.usesynth.ai/v1/trial-xyz",
|
|
227
|
+
)
|
|
203
228
|
"""
|
|
229
|
+
|
|
204
230
|
run_id: str
|
|
205
|
-
|
|
206
|
-
# DEPRECATED: Legacy format maintained for training code compatibility.
|
|
207
|
-
# Will be removed once training migrates to reading from `trace` field.
|
|
208
|
-
# See: monorepo/trace_single_source.txt for migration plan.
|
|
209
|
-
trajectories: list[RolloutTrajectory]
|
|
210
|
-
|
|
211
|
-
branches: dict[str, list[str]] = Field(default_factory=dict)
|
|
212
231
|
metrics: RolloutMetrics
|
|
213
|
-
aborted: bool = False
|
|
214
|
-
ops_executed: int = 0
|
|
215
|
-
|
|
216
|
-
# OPTIONAL: correlation ID for linking rollout to inference traces
|
|
217
|
-
# If not provided, trainer will infer it from trajectory.inference_url ?cid=... parameter
|
|
218
|
-
trace_correlation_id: str | None = None
|
|
219
|
-
|
|
220
|
-
# PREFERRED: v3 trace format (SessionTrace). This is the single source of truth
|
|
221
|
-
# for rollout data and should be used by all new code. Contains richer data than
|
|
222
|
-
# trajectories including token IDs, logprobs, timing, and multimodal content.
|
|
223
232
|
trace: dict[str, Any] | None = None
|
|
224
|
-
|
|
233
|
+
|
|
234
|
+
# =========================================================================
|
|
235
|
+
# CANONICAL LOCATIONS (Top-Level - Preferred for Parsing)
|
|
236
|
+
# =========================================================================
|
|
237
|
+
trace_correlation_id: str | None = Field(
|
|
238
|
+
default=None,
|
|
239
|
+
description="Correlation ID for trace recovery. TOP-LEVEL CANONICAL location.",
|
|
240
|
+
)
|
|
241
|
+
inference_url: str | None = Field(
|
|
242
|
+
default=None,
|
|
243
|
+
description="Inference URL used for this rollout. TOP-LEVEL CANONICAL location.",
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# =========================================================================
|
|
247
|
+
# LEGACY FIELDS (Backward Compatibility)
|
|
248
|
+
# =========================================================================
|
|
249
|
+
branches: dict[str, list[str]] = Field(
|
|
250
|
+
default_factory=dict,
|
|
251
|
+
description="[LEGACY] Branch tracking. Usually empty for single-path rollouts.",
|
|
252
|
+
)
|
|
253
|
+
aborted: bool = Field(
|
|
254
|
+
default=False,
|
|
255
|
+
description="Whether the rollout was aborted early.",
|
|
256
|
+
)
|
|
257
|
+
pipeline_metadata: dict[str, Any] = Field(
|
|
258
|
+
default_factory=dict,
|
|
259
|
+
description="[LEGACY] Additional metadata. Prefer top-level fields instead.",
|
|
260
|
+
)
|
|
225
261
|
|
|
226
262
|
|
|
227
263
|
class _ExtraAllowModel(BaseModel):
|
|
@@ -262,7 +298,7 @@ class RubricSection(_ExtraAllowModel):
|
|
|
262
298
|
|
|
263
299
|
|
|
264
300
|
class RubricInfo(_ExtraAllowModel):
|
|
265
|
-
"""Outcome and event scoring definitions used by
|
|
301
|
+
"""Outcome and event scoring definitions used by verifiers."""
|
|
266
302
|
|
|
267
303
|
outcome: RubricSection | None = None
|
|
268
304
|
events: RubricSection | None = None
|
|
@@ -287,11 +323,17 @@ class TaskInfo(_ExtraAllowModel):
|
|
|
287
323
|
"""Static metadata describing the capabilities of a Task App task."""
|
|
288
324
|
|
|
289
325
|
task: TaskDescriptor
|
|
290
|
-
environment: str
|
|
291
326
|
dataset: DatasetInfo
|
|
292
|
-
rubric: RubricInfo
|
|
293
327
|
inference: InferenceInfo
|
|
294
328
|
limits: LimitsInfo
|
|
329
|
+
environment: str | None = Field(
|
|
330
|
+
default=None,
|
|
331
|
+
description="[DEPRECATED] Legacy field not read by server. Will be removed in future version.",
|
|
332
|
+
)
|
|
333
|
+
rubric: RubricInfo | None = Field(
|
|
334
|
+
default=None,
|
|
335
|
+
description="[DEPRECATED] Use LocalAPIConfig.rubrics (RubricBundle) instead. Server ignores this field.",
|
|
336
|
+
)
|
|
295
337
|
task_metadata: dict[str, Any] = Field(
|
|
296
338
|
default_factory=dict,
|
|
297
339
|
description="Task-specific extras (e.g. prompt version info, documentation links).",
|