synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
"""First-class SDK API for GraphGen (
|
|
1
|
+
"""First-class SDK API for GraphGen (Graph Opt).
|
|
2
|
+
|
|
3
|
+
**Status:** Alpha
|
|
2
4
|
|
|
3
5
|
GraphGen is a simplified "Workflows API" for prompt optimization that:
|
|
4
6
|
- Uses a simple JSON dataset format (GraphGenTaskSet) instead of TOML configs
|
|
5
7
|
- Auto-generates task apps from the dataset (no user-managed task apps)
|
|
6
|
-
- Has built-in
|
|
8
|
+
- Has built-in verifier configurations (rubric, contrastive, gold_examples)
|
|
7
9
|
- Wraps GEPA internally for the actual optimization
|
|
8
10
|
|
|
9
11
|
Example CLI usage:
|
|
10
|
-
uvx synth-ai train --type
|
|
12
|
+
uvx synth-ai train --type graphgen --dataset my_tasks.json --poll
|
|
11
13
|
|
|
12
14
|
Example SDK usage:
|
|
13
15
|
from synth_ai.sdk.api.train.graphgen import GraphGenJob
|
|
@@ -46,14 +48,39 @@ from .graphgen_models import (
|
|
|
46
48
|
load_graphgen_taskset,
|
|
47
49
|
parse_graphgen_taskset,
|
|
48
50
|
SessionTraceInput,
|
|
49
|
-
|
|
51
|
+
GraphGenGraphVerifierResponse,
|
|
50
52
|
)
|
|
51
53
|
from .utils import ensure_api_base, http_get, http_post
|
|
52
54
|
|
|
53
55
|
|
|
54
56
|
@dataclass
|
|
55
57
|
class GraphGenJobResult:
|
|
56
|
-
"""Result from
|
|
58
|
+
"""Result from a GraphGen job.
|
|
59
|
+
|
|
60
|
+
Contains the final status and results of a completed GraphGen workflow
|
|
61
|
+
optimization job, including the best score and snapshot ID for the
|
|
62
|
+
optimized graph.
|
|
63
|
+
|
|
64
|
+
Attributes:
|
|
65
|
+
graphgen_job_id: Unique identifier for the GraphGen job (e.g.,
|
|
66
|
+
"graphgen_abc123def456").
|
|
67
|
+
status: Current job status. One of: "pending", "running", "succeeded",
|
|
68
|
+
"failed", "cancelled".
|
|
69
|
+
best_score: Best evaluation score achieved during optimization. Higher
|
|
70
|
+
is better. None if job hasn't completed successfully.
|
|
71
|
+
best_snapshot_id: ID of the graph snapshot with the best score. Use this
|
|
72
|
+
to download or deploy the optimized graph.
|
|
73
|
+
error: Error message if the job failed, None otherwise.
|
|
74
|
+
dataset_name: Name of the dataset used for optimization.
|
|
75
|
+
task_count: Number of tasks in the dataset.
|
|
76
|
+
graph_evolve_job_id: ID of the underlying graph evolution job, if applicable.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> result = job.get_result()
|
|
80
|
+
>>> if result.status == "succeeded":
|
|
81
|
+
... print(f"Best score: {result.best_score}")
|
|
82
|
+
... print(f"Snapshot ID: {result.best_snapshot_id}")
|
|
83
|
+
"""
|
|
57
84
|
|
|
58
85
|
graphgen_job_id: str
|
|
59
86
|
status: str
|
|
@@ -67,7 +94,29 @@ class GraphGenJobResult:
|
|
|
67
94
|
|
|
68
95
|
@dataclass
|
|
69
96
|
class GraphGenSubmitResult:
|
|
70
|
-
"""Result from submitting
|
|
97
|
+
"""Result from submitting a GraphGen job.
|
|
98
|
+
|
|
99
|
+
Returned immediately after job submission with initial job metadata
|
|
100
|
+
and configuration details.
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
graphgen_job_id: Unique identifier for the GraphGen job.
|
|
104
|
+
status: Initial job status (typically "pending" or "running").
|
|
105
|
+
dataset_name: Name of the dataset being used for optimization.
|
|
106
|
+
task_count: Number of tasks in the dataset.
|
|
107
|
+
rollout_budget: Total number of rollouts (evaluations) budgeted for
|
|
108
|
+
this optimization job.
|
|
109
|
+
policy_model: Name of the LLM model being used for the policy
|
|
110
|
+
(e.g., "gpt-4o-mini", "claude-3-5-sonnet").
|
|
111
|
+
verifier_mode: Evaluation mode being used. One of: "rubric", "contrastive",
|
|
112
|
+
"gold_examples", "verifier_graph".
|
|
113
|
+
graph_evolve_job_id: ID of the underlying graph evolution job, if applicable.
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
>>> submit_result = job.submit()
|
|
117
|
+
>>> print(f"Job {submit_result.graphgen_job_id} started")
|
|
118
|
+
>>> print(f"Optimizing {submit_result.task_count} tasks with {submit_result.rollout_budget} rollouts")
|
|
119
|
+
"""
|
|
71
120
|
|
|
72
121
|
graphgen_job_id: str
|
|
73
122
|
status: str
|
|
@@ -75,20 +124,20 @@ class GraphGenSubmitResult:
|
|
|
75
124
|
task_count: int
|
|
76
125
|
rollout_budget: int
|
|
77
126
|
policy_model: str
|
|
78
|
-
|
|
127
|
+
verifier_mode: str
|
|
79
128
|
graph_evolve_job_id: Optional[str] = None
|
|
80
129
|
|
|
81
130
|
|
|
82
131
|
class GraphGenJob:
|
|
83
132
|
"""High-level SDK class for running GraphGen workflow optimization jobs.
|
|
84
133
|
|
|
85
|
-
GraphGen (
|
|
134
|
+
GraphGen (Graph Opt) provides a simplified API for
|
|
86
135
|
graph/workflow optimization that doesn't require users to manage task apps.
|
|
87
136
|
|
|
88
137
|
Key differences from PromptLearningJob:
|
|
89
138
|
- Uses JSON dataset format (GraphGenTaskSet) instead of TOML configs
|
|
90
139
|
- No task app management required - GraphGen builds it internally
|
|
91
|
-
- Built-in
|
|
140
|
+
- Built-in verifier modes (rubric, contrastive, gold_examples)
|
|
92
141
|
- Graph-first: trains multi-node workflows by default (Graph-GEPA)
|
|
93
142
|
- Public graph downloads are redacted `.txt` exports only
|
|
94
143
|
- Simpler configuration with sensible defaults
|
|
@@ -103,7 +152,7 @@ class GraphGenJob:
|
|
|
103
152
|
... rollout_budget=100,
|
|
104
153
|
... )
|
|
105
154
|
>>>
|
|
106
|
-
>>> # Train a verifier graph
|
|
155
|
+
>>> # Train a verifier graph
|
|
107
156
|
>>> verifier_job = GraphGenJob.from_dataset(
|
|
108
157
|
... dataset="verifier_dataset.json",
|
|
109
158
|
... graph_type="verifier",
|
|
@@ -133,9 +182,9 @@ class GraphGenJob:
|
|
|
133
182
|
>>> # Run inference with optimized prompt
|
|
134
183
|
>>> output = job.run_inference({"question": "What is 2+2?"})
|
|
135
184
|
>>>
|
|
136
|
-
>>> # Run
|
|
137
|
-
>>>
|
|
138
|
-
>>> print(f"
|
|
185
|
+
>>> # Run verifier with optimized verifier graph
|
|
186
|
+
>>> verification = verifier_job.run_verifier(trace_data)
|
|
187
|
+
>>> print(f"Outcome reward: {verification.outcome_reward}")
|
|
139
188
|
"""
|
|
140
189
|
|
|
141
190
|
def __init__(
|
|
@@ -178,8 +227,8 @@ class GraphGenJob:
|
|
|
178
227
|
policy_model: str = "gpt-4o-mini",
|
|
179
228
|
rollout_budget: int = 100,
|
|
180
229
|
proposer_effort: Literal["low", "medium", "high"] = "medium",
|
|
181
|
-
|
|
182
|
-
|
|
230
|
+
verifier_model: Optional[str] = None,
|
|
231
|
+
verifier_provider: Optional[str] = None,
|
|
183
232
|
population_size: int = 4,
|
|
184
233
|
num_generations: Optional[int] = None,
|
|
185
234
|
problem_spec: Optional[str] = None,
|
|
@@ -196,15 +245,15 @@ class GraphGenJob:
|
|
|
196
245
|
dataset: Dataset as file path, dict, or GraphGenTaskSet object
|
|
197
246
|
graph_type: Type of graph to train:
|
|
198
247
|
- "policy": Maps inputs to outputs (default).
|
|
199
|
-
- "verifier":
|
|
248
|
+
- "verifier": Verifies/scores traces (requires verifier-compliant dataset).
|
|
200
249
|
- "rlm": Recursive Language Model - handles massive contexts via tool-based search
|
|
201
250
|
and recursive LLM calls. Requires configured_tools parameter.
|
|
202
251
|
policy_model: Model to use for policy inference
|
|
203
252
|
rollout_budget: Total number of rollouts for optimization
|
|
204
253
|
proposer_effort: Proposer effort level ("medium" or "high").
|
|
205
254
|
"low" is not allowed as gpt-4.1-mini is too weak for graph generation.
|
|
206
|
-
|
|
207
|
-
|
|
255
|
+
verifier_model: Override verifier model from dataset
|
|
256
|
+
verifier_provider: Override verifier provider from dataset
|
|
208
257
|
population_size: Population size for GEPA
|
|
209
258
|
num_generations: Number of generations (auto-calculated if not specified)
|
|
210
259
|
problem_spec: Detailed problem specification for the graph proposer.
|
|
@@ -270,8 +319,8 @@ class GraphGenJob:
|
|
|
270
319
|
policy_model=policy_model,
|
|
271
320
|
rollout_budget=rollout_budget,
|
|
272
321
|
proposer_effort=proposer_effort,
|
|
273
|
-
|
|
274
|
-
|
|
322
|
+
verifier_model=verifier_model,
|
|
323
|
+
verifier_provider=verifier_provider,
|
|
275
324
|
population_size=population_size,
|
|
276
325
|
num_generations=num_generations,
|
|
277
326
|
problem_spec=problem_spec,
|
|
@@ -405,8 +454,8 @@ class GraphGenJob:
|
|
|
405
454
|
"policy_provider": self.config.policy_provider,
|
|
406
455
|
"rollout_budget": self.config.rollout_budget,
|
|
407
456
|
"proposer_effort": self.config.proposer_effort,
|
|
408
|
-
"
|
|
409
|
-
"
|
|
457
|
+
"verifier_model": self.config.verifier_model,
|
|
458
|
+
"verifier_provider": self.config.verifier_provider,
|
|
410
459
|
"problem_spec": self.config.problem_spec,
|
|
411
460
|
"target_llm_calls": self.config.target_llm_calls,
|
|
412
461
|
"configured_tools": self.config.configured_tools,
|
|
@@ -423,10 +472,10 @@ class GraphGenJob:
|
|
|
423
472
|
payload.pop("feedback_sample_size", None)
|
|
424
473
|
if payload.get("policy_provider") is None:
|
|
425
474
|
payload.pop("policy_provider", None)
|
|
426
|
-
if payload.get("
|
|
427
|
-
payload.pop("
|
|
428
|
-
if payload.get("
|
|
429
|
-
payload.pop("
|
|
475
|
+
if payload.get("verifier_model") is None:
|
|
476
|
+
payload.pop("verifier_model", None)
|
|
477
|
+
if payload.get("verifier_provider") is None:
|
|
478
|
+
payload.pop("verifier_provider", None)
|
|
430
479
|
if payload.get("problem_spec") is None:
|
|
431
480
|
payload.pop("problem_spec", None)
|
|
432
481
|
if payload.get("target_llm_calls") is None:
|
|
@@ -458,7 +507,7 @@ class GraphGenJob:
|
|
|
458
507
|
|
|
459
508
|
payload = self._build_payload()
|
|
460
509
|
|
|
461
|
-
# Submit job - use /graphgen/jobs endpoint
|
|
510
|
+
# Submit job - use /graphgen/jobs endpoint
|
|
462
511
|
create_url = f"{self.backend_url}/graphgen/jobs"
|
|
463
512
|
headers = {
|
|
464
513
|
"X-API-Key": self.api_key,
|
|
@@ -501,7 +550,7 @@ class GraphGenJob:
|
|
|
501
550
|
task_count=js.get("task_count", len(self.dataset.tasks)),
|
|
502
551
|
rollout_budget=js.get("rollout_budget", self.config.rollout_budget),
|
|
503
552
|
policy_model=js.get("policy_model", self.config.policy_model),
|
|
504
|
-
|
|
553
|
+
verifier_mode=js.get("verifier_mode", self.dataset.verifier_config.mode),
|
|
505
554
|
graph_evolve_job_id=self._graph_evolve_job_id,
|
|
506
555
|
)
|
|
507
556
|
|
|
@@ -703,7 +752,7 @@ class GraphGenJob:
|
|
|
703
752
|
base_url=self.backend_url,
|
|
704
753
|
api_key=self.api_key,
|
|
705
754
|
job_id=self.job_id, # Only GraphGen job ID - backend resolves to GEPA internally
|
|
706
|
-
endpoints=StreamEndpoints.
|
|
755
|
+
endpoints=StreamEndpoints.graphgen(self.job_id),
|
|
707
756
|
config=config,
|
|
708
757
|
handlers=list(handlers),
|
|
709
758
|
interval_seconds=interval,
|
|
@@ -715,6 +764,88 @@ class GraphGenJob:
|
|
|
715
764
|
|
|
716
765
|
return final_status
|
|
717
766
|
|
|
767
|
+
def poll_until_complete(
|
|
768
|
+
self,
|
|
769
|
+
*,
|
|
770
|
+
timeout: float = 3600.0,
|
|
771
|
+
interval: float = 5.0,
|
|
772
|
+
progress: bool = False,
|
|
773
|
+
on_status: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
774
|
+
) -> Dict[str, Any]:
|
|
775
|
+
"""Poll job until it reaches a terminal state.
|
|
776
|
+
|
|
777
|
+
Similar to PromptLearningJob.poll_until_complete(), this method polls
|
|
778
|
+
the backend periodically instead of using SSE streaming. Useful for
|
|
779
|
+
notebooks and environments where SSE may not work reliably.
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
timeout: Maximum seconds to wait (default: 3600 = 1 hour)
|
|
783
|
+
interval: Seconds between poll attempts (default: 5)
|
|
784
|
+
progress: If True, print status updates during polling (useful for notebooks)
|
|
785
|
+
on_status: Optional callback called on each status update
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
Final job status dictionary containing 'status', 'best_score', etc.
|
|
789
|
+
|
|
790
|
+
Raises:
|
|
791
|
+
RuntimeError: If job hasn't been submitted yet
|
|
792
|
+
TimeoutError: If timeout is exceeded
|
|
793
|
+
|
|
794
|
+
Example:
|
|
795
|
+
>>> result = job.poll_until_complete(progress=True)
|
|
796
|
+
[00:15] running | score: 0.72
|
|
797
|
+
[00:30] running | score: 0.78
|
|
798
|
+
[00:45] succeeded | score: 0.85
|
|
799
|
+
"""
|
|
800
|
+
if not self.job_id:
|
|
801
|
+
raise RuntimeError("Job not yet submitted. Call submit() first.")
|
|
802
|
+
|
|
803
|
+
import time
|
|
804
|
+
|
|
805
|
+
start_time = time.time()
|
|
806
|
+
elapsed = 0.0
|
|
807
|
+
last_data: Dict[str, Any] = {}
|
|
808
|
+
|
|
809
|
+
while elapsed <= timeout:
|
|
810
|
+
try:
|
|
811
|
+
status_data = self.get_status()
|
|
812
|
+
last_data = dict(status_data) if isinstance(status_data, dict) else {}
|
|
813
|
+
|
|
814
|
+
status = last_data.get("status", "unknown")
|
|
815
|
+
best_score = last_data.get("best_score")
|
|
816
|
+
|
|
817
|
+
# Progress output
|
|
818
|
+
if progress:
|
|
819
|
+
mins, secs = divmod(int(elapsed), 60)
|
|
820
|
+
score_str = f"score: {best_score:.2f}" if best_score is not None else "score: --"
|
|
821
|
+
print(f"[{mins:02d}:{secs:02d}] {status} | {score_str}")
|
|
822
|
+
|
|
823
|
+
# Callback for custom handling
|
|
824
|
+
if on_status:
|
|
825
|
+
on_status(last_data)
|
|
826
|
+
|
|
827
|
+
# Check terminal state
|
|
828
|
+
if status in ("succeeded", "completed", "failed", "error", "cancelled"):
|
|
829
|
+
return last_data
|
|
830
|
+
|
|
831
|
+
# Sleep before next poll
|
|
832
|
+
time.sleep(interval)
|
|
833
|
+
elapsed = time.time() - start_time
|
|
834
|
+
|
|
835
|
+
except Exception as e:
|
|
836
|
+
# On error, continue polling (might be transient network issue)
|
|
837
|
+
import logging
|
|
838
|
+
logger = logging.getLogger(__name__)
|
|
839
|
+
logger.warning(f"Error polling job status: {e}")
|
|
840
|
+
time.sleep(interval)
|
|
841
|
+
elapsed = time.time() - start_time
|
|
842
|
+
|
|
843
|
+
# Timeout exceeded
|
|
844
|
+
raise TimeoutError(
|
|
845
|
+
f"Job {self.job_id} did not complete within {timeout}s timeout. "
|
|
846
|
+
f"Current status: {last_data.get('status', 'unknown')}"
|
|
847
|
+
)
|
|
848
|
+
|
|
718
849
|
def download_prompt(self) -> str:
|
|
719
850
|
"""Download the optimized prompt from a completed job.
|
|
720
851
|
|
|
@@ -773,6 +904,7 @@ class GraphGenJob:
|
|
|
773
904
|
model: Optional[str] = None,
|
|
774
905
|
prompt_snapshot_id: Optional[str] = None,
|
|
775
906
|
graph_snapshot_id: Optional[str] = None,
|
|
907
|
+
timeout: float = 120.0,
|
|
776
908
|
) -> Dict[str, Any]:
|
|
777
909
|
"""Run inference with the optimized graph/workflow.
|
|
778
910
|
|
|
@@ -783,6 +915,7 @@ class GraphGenJob:
|
|
|
783
915
|
graph_snapshot_id: Specific GraphSnapshot to use (default: best).
|
|
784
916
|
Preferred for graph-first jobs. If provided, it is sent as
|
|
785
917
|
`prompt_snapshot_id` for backward-compatible backend routing.
|
|
918
|
+
timeout: Request timeout in seconds (default: 120.0 = 2 minutes for image generation tasks)
|
|
786
919
|
|
|
787
920
|
Returns:
|
|
788
921
|
Output dictionary containing 'output', 'usage', etc.
|
|
@@ -813,7 +946,8 @@ class GraphGenJob:
|
|
|
813
946
|
if snapshot_id:
|
|
814
947
|
payload["prompt_snapshot_id"] = snapshot_id
|
|
815
948
|
|
|
816
|
-
|
|
949
|
+
# Use longer timeout for image generation tasks (can take 2-3 minutes)
|
|
950
|
+
resp = http_post(url, headers=headers, json_body=payload, timeout=timeout)
|
|
817
951
|
|
|
818
952
|
if resp.status_code != 200:
|
|
819
953
|
raise RuntimeError(
|
|
@@ -848,11 +982,11 @@ class GraphGenJob:
|
|
|
848
982
|
context: Optional[Dict[str, Any]] = None,
|
|
849
983
|
prompt_snapshot_id: Optional[str] = None,
|
|
850
984
|
graph_snapshot_id: Optional[str] = None,
|
|
851
|
-
) ->
|
|
985
|
+
) -> GraphGenGraphVerifierResponse:
|
|
852
986
|
"""Run a verifier graph on an execution trace.
|
|
853
987
|
|
|
854
988
|
This method is specifically for graphs trained with graph_type=\"verifier\".
|
|
855
|
-
It accepts a V3 trace and returns structured rewards
|
|
989
|
+
It accepts a V3 trace and returns structured rewards.
|
|
856
990
|
|
|
857
991
|
Args:
|
|
858
992
|
session_trace: V3 session trace to evaluate. Can be a dict or SessionTraceInput.
|
|
@@ -862,7 +996,7 @@ class GraphGenJob:
|
|
|
862
996
|
Preferred for graph-first jobs.
|
|
863
997
|
|
|
864
998
|
Returns:
|
|
865
|
-
|
|
999
|
+
GraphGenGraphVerifierResponse containing structured rewards.
|
|
866
1000
|
|
|
867
1001
|
Raises:
|
|
868
1002
|
RuntimeError: If job hasn't been submitted or inference fails.
|
|
@@ -873,7 +1007,7 @@ class GraphGenJob:
|
|
|
873
1007
|
if prompt_snapshot_id and graph_snapshot_id:
|
|
874
1008
|
raise ValueError("Provide only one of prompt_snapshot_id or graph_snapshot_id.")
|
|
875
1009
|
|
|
876
|
-
url = f"{self.backend_url}/graphgen/graph/
|
|
1010
|
+
url = f"{self.backend_url}/graphgen/graph/verifier"
|
|
877
1011
|
headers = {
|
|
878
1012
|
"X-API-Key": self.api_key,
|
|
879
1013
|
"Content-Type": "application/json",
|
|
@@ -902,23 +1036,7 @@ class GraphGenJob:
|
|
|
902
1036
|
f"Verifier inference failed: {resp.status_code} - {resp.text[:500]}"
|
|
903
1037
|
)
|
|
904
1038
|
|
|
905
|
-
return
|
|
906
|
-
|
|
907
|
-
def run_judge(
|
|
908
|
-
self,
|
|
909
|
-
session_trace: Dict[str, Any] | SessionTraceInput,
|
|
910
|
-
*,
|
|
911
|
-
context: Optional[Dict[str, Any]] = None,
|
|
912
|
-
prompt_snapshot_id: Optional[str] = None,
|
|
913
|
-
graph_snapshot_id: Optional[str] = None,
|
|
914
|
-
) -> GraphGenGraphJudgeResponse:
|
|
915
|
-
"""Deprecated: use run_verifier instead."""
|
|
916
|
-
return self.run_verifier(
|
|
917
|
-
session_trace=session_trace,
|
|
918
|
-
context=context,
|
|
919
|
-
prompt_snapshot_id=prompt_snapshot_id,
|
|
920
|
-
graph_snapshot_id=graph_snapshot_id,
|
|
921
|
-
)
|
|
1039
|
+
return GraphGenGraphVerifierResponse.model_validate(resp.json())
|
|
922
1040
|
|
|
923
1041
|
def get_graph_record(
|
|
924
1042
|
self,
|