synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validation logic for verifier/rubric configuration from TOML.
|
|
3
|
+
|
|
4
|
+
This module validates and normalizes verifier/rubric config, removing all dead fields
|
|
5
|
+
and ensuring only the fields actually used by the backend are present.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import MutableMapping
|
|
11
|
+
from typing import Any, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
from pydantic import ValidationError
|
|
14
|
+
|
|
15
|
+
from .errors import InvalidRubricConfigError, InvalidVerifierConfigError
|
|
16
|
+
from .verifier_schemas import (
|
|
17
|
+
RubricConfig,
|
|
18
|
+
RubricWeightsConfig,
|
|
19
|
+
VerifierConfig,
|
|
20
|
+
VerifierOptionsConfig,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"validate_verifier_config",
|
|
25
|
+
"validate_rubric_config",
|
|
26
|
+
"extract_and_validate_verifier_rubric",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
# Dead fields that should trigger deprecation warnings
|
|
30
|
+
DEPRECATED_RUBRIC_FIELDS = {
|
|
31
|
+
"model",
|
|
32
|
+
"api_base",
|
|
33
|
+
"api_key_env",
|
|
34
|
+
"event",
|
|
35
|
+
"outcome",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
DEPRECATED_VERIFIER_FIELDS = {
|
|
39
|
+
"type",
|
|
40
|
+
"timeout_s", # Moved to verifier.options.timeout_s
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
DEPRECATED_VERIFIER_OPTIONS_FIELDS = {
|
|
44
|
+
"max_concurrency",
|
|
45
|
+
"tracks",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _reject_deprecated_fields(
|
|
50
|
+
section: str,
|
|
51
|
+
fields: set[str],
|
|
52
|
+
present_fields: set[str],
|
|
53
|
+
error_cls: type[Exception],
|
|
54
|
+
) -> None:
|
|
55
|
+
deprecated_present = fields & present_fields
|
|
56
|
+
if deprecated_present:
|
|
57
|
+
field_list = ", ".join(sorted(deprecated_present))
|
|
58
|
+
raise error_cls(
|
|
59
|
+
detail=f"[{section}] contains deprecated fields that are not supported: {field_list}."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def validate_rubric_config(config: MutableMapping[str, Any]) -> RubricConfig:
|
|
64
|
+
"""
|
|
65
|
+
Validate and normalize rubric configuration from TOML.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: Raw [rubric] section from TOML
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Validated RubricConfig instance
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
InvalidRubricConfigError: If validation fails
|
|
75
|
+
"""
|
|
76
|
+
if not config:
|
|
77
|
+
# Default: rubric disabled
|
|
78
|
+
return RubricConfig(enabled=False)
|
|
79
|
+
|
|
80
|
+
config_dict = dict(config)
|
|
81
|
+
|
|
82
|
+
_reject_deprecated_fields(
|
|
83
|
+
"rubric",
|
|
84
|
+
DEPRECATED_RUBRIC_FIELDS,
|
|
85
|
+
set(config_dict.keys()),
|
|
86
|
+
InvalidRubricConfigError,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if "event" in config_dict:
|
|
90
|
+
raise InvalidRubricConfigError(
|
|
91
|
+
detail="[rubric.event] is not supported. Use [verifier.options.rubric_overrides] instead."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if "outcome" in config_dict:
|
|
95
|
+
raise InvalidRubricConfigError(
|
|
96
|
+
detail="[rubric.outcome] is not supported. Use [verifier.options.rubric_overrides] instead."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Extract only valid fields
|
|
100
|
+
enabled = config_dict.get("enabled", False)
|
|
101
|
+
weights_dict = config_dict.get("weights", {})
|
|
102
|
+
|
|
103
|
+
# Validate using Pydantic
|
|
104
|
+
try:
|
|
105
|
+
if not isinstance(weights_dict, dict):
|
|
106
|
+
raise ValueError("[rubric.weights] must be a dictionary")
|
|
107
|
+
|
|
108
|
+
weights = RubricWeightsConfig(**weights_dict)
|
|
109
|
+
return RubricConfig(enabled=enabled, weights=weights)
|
|
110
|
+
|
|
111
|
+
except ValidationError as exc:
|
|
112
|
+
errors = []
|
|
113
|
+
for error in exc.errors():
|
|
114
|
+
loc = ".".join(str(x) for x in error["loc"])
|
|
115
|
+
msg = error["msg"]
|
|
116
|
+
errors.append(f" • rubric.{loc}: {msg}")
|
|
117
|
+
raise InvalidRubricConfigError(
|
|
118
|
+
detail="Rubric validation failed:\n" + "\n".join(errors)
|
|
119
|
+
) from exc
|
|
120
|
+
except Exception as exc:
|
|
121
|
+
raise InvalidRubricConfigError(
|
|
122
|
+
detail=f"Rubric validation failed: {exc}"
|
|
123
|
+
) from exc
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def validate_verifier_config(config: MutableMapping[str, Any]) -> Optional[VerifierConfig]:
|
|
127
|
+
"""
|
|
128
|
+
Validate and normalize verifier configuration from TOML.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
config: Raw [verifier] section from TOML
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Validated VerifierConfig instance, or None if not present
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
InvalidVerifierConfigError: If validation fails
|
|
138
|
+
"""
|
|
139
|
+
if not config:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
config_dict = dict(config)
|
|
143
|
+
|
|
144
|
+
_reject_deprecated_fields(
|
|
145
|
+
"verifier",
|
|
146
|
+
DEPRECATED_VERIFIER_FIELDS,
|
|
147
|
+
set(config_dict.keys()),
|
|
148
|
+
InvalidVerifierConfigError,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Extract verifier.options (required)
|
|
152
|
+
options_dict = config_dict.get("options")
|
|
153
|
+
if not options_dict:
|
|
154
|
+
raise InvalidVerifierConfigError(
|
|
155
|
+
detail="[verifier.options] section is required when [verifier] is present"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if not isinstance(options_dict, dict):
|
|
159
|
+
raise InvalidVerifierConfigError(
|
|
160
|
+
detail="[verifier.options] must be a dictionary"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
_reject_deprecated_fields(
|
|
164
|
+
"verifier.options",
|
|
165
|
+
DEPRECATED_VERIFIER_OPTIONS_FIELDS,
|
|
166
|
+
set(options_dict.keys()),
|
|
167
|
+
InvalidVerifierConfigError,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Validate using Pydantic
|
|
171
|
+
try:
|
|
172
|
+
options = VerifierOptionsConfig(**options_dict)
|
|
173
|
+
return VerifierConfig(options=options)
|
|
174
|
+
|
|
175
|
+
except ValidationError as exc:
|
|
176
|
+
errors = []
|
|
177
|
+
for error in exc.errors():
|
|
178
|
+
loc = ".".join(str(x) for x in error["loc"])
|
|
179
|
+
msg = error["msg"]
|
|
180
|
+
errors.append(f" • verifier.options.{loc}: {msg}")
|
|
181
|
+
raise InvalidVerifierConfigError(
|
|
182
|
+
detail="Verifier validation failed:\n" + "\n".join(errors)
|
|
183
|
+
) from exc
|
|
184
|
+
except Exception as exc:
|
|
185
|
+
raise InvalidVerifierConfigError(
|
|
186
|
+
detail=f"Verifier validation failed: {exc}"
|
|
187
|
+
) from exc
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def extract_and_validate_verifier_rubric(
|
|
191
|
+
toml_config: MutableMapping[str, Any]
|
|
192
|
+
) -> Tuple[RubricConfig, Optional[VerifierConfig]]:
|
|
193
|
+
"""
|
|
194
|
+
Extract and validate verifier/rubric config from full TOML config.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
toml_config: Full TOML configuration dict
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Tuple of (validated_rubric, validated_verifier_or_none)
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
InvalidRubricConfigError: If rubric validation fails
|
|
204
|
+
InvalidVerifierConfigError: If verifier validation fails
|
|
205
|
+
"""
|
|
206
|
+
rubric_dict = toml_config.get("rubric", {})
|
|
207
|
+
verifier_dict = toml_config.get("verifier", {})
|
|
208
|
+
|
|
209
|
+
# Validate rubric
|
|
210
|
+
rubric_config = validate_rubric_config(rubric_dict)
|
|
211
|
+
|
|
212
|
+
# Validate verifier (if present)
|
|
213
|
+
verifier_config = validate_verifier_config(verifier_dict) if verifier_dict else None
|
|
214
|
+
|
|
215
|
+
if rubric_config.enabled and not verifier_config:
|
|
216
|
+
raise InvalidVerifierConfigError(
|
|
217
|
+
detail="[rubric].enabled=true requires a [verifier] section."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if rubric_config.enabled and verifier_config:
|
|
221
|
+
weights = rubric_config.weights
|
|
222
|
+
options = verifier_config.options
|
|
223
|
+
|
|
224
|
+
if weights.event > 0 and not options.event:
|
|
225
|
+
raise InvalidVerifierConfigError(
|
|
226
|
+
detail="[rubric.weights].event > 0 requires [verifier.options].event=true."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if weights.outcome > 0 and not options.outcome:
|
|
230
|
+
raise InvalidVerifierConfigError(
|
|
231
|
+
detail="[rubric.weights].outcome > 0 requires [verifier.options].outcome=true."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return rubric_config, verifier_config
|
|
235
|
+
|
|
@@ -480,9 +480,6 @@ def fastapi_app():
|
|
|
480
480
|
data = request if isinstance(request, dict) else {}
|
|
481
481
|
env = data.get("env") if isinstance(data, dict) else {}
|
|
482
482
|
policy = data.get("policy") if isinstance(data, dict) else {}
|
|
483
|
-
ops = data.get("ops") if isinstance(data, dict) else []
|
|
484
|
-
if not isinstance(ops, list):
|
|
485
|
-
ops = []
|
|
486
483
|
env_name = (env or {}).get("env_name") or "math" # type: ignore[misc]
|
|
487
484
|
policy_cfg = (policy or {}).get("config") or {} # type: ignore[misc]
|
|
488
485
|
model = policy_cfg.get("model") # type: ignore[misc]
|
|
@@ -730,13 +727,12 @@ def fastapi_app():
|
|
|
730
727
|
],
|
|
731
728
|
"branches": {},
|
|
732
729
|
"metrics": {
|
|
733
|
-
"
|
|
734
|
-
"
|
|
730
|
+
"episode_rewards": [total_reward],
|
|
731
|
+
"reward_mean": float(total_reward),
|
|
735
732
|
"num_steps": len(steps),
|
|
736
733
|
"num_episodes": 1,
|
|
737
734
|
},
|
|
738
735
|
"aborted": False,
|
|
739
|
-
"ops_executed": len(steps),
|
|
740
736
|
}
|
|
741
737
|
|
|
742
738
|
return api
|
|
@@ -469,9 +469,6 @@ def fastapi_app():
|
|
|
469
469
|
data = request if isinstance(request, dict) else {}
|
|
470
470
|
env = data.get("env") if isinstance(data, dict) else {}
|
|
471
471
|
policy = data.get("policy") if isinstance(data, dict) else {}
|
|
472
|
-
ops = data.get("ops") if isinstance(data, dict) else []
|
|
473
|
-
if not isinstance(ops, list):
|
|
474
|
-
ops = []
|
|
475
472
|
env_name = (env or {}).get("env_name") or "math" # type: ignore[misc]
|
|
476
473
|
policy_cfg = (policy or {}).get("config") or {} # type: ignore[misc]
|
|
477
474
|
model = policy_cfg.get("model") # type: ignore[misc]
|
|
@@ -690,13 +687,12 @@ def fastapi_app():
|
|
|
690
687
|
],
|
|
691
688
|
"branches": {},
|
|
692
689
|
"metrics": {
|
|
693
|
-
"
|
|
694
|
-
"
|
|
690
|
+
"episode_rewards": [total_reward],
|
|
691
|
+
"reward_mean": float(total_reward),
|
|
695
692
|
"num_steps": len(steps),
|
|
696
693
|
"num_episodes": 1,
|
|
697
694
|
},
|
|
698
695
|
"aborted": False,
|
|
699
|
-
"ops_executed": len(steps),
|
|
700
696
|
}
|
|
701
697
|
|
|
702
698
|
return api
|
|
@@ -5,7 +5,6 @@ import inspect
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
7
|
import socket
|
|
8
|
-
import uuid
|
|
9
8
|
from collections.abc import Iterable, Sequence
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
from typing import Any, Mapping, cast
|
|
@@ -21,8 +20,6 @@ from synth_ai.sdk.task.contracts import (
|
|
|
21
20
|
RolloutMetrics,
|
|
22
21
|
RolloutRequest,
|
|
23
22
|
RolloutResponse,
|
|
24
|
-
RolloutStep,
|
|
25
|
-
RolloutTrajectory,
|
|
26
23
|
TaskInfo,
|
|
27
24
|
)
|
|
28
25
|
from synth_ai.sdk.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
@@ -34,6 +31,10 @@ from synth_ai.sdk.task.server import (
|
|
|
34
31
|
create_task_app,
|
|
35
32
|
run_task_app,
|
|
36
33
|
)
|
|
34
|
+
from synth_ai.sdk.task.trace_correlation_helpers import (
|
|
35
|
+
build_trace_payload,
|
|
36
|
+
extract_trace_correlation_id,
|
|
37
|
+
)
|
|
37
38
|
from synth_ai.sdk.task.vendors import normalize_vendor_keys
|
|
38
39
|
|
|
39
40
|
# Dataset configuration
|
|
@@ -593,60 +594,37 @@ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) ->
|
|
|
593
594
|
flush=True,
|
|
594
595
|
)
|
|
595
596
|
|
|
596
|
-
step = RolloutStep(
|
|
597
|
-
obs=observation,
|
|
598
|
-
tool_calls=tool_calls,
|
|
599
|
-
reward=reward,
|
|
600
|
-
done=True,
|
|
601
|
-
info=info_payload,
|
|
602
|
-
)
|
|
603
|
-
|
|
604
597
|
inference_url = (request.policy.config or {}).get("inference_url")
|
|
605
|
-
trajectory = RolloutTrajectory( # type: ignore[call-overload]
|
|
606
|
-
env_id=f"banking77::{sample['split']}::{sample['index']}",
|
|
607
|
-
policy_id=request.policy.policy_id or request.policy.policy_name or "policy",
|
|
608
|
-
steps=[step],
|
|
609
|
-
final={"observation": observation, "reward": reward}, # type: ignore[arg-type]
|
|
610
|
-
length=1,
|
|
611
|
-
inference_url=str(inference_url or ""),
|
|
612
|
-
)
|
|
613
598
|
|
|
614
599
|
metrics = RolloutMetrics(
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
num_steps=1,
|
|
618
|
-
num_episodes=1,
|
|
619
|
-
outcome_score=reward,
|
|
620
|
-
events_score=reward,
|
|
621
|
-
details={"correct": is_correct},
|
|
600
|
+
outcome_reward=reward,
|
|
601
|
+
details={"predicted": predicted_intent, "expected": expected_intent},
|
|
622
602
|
)
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
(
|
|
627
|
-
|
|
603
|
+
policy_config = request.policy.config or {}
|
|
604
|
+
trace_correlation_id = extract_trace_correlation_id(
|
|
605
|
+
policy_config=policy_config,
|
|
606
|
+
inference_url=str(inference_url or ""),
|
|
607
|
+
mode=request.mode,
|
|
608
|
+
)
|
|
609
|
+
trace_metadata = {
|
|
610
|
+
"env": "banking77",
|
|
611
|
+
"split": sample["split"],
|
|
612
|
+
"index": sample["index"],
|
|
613
|
+
"correct": is_correct,
|
|
614
|
+
}
|
|
615
|
+
trace_payload = build_trace_payload(
|
|
616
|
+
messages=rendered_messages,
|
|
617
|
+
response=response_json if isinstance(response_json, dict) else None,
|
|
618
|
+
correlation_id=trace_correlation_id,
|
|
619
|
+
metadata=trace_metadata,
|
|
628
620
|
)
|
|
629
|
-
if include_trace:
|
|
630
|
-
trace_payload = {
|
|
631
|
-
"session_id": str(uuid.uuid4()),
|
|
632
|
-
"events_count": 1,
|
|
633
|
-
"decision_rewards": [reward],
|
|
634
|
-
"metadata": {
|
|
635
|
-
"env": "banking77",
|
|
636
|
-
"split": sample["split"],
|
|
637
|
-
"index": sample["index"],
|
|
638
|
-
"correct": is_correct,
|
|
639
|
-
},
|
|
640
|
-
}
|
|
641
621
|
|
|
642
622
|
return RolloutResponse(
|
|
643
623
|
run_id=request.run_id,
|
|
644
|
-
trajectories=[trajectory],
|
|
645
|
-
branches={},
|
|
646
624
|
metrics=metrics,
|
|
647
|
-
aborted=False,
|
|
648
|
-
ops_executed=2,
|
|
649
625
|
trace=trace_payload,
|
|
626
|
+
trace_correlation_id=trace_correlation_id,
|
|
627
|
+
inference_url=str(inference_url or ""),
|
|
650
628
|
)
|
|
651
629
|
|
|
652
630
|
|
|
@@ -144,7 +144,7 @@ def _validate_rollout_payload(payload: Any) -> None:
|
|
|
144
144
|
if not isinstance(trajectories, list):
|
|
145
145
|
raise ValueError(
|
|
146
146
|
f"`/rollout` response field 'trajectories' must be a list, got {type(trajectories).__name__}. "
|
|
147
|
-
f"Make sure your rollout executor returns a proper RolloutResponse with a
|
|
147
|
+
f"Make sure your rollout executor returns a proper RolloutResponse with a v3 trace payload."
|
|
148
148
|
)
|
|
149
149
|
|
|
150
150
|
# Ensure trajectories list is not empty (training will fail if it's empty)
|
|
@@ -265,27 +265,27 @@ def _validate_rollout_payload(payload: Any) -> None:
|
|
|
265
265
|
)
|
|
266
266
|
|
|
267
267
|
# Metrics can be either:
|
|
268
|
-
# 1. Full RolloutMetrics with
|
|
269
|
-
# 2. Simple dict with scalar values (
|
|
270
|
-
required_metrics_fields = ["
|
|
268
|
+
# 1. Full RolloutMetrics with episode_rewards (list), reward_mean, num_steps
|
|
269
|
+
# 2. Simple dict with scalar values (episode_rewards as float, reward_mean, num_steps)
|
|
270
|
+
required_metrics_fields = ["episode_rewards", "reward_mean", "num_steps"]
|
|
271
271
|
for field in required_metrics_fields:
|
|
272
272
|
if field not in metrics:
|
|
273
273
|
raise ValueError(
|
|
274
274
|
f"`/rollout` metrics missing required field '{field}'. "
|
|
275
|
-
f"Metrics must include:
|
|
275
|
+
f"Metrics must include: episode_rewards, reward_mean, and num_steps."
|
|
276
276
|
)
|
|
277
277
|
|
|
278
|
-
# Validate types -
|
|
279
|
-
|
|
280
|
-
if not isinstance(
|
|
278
|
+
# Validate types - episode_rewards can be either a list or a scalar
|
|
279
|
+
episode_rewards = metrics.get("episode_rewards")
|
|
280
|
+
if not isinstance(episode_rewards, list | int | float):
|
|
281
281
|
raise ValueError(
|
|
282
|
-
f"`/rollout` metrics.
|
|
282
|
+
f"`/rollout` metrics.episode_rewards must be a list or number, got {type(episode_rewards).__name__}"
|
|
283
283
|
)
|
|
284
284
|
|
|
285
|
-
|
|
286
|
-
if not isinstance(
|
|
285
|
+
reward_mean = metrics.get("reward_mean")
|
|
286
|
+
if not isinstance(reward_mean, int | float):
|
|
287
287
|
raise ValueError(
|
|
288
|
-
f"`/rollout` metrics.
|
|
288
|
+
f"`/rollout` metrics.reward_mean must be a number, got {type(reward_mean).__name__}"
|
|
289
289
|
)
|
|
290
290
|
|
|
291
291
|
num_steps = metrics.get("num_steps")
|
|
@@ -388,7 +388,6 @@ def test_route_contracts(app: ASGIApp) -> None:
|
|
|
388
388
|
"assert_proxy": True, # Backend always sets this for prompt learning
|
|
389
389
|
"proxy_only": True, # Backend always sets this for prompt learning
|
|
390
390
|
},
|
|
391
|
-
"ops": ["agent", "env"], # Critical: training sends this
|
|
392
391
|
"record": {"trajectories": True},
|
|
393
392
|
"mode": "eval",
|
|
394
393
|
}
|
|
@@ -307,8 +307,8 @@ def _extract_app_id(node: ast.Call) -> str | None:
|
|
|
307
307
|
|
|
308
308
|
def _is_register_task_app_call(node: ast.Call) -> bool:
|
|
309
309
|
func = node.func
|
|
310
|
-
return (isinstance(func, ast.Name) and func.id
|
|
311
|
-
isinstance(func, ast.Attribute) and func.attr
|
|
310
|
+
return (isinstance(func, ast.Name) and func.id in {"register_task_app", "register_local_api"}) or (
|
|
311
|
+
isinstance(func, ast.Attribute) and func.attr in {"register_task_app", "register_local_api"}
|
|
312
312
|
)
|
|
313
313
|
|
|
314
314
|
|
|
@@ -316,10 +316,10 @@ def _extract_register_app_id(node: ast.Call) -> str | None:
|
|
|
316
316
|
for kw in node.keywords:
|
|
317
317
|
if kw.arg == "entry" and isinstance(kw.value, ast.Call):
|
|
318
318
|
entry_call = kw.value
|
|
319
|
-
if isinstance(entry_call.func, ast.Name) and entry_call.func.id
|
|
319
|
+
if isinstance(entry_call.func, ast.Name) and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}:
|
|
320
320
|
for entry_kw in entry_call.keywords:
|
|
321
321
|
if (
|
|
322
|
-
entry_kw.arg
|
|
322
|
+
entry_kw.arg in {"app_id", "api_id"}
|
|
323
323
|
and isinstance(entry_kw.value, ast.Constant)
|
|
324
324
|
and isinstance(entry_kw.value.value, str)
|
|
325
325
|
):
|
|
@@ -535,7 +535,7 @@ def _has_modal_support_in_file(path: Path) -> bool:
|
|
|
535
535
|
entry_call = kw.value
|
|
536
536
|
if (
|
|
537
537
|
isinstance(entry_call.func, ast.Name)
|
|
538
|
-
and entry_call.func.id
|
|
538
|
+
and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}
|
|
539
539
|
):
|
|
540
540
|
for entry_kw in entry_call.keywords:
|
|
541
541
|
if entry_kw.arg == "modal" and isinstance(entry_kw.value, ast.Call):
|
|
@@ -562,7 +562,7 @@ def _extract_modal_config_from_file(path: Path) -> ModalDeploymentConfig | None:
|
|
|
562
562
|
entry_call = kw.value
|
|
563
563
|
if (
|
|
564
564
|
isinstance(entry_call.func, ast.Name)
|
|
565
|
-
and entry_call.func.id
|
|
565
|
+
and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}
|
|
566
566
|
):
|
|
567
567
|
for entry_kw in entry_call.keywords:
|
|
568
568
|
if entry_kw.arg == "modal" and isinstance(entry_kw.value, ast.Call):
|
synth_ai/cli/lib/train_cfgs.py
CHANGED
|
@@ -6,8 +6,8 @@ from typing import Any, Dict, List, Literal, Tuple
|
|
|
6
6
|
from synth_ai.cli.lib.prompts import ctx_print
|
|
7
7
|
from synth_ai.core.paths import is_hidden_path, validate_file_type
|
|
8
8
|
|
|
9
|
-
# Train config types: prompt optimization, reinforcement learning, supervised fine-tuning,
|
|
10
|
-
TrainType = Literal["prompt", "rl", "sft", "
|
|
9
|
+
# Train config types: prompt optimization, reinforcement learning, supervised fine-tuning, graph opt, context learning
|
|
10
|
+
TrainType = Literal["prompt", "rl", "sft", "graphgen", "context_learning"]
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def get_type(config: Dict[str, Any]) -> TrainType | None:
|
|
@@ -17,9 +17,9 @@ def get_type(config: Dict[str, Any]) -> TrainType | None:
|
|
|
17
17
|
if "prompt_learning" in config:
|
|
18
18
|
return "prompt"
|
|
19
19
|
|
|
20
|
-
# Graph
|
|
21
|
-
if isinstance(config.get("graph"), dict)
|
|
22
|
-
return "
|
|
20
|
+
# Graph Opt jobs use a dedicated [graph] section.
|
|
21
|
+
if isinstance(config.get("graph"), dict):
|
|
22
|
+
return "graphgen"
|
|
23
23
|
|
|
24
24
|
algorithm = config.get("algorithm")
|
|
25
25
|
algo_type = None
|
|
@@ -221,14 +221,14 @@ def validate_rl_cfg(cfg: Dict[str, Any]) -> None:
|
|
|
221
221
|
return None
|
|
222
222
|
|
|
223
223
|
|
|
224
|
-
def
|
|
225
|
-
"""Validate a graph
|
|
224
|
+
def validate_graph_cfg(cfg: Dict[str, Any], *, path: Path) -> None:
|
|
225
|
+
"""Validate a graph opt TOML config.
|
|
226
226
|
|
|
227
227
|
Uses the SDK validator so backend and CLI stay in sync.
|
|
228
228
|
"""
|
|
229
229
|
from synth_ai.sdk.api.train.graph_validators import validate_graph_job_section
|
|
230
230
|
|
|
231
|
-
section = cfg.get("graph") or
|
|
231
|
+
section = cfg.get("graph") or {}
|
|
232
232
|
validate_graph_job_section(section, base_dir=path.parent.resolve())
|
|
233
233
|
|
|
234
234
|
|
|
@@ -262,8 +262,8 @@ def validate_train_cfg(path: Path, discovery: bool = False) -> TrainType:
|
|
|
262
262
|
validate_rl_cfg(cfg)
|
|
263
263
|
case "sft":
|
|
264
264
|
validate_sft_cfg(cfg)
|
|
265
|
-
case "
|
|
266
|
-
|
|
265
|
+
case "graphgen":
|
|
266
|
+
validate_graph_cfg(cfg, path=path)
|
|
267
267
|
print_pass()
|
|
268
268
|
|
|
269
269
|
return train_type
|
|
@@ -4,10 +4,17 @@ Commands for managing Synth task apps - local serving, Modal deployment,
|
|
|
4
4
|
validation, and discovery.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import importlib
|
|
10
|
+
|
|
7
11
|
from synth_ai.cli.task_apps.commands import (
|
|
8
12
|
AppChoice,
|
|
9
13
|
TaskAppEntryType,
|
|
14
|
+
_find_modal_executable,
|
|
15
|
+
_is_modal_shim,
|
|
10
16
|
_markov_message_from_dict,
|
|
17
|
+
_modal_command_prefix,
|
|
11
18
|
register,
|
|
12
19
|
serve_command,
|
|
13
20
|
task_app_group,
|
|
@@ -22,5 +29,9 @@ __all__ = [
|
|
|
22
29
|
"task_app_group",
|
|
23
30
|
"serve_command",
|
|
24
31
|
"register",
|
|
32
|
+
"_find_modal_executable",
|
|
33
|
+
"_is_modal_shim",
|
|
34
|
+
"_modal_command_prefix",
|
|
25
35
|
"_markov_message_from_dict",
|
|
36
|
+
"importlib",
|
|
26
37
|
]
|
|
@@ -33,7 +33,6 @@ except Exception: # pragma: no cover - fallback
|
|
|
33
33
|
import click
|
|
34
34
|
from click.exceptions import Abort
|
|
35
35
|
|
|
36
|
-
from synth_ai.cli.commands.eval import core as eval_core
|
|
37
36
|
from synth_ai.cli.commands.filter import core as filter_core
|
|
38
37
|
|
|
39
38
|
# Tracing imports - make conditional for optional dependencies
|
|
@@ -569,20 +568,20 @@ def _extract_app_id(node: ast.Call) -> str | None:
|
|
|
569
568
|
|
|
570
569
|
def _is_register_task_app_call(node: ast.Call) -> bool:
|
|
571
570
|
func = node.func
|
|
572
|
-
return (isinstance(func, ast.Name) and func.id
|
|
573
|
-
isinstance(func, ast.Attribute) and func.attr
|
|
571
|
+
return (isinstance(func, ast.Name) and func.id in {"register_task_app", "register_local_api"}) or (
|
|
572
|
+
isinstance(func, ast.Attribute) and func.attr in {"register_task_app", "register_local_api"}
|
|
574
573
|
)
|
|
575
574
|
|
|
576
575
|
|
|
577
576
|
def _extract_register_app_id(node: ast.Call) -> str | None:
|
|
578
|
-
# Look for entry=TaskAppEntry(app_id="..."
|
|
577
|
+
# Look for entry=TaskAppEntry(app_id="...") or entry=LocalAPIEntry(api_id="...")
|
|
579
578
|
for kw in node.keywords:
|
|
580
579
|
if kw.arg == "entry" and isinstance(kw.value, ast.Call):
|
|
581
580
|
entry_call = kw.value
|
|
582
|
-
if isinstance(entry_call.func, ast.Name) and entry_call.func.id
|
|
581
|
+
if isinstance(entry_call.func, ast.Name) and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}:
|
|
583
582
|
for entry_kw in entry_call.keywords:
|
|
584
583
|
if (
|
|
585
|
-
entry_kw.arg
|
|
584
|
+
entry_kw.arg in {"app_id", "api_id"}
|
|
586
585
|
and isinstance(entry_kw.value, ast.Constant)
|
|
587
586
|
and isinstance(entry_kw.value.value, str)
|
|
588
587
|
):
|
|
@@ -865,7 +864,7 @@ def _has_modal_support_in_file(path: Path) -> bool:
|
|
|
865
864
|
entry_call = kw.value
|
|
866
865
|
if (
|
|
867
866
|
isinstance(entry_call.func, ast.Name)
|
|
868
|
-
and entry_call.func.id
|
|
867
|
+
and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}
|
|
869
868
|
):
|
|
870
869
|
for entry_kw in entry_call.keywords:
|
|
871
870
|
if entry_kw.arg == "modal" and isinstance(entry_kw.value, ast.Call):
|
|
@@ -895,7 +894,7 @@ def _extract_modal_config_from_file(path: Path) -> ModalDeploymentConfigType | N
|
|
|
895
894
|
entry_call = kw.value
|
|
896
895
|
if (
|
|
897
896
|
isinstance(entry_call.func, ast.Name)
|
|
898
|
-
and entry_call.func.id
|
|
897
|
+
and entry_call.func.id in {"TaskAppEntry", "LocalAPIEntry"}
|
|
899
898
|
):
|
|
900
899
|
for entry_kw in entry_call.keywords:
|
|
901
900
|
if entry_kw.arg == "modal" and isinstance(entry_kw.value, ast.Call):
|
|
@@ -3140,14 +3139,7 @@ def fastapi_app():
|
|
|
3140
3139
|
def register(cli: click.Group) -> None:
|
|
3141
3140
|
cli.add_command(serve_command)
|
|
3142
3141
|
cli.add_command(task_app_group)
|
|
3143
|
-
cli.add_command(eval_command)
|
|
3144
3142
|
cli.add_command(filter_command)
|
|
3145
3143
|
|
|
3146
3144
|
|
|
3147
|
-
eval_command = eval_core.command
|
|
3148
|
-
|
|
3149
3145
|
filter_command = filter_core.command
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
def register_eval(cli: click.Group) -> None:
|
|
3153
|
-
cli.add_command(eval_command)
|