synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -547,37 +547,37 @@ def validate_prompt_learning_config(config_data: dict[str, Any], config_path: Pa
|
|
|
547
547
|
)
|
|
548
548
|
errors.extend(lo_errors)
|
|
549
549
|
|
|
550
|
-
# Validate
|
|
551
|
-
|
|
552
|
-
if
|
|
553
|
-
if not isinstance(
|
|
554
|
-
errors.append(f"prompt_learning.
|
|
550
|
+
# Validate verifier config (shared by GEPA and MIPRO)
|
|
551
|
+
verifier_section = pl_section.get("verifier") or {}
|
|
552
|
+
if verifier_section:
|
|
553
|
+
if not isinstance(verifier_section, dict):
|
|
554
|
+
errors.append(f"prompt_learning.verifier must be a table/dict, got {type(verifier_section).__name__}")
|
|
555
555
|
else:
|
|
556
|
-
reward_source = str(
|
|
557
|
-
enabled = bool(
|
|
558
|
-
if reward_source and reward_source not in {"task_app", "
|
|
559
|
-
errors.append("prompt_learning.
|
|
560
|
-
backend_base = str(
|
|
561
|
-
backend_provider = str(
|
|
562
|
-
backend_model = str(
|
|
556
|
+
reward_source = str(verifier_section.get("reward_source", "task_app")).strip().lower()
|
|
557
|
+
enabled = bool(verifier_section.get("enabled"))
|
|
558
|
+
if reward_source and reward_source not in {"task_app", "verifier", "fused"}:
|
|
559
|
+
errors.append("prompt_learning.verifier.reward_source must be 'task_app', 'verifier', or 'fused'")
|
|
560
|
+
backend_base = str(verifier_section.get("backend_base", "") or "").strip()
|
|
561
|
+
backend_provider = str(verifier_section.get("backend_provider", "") or "").strip()
|
|
562
|
+
backend_model = str(verifier_section.get("backend_model", "") or "").strip()
|
|
563
563
|
if enabled:
|
|
564
564
|
pass
|
|
565
565
|
if reward_source == "fused":
|
|
566
|
-
weight_event =
|
|
567
|
-
weight_outcome =
|
|
566
|
+
weight_event = verifier_section.get("weight_event", 0.0)
|
|
567
|
+
weight_outcome = verifier_section.get("weight_outcome", 0.0)
|
|
568
568
|
try:
|
|
569
569
|
weight_event_f = float(weight_event)
|
|
570
570
|
except (TypeError, ValueError):
|
|
571
|
-
errors.append("prompt_learning.
|
|
571
|
+
errors.append("prompt_learning.verifier.weight_event must be numeric")
|
|
572
572
|
weight_event_f = 0.0
|
|
573
573
|
try:
|
|
574
574
|
weight_outcome_f = float(weight_outcome)
|
|
575
575
|
except (TypeError, ValueError):
|
|
576
|
-
errors.append("prompt_learning.
|
|
576
|
+
errors.append("prompt_learning.verifier.weight_outcome must be numeric")
|
|
577
577
|
weight_outcome_f = 0.0
|
|
578
578
|
if weight_event_f <= 0 and weight_outcome_f <= 0:
|
|
579
579
|
errors.append(
|
|
580
|
-
"prompt_learning.
|
|
580
|
+
"prompt_learning.verifier.reward_source='fused' requires weight_event > 0 or weight_outcome > 0"
|
|
581
581
|
)
|
|
582
582
|
|
|
583
583
|
# Check for multi-stage/multi-module pipeline config
|
|
@@ -1,13 +1,27 @@
|
|
|
1
|
-
"""Graph completions client for graph inference (policies, verifiers, RLM).
|
|
1
|
+
"""Graph completions client for graph inference (policies, verifiers, RLM).
|
|
2
|
+
|
|
3
|
+
**Status:** Alpha
|
|
4
|
+
|
|
5
|
+
This module provides the client for running inference on trained graphs,
|
|
6
|
+
including policy graphs, verifier graphs, and Reasoning Language Models (RLM).
|
|
7
|
+
|
|
8
|
+
Provides both sync and async clients:
|
|
9
|
+
- GraphCompletionsSyncClient: Synchronous client using httpx
|
|
10
|
+
- GraphCompletionsAsyncClient: Asynchronous client using AsyncHttpClient
|
|
11
|
+
- GraphCompletionsClient: Alias for GraphCompletionsAsyncClient (backward compat)
|
|
12
|
+
"""
|
|
2
13
|
|
|
3
14
|
from __future__ import annotations
|
|
4
15
|
|
|
5
16
|
import json
|
|
17
|
+
from dataclasses import dataclass
|
|
6
18
|
from typing import Any, Literal, List, Mapping, Optional, TypedDict, Union
|
|
7
19
|
|
|
20
|
+
import httpx
|
|
21
|
+
|
|
8
22
|
from synth_ai.core.http import AsyncHttpClient, HTTPError
|
|
9
23
|
from synth_ai.core.tracing_v3.serialization import normalize_for_json
|
|
10
|
-
from synth_ai.sdk.
|
|
24
|
+
from synth_ai.sdk.graphs.verifier_schemas import (
|
|
11
25
|
CalibrationExampleInput,
|
|
12
26
|
GoldExampleInput,
|
|
13
27
|
)
|
|
@@ -20,7 +34,7 @@ class GraphTarget(TypedDict, total=False):
|
|
|
20
34
|
job_id: str
|
|
21
35
|
graph_name: str
|
|
22
36
|
graphgen_job_id: str
|
|
23
|
-
|
|
37
|
+
verifier_shape: str
|
|
24
38
|
|
|
25
39
|
|
|
26
40
|
class GraphInfo(TypedDict, total=False):
|
|
@@ -29,7 +43,7 @@ class GraphInfo(TypedDict, total=False):
|
|
|
29
43
|
graph_id: str
|
|
30
44
|
name: str
|
|
31
45
|
version: int
|
|
32
|
-
kind: str # "policy", "verifier"
|
|
46
|
+
kind: str # "policy", "verifier"
|
|
33
47
|
best_score: float | None
|
|
34
48
|
job_id: str | None # Source job that created this graph
|
|
35
49
|
created_at: str
|
|
@@ -42,8 +56,194 @@ class ListGraphsResponse(TypedDict):
|
|
|
42
56
|
total: int
|
|
43
57
|
|
|
44
58
|
|
|
45
|
-
|
|
46
|
-
|
|
59
|
+
@dataclass
|
|
60
|
+
class GraphCompletionResponse:
|
|
61
|
+
"""Response from graph completion endpoint."""
|
|
62
|
+
|
|
63
|
+
output: dict[str, Any]
|
|
64
|
+
"""The graph output data."""
|
|
65
|
+
|
|
66
|
+
usage: dict[str, Any] | None = None
|
|
67
|
+
"""Token usage statistics."""
|
|
68
|
+
|
|
69
|
+
cache_status: str | None = None
|
|
70
|
+
"""Cache hit status: 'warm', 'cold', or None."""
|
|
71
|
+
|
|
72
|
+
latency_ms: float | None = None
|
|
73
|
+
"""Request latency in milliseconds."""
|
|
74
|
+
|
|
75
|
+
raw: dict[str, Any] | None = None
|
|
76
|
+
"""Raw response dict for accessing additional fields."""
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def from_dict(cls, data: dict[str, Any]) -> "GraphCompletionResponse":
|
|
80
|
+
"""Create from API response dict."""
|
|
81
|
+
return cls(
|
|
82
|
+
output=data.get("output", {}),
|
|
83
|
+
usage=data.get("usage"),
|
|
84
|
+
cache_status=data.get("cache_status"),
|
|
85
|
+
latency_ms=data.get("latency_ms"),
|
|
86
|
+
raw=data,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class GraphCompletionsSyncClient:
|
|
91
|
+
"""Synchronous client for graph completions using httpx.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
```python
|
|
95
|
+
client = GraphCompletionsSyncClient(base_url, api_key)
|
|
96
|
+
|
|
97
|
+
# Run inference on a GraphGen job
|
|
98
|
+
response = client.run(job_id="graphgen_xxx", input_data={"query": "hello"})
|
|
99
|
+
print(response.output)
|
|
100
|
+
|
|
101
|
+
# Just get the output
|
|
102
|
+
output = client.run_output(job_id="graphgen_xxx", input_data={"query": "hello"})
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 60.0) -> None:
|
|
107
|
+
self._base = base_url.rstrip("/")
|
|
108
|
+
self._key = api_key
|
|
109
|
+
self._timeout = timeout
|
|
110
|
+
|
|
111
|
+
def _resolve_job_id(self, *, job_id: str | None, graph: GraphTarget | None) -> str:
|
|
112
|
+
if job_id:
|
|
113
|
+
return job_id
|
|
114
|
+
if not graph:
|
|
115
|
+
raise ValueError("graph_completions_missing_job_id")
|
|
116
|
+
if graph.get("job_id"):
|
|
117
|
+
return str(graph["job_id"])
|
|
118
|
+
kind = graph.get("kind")
|
|
119
|
+
if kind == "zero_shot":
|
|
120
|
+
verifier_shape = graph.get("verifier_shape") or graph.get("graph_name")
|
|
121
|
+
if not verifier_shape:
|
|
122
|
+
raise ValueError("graph_completions_missing_verifier_shape")
|
|
123
|
+
return str(verifier_shape)
|
|
124
|
+
if kind == "graphgen":
|
|
125
|
+
graphgen_job_id = graph.get("graphgen_job_id")
|
|
126
|
+
if not graphgen_job_id:
|
|
127
|
+
raise ValueError("graph_completions_missing_graphgen_job_id")
|
|
128
|
+
return str(graphgen_job_id)
|
|
129
|
+
graph_name = graph.get("graph_name")
|
|
130
|
+
if graph_name:
|
|
131
|
+
return str(graph_name)
|
|
132
|
+
raise ValueError("graph_completions_missing_graph_target")
|
|
133
|
+
|
|
134
|
+
def run(
|
|
135
|
+
self,
|
|
136
|
+
*,
|
|
137
|
+
input_data: Mapping[str, Any],
|
|
138
|
+
job_id: str | None = None,
|
|
139
|
+
graph: GraphTarget | None = None,
|
|
140
|
+
model: str | None = None,
|
|
141
|
+
prompt_snapshot_id: str | None = None,
|
|
142
|
+
timeout: float | None = None,
|
|
143
|
+
) -> GraphCompletionResponse:
|
|
144
|
+
"""Run graph completion and return typed response.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
input_data: Input data for the graph
|
|
148
|
+
job_id: GraphGen job ID or graph name
|
|
149
|
+
graph: Alternative graph target specification
|
|
150
|
+
model: Optional model override
|
|
151
|
+
prompt_snapshot_id: Specific snapshot to use
|
|
152
|
+
timeout: Request timeout (overrides client default)
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
GraphCompletionResponse with output, usage, cache_status, etc.
|
|
156
|
+
"""
|
|
157
|
+
payload: dict[str, Any] = {
|
|
158
|
+
"job_id": self._resolve_job_id(job_id=job_id, graph=graph),
|
|
159
|
+
"input": normalize_for_json(dict(input_data)),
|
|
160
|
+
}
|
|
161
|
+
if model:
|
|
162
|
+
payload["model"] = model
|
|
163
|
+
if prompt_snapshot_id:
|
|
164
|
+
payload["prompt_snapshot_id"] = prompt_snapshot_id
|
|
165
|
+
|
|
166
|
+
url = f"{self._base}/api/graphs/completions"
|
|
167
|
+
headers = {"X-API-Key": self._key, "Content-Type": "application/json"}
|
|
168
|
+
|
|
169
|
+
with httpx.Client(timeout=timeout or self._timeout) as client:
|
|
170
|
+
resp = client.post(url, headers=headers, json=payload)
|
|
171
|
+
|
|
172
|
+
if resp.status_code == 400 or resp.status_code == 422:
|
|
173
|
+
raise ValueError(f"graph_completions_validation_error: {resp.text[:500]}")
|
|
174
|
+
if resp.status_code in (401, 403):
|
|
175
|
+
raise PermissionError(f"graph_completions_auth_error: {resp.text[:500]}")
|
|
176
|
+
if resp.status_code == 404:
|
|
177
|
+
raise FileNotFoundError(f"graph_completions_not_found: {resp.text[:500]}")
|
|
178
|
+
if resp.status_code == 429:
|
|
179
|
+
raise Exception("graph_completions_rate_limited")
|
|
180
|
+
|
|
181
|
+
resp.raise_for_status()
|
|
182
|
+
return GraphCompletionResponse.from_dict(resp.json())
|
|
183
|
+
|
|
184
|
+
def run_output(
|
|
185
|
+
self,
|
|
186
|
+
*,
|
|
187
|
+
input_data: Mapping[str, Any],
|
|
188
|
+
job_id: str | None = None,
|
|
189
|
+
graph: GraphTarget | None = None,
|
|
190
|
+
model: str | None = None,
|
|
191
|
+
prompt_snapshot_id: str | None = None,
|
|
192
|
+
timeout: float | None = None,
|
|
193
|
+
) -> dict[str, Any]:
|
|
194
|
+
"""Run graph completion and return just the output dict.
|
|
195
|
+
|
|
196
|
+
Convenience method that returns only the output field.
|
|
197
|
+
"""
|
|
198
|
+
result = self.run(
|
|
199
|
+
input_data=input_data,
|
|
200
|
+
job_id=job_id,
|
|
201
|
+
graph=graph,
|
|
202
|
+
model=model,
|
|
203
|
+
prompt_snapshot_id=prompt_snapshot_id,
|
|
204
|
+
timeout=timeout,
|
|
205
|
+
)
|
|
206
|
+
return result.output
|
|
207
|
+
|
|
208
|
+
def complete(
|
|
209
|
+
self,
|
|
210
|
+
graph_id: str,
|
|
211
|
+
input_data: Mapping[str, Any],
|
|
212
|
+
*,
|
|
213
|
+
model: str | None = None,
|
|
214
|
+
timeout: float | None = None,
|
|
215
|
+
) -> GraphCompletionResponse:
|
|
216
|
+
"""Execute any graph with arbitrary input.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
graph_id: Built-in graph name, GraphGen job_id, or snapshot UUID
|
|
220
|
+
input_data: Graph-specific input data
|
|
221
|
+
model: Optional model override
|
|
222
|
+
timeout: Request timeout
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
GraphCompletionResponse
|
|
226
|
+
"""
|
|
227
|
+
return self.run(
|
|
228
|
+
input_data=input_data,
|
|
229
|
+
job_id=graph_id,
|
|
230
|
+
model=model,
|
|
231
|
+
timeout=timeout,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class GraphCompletionsAsyncClient:
|
|
236
|
+
"""Asynchronous client for graph completions.
|
|
237
|
+
|
|
238
|
+
Example:
|
|
239
|
+
```python
|
|
240
|
+
client = GraphCompletionsAsyncClient(base_url, api_key)
|
|
241
|
+
|
|
242
|
+
# Run inference on a GraphGen job
|
|
243
|
+
result = await client.run(job_id="graphgen_xxx", input_data={"query": "hello"})
|
|
244
|
+
print(result["output"])
|
|
245
|
+
```
|
|
246
|
+
"""
|
|
47
247
|
|
|
48
248
|
def __init__(self, base_url: str, api_key: str, *, timeout: float = 60.0) -> None:
|
|
49
249
|
self._base = base_url.rstrip("/")
|
|
@@ -63,7 +263,7 @@ class GraphCompletionsClient:
|
|
|
63
263
|
(determined by API key).
|
|
64
264
|
|
|
65
265
|
Args:
|
|
66
|
-
kind: Optional filter by graph kind ("policy", "verifier"
|
|
266
|
+
kind: Optional filter by graph kind ("policy", "verifier")
|
|
67
267
|
limit: Maximum number of graphs to return (default: 50)
|
|
68
268
|
|
|
69
269
|
Returns:
|
|
@@ -112,10 +312,10 @@ class GraphCompletionsClient:
|
|
|
112
312
|
return str(graph["job_id"])
|
|
113
313
|
kind = graph.get("kind")
|
|
114
314
|
if kind == "zero_shot":
|
|
115
|
-
|
|
116
|
-
if not
|
|
117
|
-
raise ValueError("
|
|
118
|
-
return str(
|
|
315
|
+
verifier_shape = graph.get("verifier_shape") or graph.get("graph_name")
|
|
316
|
+
if not verifier_shape:
|
|
317
|
+
raise ValueError("graph_completions_missing_verifier_shape")
|
|
318
|
+
return str(verifier_shape)
|
|
119
319
|
if kind == "graphgen":
|
|
120
320
|
graphgen_job_id = graph.get("graphgen_job_id")
|
|
121
321
|
if not graphgen_job_id:
|
|
@@ -228,7 +428,7 @@ class GraphCompletionsClient:
|
|
|
228
428
|
rubric: Mapping[str, Any],
|
|
229
429
|
system_prompt: str | None = None,
|
|
230
430
|
user_prompt: str | None = None,
|
|
231
|
-
|
|
431
|
+
verifier_shape: str | None = None,
|
|
232
432
|
options: Mapping[str, Any] | None = None,
|
|
233
433
|
model: str | None = None,
|
|
234
434
|
) -> dict[str, Any]:
|
|
@@ -239,7 +439,7 @@ class GraphCompletionsClient:
|
|
|
239
439
|
rubric: Rubric with event/outcome criteria
|
|
240
440
|
system_prompt: Optional custom system prompt
|
|
241
441
|
user_prompt: Optional custom user prompt
|
|
242
|
-
|
|
442
|
+
verifier_shape: "single", "mapreduce", or "rlm" (auto-detects if None)
|
|
243
443
|
options: Optional execution options (event, outcome, etc.)
|
|
244
444
|
model: Optional model override
|
|
245
445
|
|
|
@@ -247,11 +447,11 @@ class GraphCompletionsClient:
|
|
|
247
447
|
Verification result with event_reviews, outcome_review, etc.
|
|
248
448
|
"""
|
|
249
449
|
# Auto-select graph shape based on trace size
|
|
250
|
-
if
|
|
251
|
-
|
|
450
|
+
if verifier_shape is None:
|
|
451
|
+
verifier_shape = self._select_graph_shape(session_trace)
|
|
252
452
|
|
|
253
453
|
# Use composable naming: zero_shot_verifier_{gold_output_format}_{graph_shape}
|
|
254
|
-
graph_id = f"zero_shot_verifier_rubric_{
|
|
454
|
+
graph_id = f"zero_shot_verifier_rubric_{verifier_shape}"
|
|
255
455
|
|
|
256
456
|
input_data: dict[str, Any] = {
|
|
257
457
|
"session_trace": normalize_for_json(session_trace),
|
|
@@ -279,7 +479,7 @@ class GraphCompletionsClient:
|
|
|
279
479
|
expected_rubric: str | None = None,
|
|
280
480
|
system_prompt: str | None = None,
|
|
281
481
|
user_prompt: str | None = None,
|
|
282
|
-
|
|
482
|
+
verifier_shape: str | None = None,
|
|
283
483
|
options: Mapping[str, Any] | None = None,
|
|
284
484
|
model: str | None = None,
|
|
285
485
|
) -> dict[str, Any]:
|
|
@@ -295,7 +495,7 @@ class GraphCompletionsClient:
|
|
|
295
495
|
expected_rubric: Optional rubric/ground truth for the trace being evaluated
|
|
296
496
|
system_prompt: Optional custom system prompt
|
|
297
497
|
user_prompt: Optional custom user prompt
|
|
298
|
-
|
|
498
|
+
verifier_shape: "single", "mapreduce", or "rlm" (auto-detects if None)
|
|
299
499
|
options: Optional execution options
|
|
300
500
|
model: Optional model override
|
|
301
501
|
|
|
@@ -317,10 +517,10 @@ class GraphCompletionsClient:
|
|
|
317
517
|
f"and outcome_reward (float 0.0-1.0). event_rewards length must match trace events."
|
|
318
518
|
) from e
|
|
319
519
|
|
|
320
|
-
if
|
|
321
|
-
|
|
520
|
+
if verifier_shape is None:
|
|
521
|
+
verifier_shape = self._select_graph_shape(session_trace)
|
|
322
522
|
|
|
323
|
-
graph_id = f"zero_shot_verifier_fewshot_{
|
|
523
|
+
graph_id = f"zero_shot_verifier_fewshot_{verifier_shape}"
|
|
324
524
|
|
|
325
525
|
# Convert validated examples back to dict for serialization
|
|
326
526
|
input_data: dict[str, Any] = {
|
|
@@ -354,7 +554,7 @@ class GraphCompletionsClient:
|
|
|
354
554
|
expected_rubric: str | None = None,
|
|
355
555
|
system_prompt: str | None = None,
|
|
356
556
|
user_prompt: str | None = None,
|
|
357
|
-
|
|
557
|
+
verifier_shape: str | None = None,
|
|
358
558
|
options: Mapping[str, Any] | None = None,
|
|
359
559
|
model: str | None = None,
|
|
360
560
|
) -> dict[str, Any]:
|
|
@@ -374,7 +574,7 @@ class GraphCompletionsClient:
|
|
|
374
574
|
expected_rubric: Optional rubric/ground truth for this trace
|
|
375
575
|
system_prompt: Optional custom system prompt
|
|
376
576
|
user_prompt: Optional custom user prompt
|
|
377
|
-
|
|
577
|
+
verifier_shape: "single", "mapreduce", or "rlm" (auto-detects if None)
|
|
378
578
|
options: Optional execution options
|
|
379
579
|
model: Optional model override
|
|
380
580
|
|
|
@@ -408,10 +608,10 @@ class GraphCompletionsClient:
|
|
|
408
608
|
f"candidate_reasoning must be a non-empty string, got {type(candidate_reasoning).__name__}"
|
|
409
609
|
)
|
|
410
610
|
|
|
411
|
-
if
|
|
412
|
-
|
|
611
|
+
if verifier_shape is None:
|
|
612
|
+
verifier_shape = self._select_graph_shape(session_trace)
|
|
413
613
|
|
|
414
|
-
graph_id = f"zero_shot_verifier_contrastive_{
|
|
614
|
+
graph_id = f"zero_shot_verifier_contrastive_{verifier_shape}"
|
|
415
615
|
|
|
416
616
|
# Convert validated examples back to dict for serialization
|
|
417
617
|
input_data: dict[str, Any] = {
|
|
@@ -441,7 +641,7 @@ class GraphCompletionsClient:
|
|
|
441
641
|
session_trace: Mapping[str, Any],
|
|
442
642
|
system_prompt: str,
|
|
443
643
|
user_prompt: str,
|
|
444
|
-
|
|
644
|
+
verifier_shape: str | None = None,
|
|
445
645
|
options: Mapping[str, Any] | None = None,
|
|
446
646
|
model: str | None = None,
|
|
447
647
|
) -> dict[str, Any]:
|
|
@@ -451,19 +651,19 @@ class GraphCompletionsClient:
|
|
|
451
651
|
session_trace: V3 trace format
|
|
452
652
|
system_prompt: Custom system prompt (required)
|
|
453
653
|
user_prompt: Custom user prompt (required)
|
|
454
|
-
|
|
654
|
+
verifier_shape: "single", "mapreduce", or "rlm" (auto-detects if None)
|
|
455
655
|
options: Optional execution options
|
|
456
656
|
model: Optional model override
|
|
457
657
|
|
|
458
658
|
Returns:
|
|
459
659
|
Verification result
|
|
460
660
|
"""
|
|
461
|
-
if
|
|
462
|
-
|
|
661
|
+
if verifier_shape is None:
|
|
662
|
+
verifier_shape = self._select_graph_shape(session_trace)
|
|
463
663
|
|
|
464
664
|
# For custom prompts, use rubric single graph but with custom prompts
|
|
465
665
|
# The graph will use the prompts instead of rubric
|
|
466
|
-
graph_id = f"zero_shot_verifier_rubric_{
|
|
666
|
+
graph_id = f"zero_shot_verifier_rubric_{verifier_shape}"
|
|
467
667
|
|
|
468
668
|
input_data: dict[str, Any] = {
|
|
469
669
|
"session_trace": normalize_for_json(session_trace),
|
|
@@ -529,7 +729,7 @@ class GraphCompletionsClient:
|
|
|
529
729
|
return result
|
|
530
730
|
|
|
531
731
|
|
|
532
|
-
class
|
|
732
|
+
class VerifierAsyncClient(GraphCompletionsAsyncClient):
|
|
533
733
|
"""Verifier graph client that builds standard verifier inputs."""
|
|
534
734
|
|
|
535
735
|
async def evaluate(
|
|
@@ -554,7 +754,6 @@ class VerifierClient(GraphCompletionsClient):
|
|
|
554
754
|
input_data: dict[str, Any] = {
|
|
555
755
|
"policy_name": policy_name,
|
|
556
756
|
"task_app": task_app_payload,
|
|
557
|
-
"session_trace": trace_payload,
|
|
558
757
|
"trace": trace_payload,
|
|
559
758
|
"options": dict(options or {}),
|
|
560
759
|
}
|
|
@@ -568,3 +767,10 @@ class VerifierClient(GraphCompletionsClient):
|
|
|
568
767
|
model=model,
|
|
569
768
|
prompt_snapshot_id=prompt_snapshot_id,
|
|
570
769
|
)
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
GraphCompletionsClient = GraphCompletionsAsyncClient
|
|
773
|
+
"""Alias for GraphCompletionsAsyncClient."""
|
|
774
|
+
|
|
775
|
+
VerifierClient = VerifierAsyncClient
|
|
776
|
+
"""Alias for VerifierAsyncClient."""
|
|
@@ -2,21 +2,21 @@
|
|
|
2
2
|
Verifier API Contract Schemas
|
|
3
3
|
|
|
4
4
|
These schemas define the expected structure for requests and responses
|
|
5
|
-
to the verifier scoring endpoint at POST /api/
|
|
6
|
-
verifier graphs use the same response format via POST /api/graphs/completions.
|
|
5
|
+
to the verifier scoring endpoint at POST /api/graphs/verifiers/completions.
|
|
6
|
+
Zero-shot verifier graphs use the same response format via POST /api/graphs/completions.
|
|
7
7
|
|
|
8
8
|
This is the canonical contract that the backend MUST conform to.
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
from __future__ import annotations
|
|
12
12
|
|
|
13
|
-
from typing import Any, Literal, Optional
|
|
13
|
+
from typing import Annotated, Any, Literal, Optional
|
|
14
14
|
|
|
15
15
|
from pydantic import BaseModel, Field, model_validator
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class CriterionScorePayload(BaseModel):
|
|
19
|
-
"""Per-criterion score returned by the
|
|
19
|
+
"""Per-criterion score returned by the verifier."""
|
|
20
20
|
|
|
21
21
|
score: float = Field(..., description="Numeric score for this criterion")
|
|
22
22
|
reason: str = Field(default="", description="Explanation for the score")
|
|
@@ -35,11 +35,11 @@ class ReviewPayload(BaseModel):
|
|
|
35
35
|
summary: Optional[str] = Field(None, description="Optional text summary")
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class
|
|
38
|
+
class VerifierScoreResponse(BaseModel):
|
|
39
39
|
"""
|
|
40
|
-
Response body for POST /api/
|
|
40
|
+
Response body for POST /api/graphs/verifiers/completions.
|
|
41
41
|
|
|
42
|
-
This is the canonical contract that
|
|
42
|
+
This is the canonical contract that verifier backends MUST return and is
|
|
43
43
|
also used as the zero-shot verifier graph output.
|
|
44
44
|
"""
|
|
45
45
|
|
|
@@ -90,24 +90,24 @@ class JudgeScoreResponse(BaseModel):
|
|
|
90
90
|
|
|
91
91
|
# Request schemas for completeness
|
|
92
92
|
|
|
93
|
-
class
|
|
93
|
+
class VerifierTaskApp(BaseModel):
|
|
94
94
|
"""Task application metadata."""
|
|
95
95
|
|
|
96
96
|
id: str = Field(..., description="Task app identifier")
|
|
97
97
|
base_url: Optional[str] = Field(None, description="Optional base URL for task app")
|
|
98
98
|
|
|
99
99
|
|
|
100
|
-
class
|
|
101
|
-
"""
|
|
100
|
+
class VerifierOptions(BaseModel):
|
|
101
|
+
"""Verifier provider and configuration options."""
|
|
102
102
|
|
|
103
|
-
provider: Optional[str] = Field(None, description="
|
|
103
|
+
provider: Optional[str] = Field(None, description="Verifier provider (e.g., 'openai', 'groq')")
|
|
104
104
|
model: Optional[str] = Field(None, description="Model identifier")
|
|
105
105
|
rubric_id: Optional[str] = Field(None, description="Rubric identifier")
|
|
106
|
-
event: bool = Field(True, description="Enable event-level
|
|
107
|
-
outcome: bool = Field(True, description="Enable outcome-level
|
|
106
|
+
event: bool = Field(True, description="Enable event-level verification")
|
|
107
|
+
outcome: bool = Field(True, description="Enable outcome-level verification")
|
|
108
108
|
|
|
109
109
|
|
|
110
|
-
class
|
|
110
|
+
class VerifierTracePayload(BaseModel):
|
|
111
111
|
"""Trace payload containing trajectory context."""
|
|
112
112
|
|
|
113
113
|
event_history: list[dict[str, Any]] = Field(..., description="List of events/steps")
|
|
@@ -118,13 +118,13 @@ class JudgeTracePayload(BaseModel):
|
|
|
118
118
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Trace metadata")
|
|
119
119
|
|
|
120
120
|
|
|
121
|
-
class
|
|
122
|
-
"""Request body for POST /api/
|
|
121
|
+
class VerifierScoreRequest(BaseModel):
|
|
122
|
+
"""Request body for POST /api/graphs/verifiers/completions."""
|
|
123
123
|
|
|
124
124
|
policy_name: str = Field(..., description="Name of the policy being evaluated")
|
|
125
|
-
task_app:
|
|
126
|
-
trace:
|
|
127
|
-
options:
|
|
125
|
+
task_app: VerifierTaskApp = Field(..., description="Task application metadata")
|
|
126
|
+
trace: VerifierTracePayload = Field(..., description="Trajectory trace to evaluate")
|
|
127
|
+
options: VerifierOptions = Field(default_factory=lambda: VerifierOptions(), description="Verifier options")
|
|
128
128
|
rubric: Optional[dict[str, Any]] = Field(None, description="Optional explicit rubric criteria")
|
|
129
129
|
|
|
130
130
|
|
|
@@ -139,11 +139,11 @@ class CalibrationExampleInput(BaseModel):
|
|
|
139
139
|
|
|
140
140
|
session_trace: dict[str, Any] = Field(..., description="V3 SessionTrace format (validated separately)")
|
|
141
141
|
event_rewards: list[Annotated[float, Field(ge=0.0, le=1.0)]] = Field(
|
|
142
|
-
...,
|
|
142
|
+
...,
|
|
143
143
|
description="List of rewards per event (0.0-1.0), must match number of events in trace"
|
|
144
144
|
)
|
|
145
145
|
outcome_reward: Annotated[float, Field(ge=0.0, le=1.0)] = Field(
|
|
146
|
-
...,
|
|
146
|
+
...,
|
|
147
147
|
description="Overall outcome reward (0.0-1.0)"
|
|
148
148
|
)
|
|
149
149
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
|
|
@@ -200,12 +200,12 @@ class GoldExampleInput(BaseModel):
|
|
|
200
200
|
|
|
201
201
|
summary: str = Field(..., min_length=1, description="Summary of the trace being evaluated")
|
|
202
202
|
gold_score: Annotated[float, Field(ge=0.0, le=1.0)] = Field(
|
|
203
|
-
...,
|
|
203
|
+
...,
|
|
204
204
|
description="Gold-standard score (0.0-1.0)"
|
|
205
205
|
)
|
|
206
206
|
gold_reasoning: str = Field(..., min_length=1, description="Gold-standard reasoning/explanation")
|
|
207
207
|
session_trace: Optional[dict[str, Any]] = Field(
|
|
208
|
-
None,
|
|
208
|
+
None,
|
|
209
209
|
description="Optional full trace (for richer evaluation)"
|
|
210
210
|
)
|
|
211
211
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
|
|
@@ -1,6 +1,23 @@
|
|
|
1
1
|
from synth_ai.sdk.task import task_app_health, validate_task_app_url
|
|
2
2
|
|
|
3
3
|
from .client import LearningClient
|
|
4
|
+
from .context_learning_client import (
|
|
5
|
+
ContextLearningClient,
|
|
6
|
+
create_job as create_context_learning_job,
|
|
7
|
+
get_best_script as get_context_learning_best_script,
|
|
8
|
+
get_job_status as get_context_learning_status,
|
|
9
|
+
run_job as run_context_learning_job,
|
|
10
|
+
)
|
|
11
|
+
from .context_learning_types import (
|
|
12
|
+
AlgorithmConfig,
|
|
13
|
+
BestScriptResult,
|
|
14
|
+
ContextLearningEvent,
|
|
15
|
+
ContextLearningJobConfig,
|
|
16
|
+
ContextLearningJobStatus,
|
|
17
|
+
ContextLearningMetric,
|
|
18
|
+
ContextLearningResults,
|
|
19
|
+
EnvironmentConfig,
|
|
20
|
+
)
|
|
4
21
|
from .health import backend_health, balance_autumn_normalized, pricing_preflight
|
|
5
22
|
from .jobs import JobHandle, JobsApiResolver
|
|
6
23
|
from .prompt_learning_client import (
|
|
@@ -20,8 +37,6 @@ from .rl import (
|
|
|
20
37
|
RolloutRequest,
|
|
21
38
|
RolloutResponse,
|
|
22
39
|
RolloutSafetyConfig,
|
|
23
|
-
RolloutStep,
|
|
24
|
-
RolloutTrajectory,
|
|
25
40
|
encrypt_for_backend,
|
|
26
41
|
mint_environment_api_key,
|
|
27
42
|
setup_environment_api_key,
|
|
@@ -32,30 +47,45 @@ from .sse import stream_events as stream_job_events
|
|
|
32
47
|
from .validators import validate_trainer_cfg_rl, validate_training_jsonl
|
|
33
48
|
|
|
34
49
|
__all__ = [
|
|
50
|
+
# Learning clients
|
|
35
51
|
"LearningClient",
|
|
36
52
|
"RlClient",
|
|
37
53
|
"RLJobConfig",
|
|
38
54
|
"FtClient",
|
|
39
55
|
"SFTJobConfig",
|
|
40
56
|
"prepare_sft_job_payload",
|
|
57
|
+
# Prompt Learning
|
|
41
58
|
"PromptLearningClient",
|
|
42
59
|
"get_prompts",
|
|
43
60
|
"get_prompt_text",
|
|
44
61
|
"get_scoring_summary",
|
|
62
|
+
# Context Learning
|
|
63
|
+
"ContextLearningClient",
|
|
64
|
+
"ContextLearningJobConfig",
|
|
65
|
+
"ContextLearningJobStatus",
|
|
66
|
+
"ContextLearningEvent",
|
|
67
|
+
"ContextLearningMetric",
|
|
68
|
+
"ContextLearningResults",
|
|
69
|
+
"BestScriptResult",
|
|
70
|
+
"EnvironmentConfig",
|
|
71
|
+
"AlgorithmConfig",
|
|
72
|
+
"create_context_learning_job",
|
|
73
|
+
"get_context_learning_status",
|
|
74
|
+
"get_context_learning_best_script",
|
|
75
|
+
"run_context_learning_job",
|
|
76
|
+
# RL types
|
|
45
77
|
"RolloutEnvSpec",
|
|
46
78
|
"RolloutPolicySpec",
|
|
47
79
|
"RolloutRecordConfig",
|
|
48
80
|
"RolloutSafetyConfig",
|
|
49
81
|
"RolloutRequest",
|
|
50
|
-
"RolloutStep",
|
|
51
|
-
"RolloutTrajectory",
|
|
52
82
|
"RolloutMetrics",
|
|
53
83
|
"RolloutResponse",
|
|
54
84
|
"mint_environment_api_key",
|
|
55
85
|
"encrypt_for_backend",
|
|
56
86
|
"setup_environment_api_key",
|
|
57
87
|
"MAX_ENVIRONMENT_API_KEY_BYTES",
|
|
58
|
-
#
|
|
88
|
+
# Utilities
|
|
59
89
|
"validate_training_jsonl",
|
|
60
90
|
"validate_trainer_cfg_rl",
|
|
61
91
|
"validate_task_app_url",
|