synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,133 +1,60 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from collections.abc import MutableMapping
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
__all__ = ["validate_eval_options"]
|
|
8
|
-
|
|
9
|
-
_SEED_RANGE = re.compile(r"^\s*(-?\d+)\s*-\s*(-?\d+)\s*$")
|
|
1
|
+
"""Validation helpers for eval options."""
|
|
10
2
|
|
|
11
|
-
|
|
12
|
-
def _coerce_bool(value: Any) -> bool:
|
|
13
|
-
if isinstance(value, str):
|
|
14
|
-
return value.strip().lower() in {"1", "true", "yes", "on"}
|
|
15
|
-
return bool(value)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _coerce_int(value: Any) -> int | None:
|
|
19
|
-
if value is None or value == "":
|
|
20
|
-
return None
|
|
21
|
-
return int(value)
|
|
3
|
+
from __future__ import annotations
|
|
22
4
|
|
|
23
5
|
|
|
24
|
-
def _parse_seeds(value:
|
|
25
|
-
if value
|
|
6
|
+
def _parse_seeds(value: str | list[int]) -> list[int]:
|
|
7
|
+
if isinstance(value, list):
|
|
8
|
+
return value
|
|
9
|
+
if not value:
|
|
26
10
|
return []
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
elif isinstance(value, list | tuple | set):
|
|
30
|
-
chunks = list(value)
|
|
31
|
-
else:
|
|
32
|
-
chunks = [value]
|
|
33
|
-
seeds: list[int] = []
|
|
34
|
-
for chunk in chunks:
|
|
35
|
-
if isinstance(chunk, int):
|
|
36
|
-
seeds.append(chunk)
|
|
37
|
-
else:
|
|
38
|
-
text = str(chunk).strip()
|
|
39
|
-
if not text:
|
|
40
|
-
continue
|
|
41
|
-
match = _SEED_RANGE.match(text)
|
|
42
|
-
if match:
|
|
43
|
-
start = int(match.group(1))
|
|
44
|
-
end = int(match.group(2))
|
|
45
|
-
if start > end:
|
|
46
|
-
raise ValueError(f"Invalid seed range '{text}': start must be <= end")
|
|
47
|
-
seeds.extend(range(start, end + 1))
|
|
48
|
-
else:
|
|
49
|
-
seeds.append(int(text))
|
|
50
|
-
return seeds
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def _normalize_metadata(value: Any) -> dict[str, str]:
|
|
54
|
-
if value is None:
|
|
55
|
-
return {}
|
|
56
|
-
if isinstance(value, MutableMapping):
|
|
57
|
-
return {str(k): str(v) for k, v in value.items()}
|
|
58
|
-
if isinstance(value, list | tuple):
|
|
59
|
-
result: dict[str, str] = {}
|
|
60
|
-
for item in value:
|
|
61
|
-
if isinstance(item, str) and "=" in item:
|
|
62
|
-
key, val = item.split("=", 1)
|
|
63
|
-
result[key.strip()] = val.strip()
|
|
64
|
-
return result
|
|
65
|
-
if isinstance(value, str) and "=" in value:
|
|
66
|
-
key, val = value.split("=", 1)
|
|
67
|
-
return {key.strip(): val.strip()}
|
|
68
|
-
return {}
|
|
11
|
+
parts = [part.strip() for part in value.split(",") if part.strip()]
|
|
12
|
+
return [int(part) for part in parts]
|
|
69
13
|
|
|
70
14
|
|
|
71
|
-
def
|
|
72
|
-
if
|
|
73
|
-
return
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
if "seeds" in result:
|
|
91
|
-
result["seeds"] = _parse_seeds(result.get("seeds"))
|
|
92
|
-
|
|
93
|
-
for field in ("max_turns", "max_llm_calls", "concurrency"):
|
|
94
|
-
try:
|
|
95
|
-
result[field] = _coerce_int(result.get(field))
|
|
96
|
-
except Exception as exc:
|
|
97
|
-
raise ValueError(f"Invalid value for {field}: {result.get(field)}") from exc
|
|
15
|
+
def _parse_metadata(values: list[str]) -> dict[str, str]:
|
|
16
|
+
if not values:
|
|
17
|
+
return {}
|
|
18
|
+
parsed: dict[str, str] = {}
|
|
19
|
+
for entry in values:
|
|
20
|
+
if "=" not in entry:
|
|
21
|
+
raise ValueError("Metadata filter must be key=value")
|
|
22
|
+
key, value = entry.split("=", 1)
|
|
23
|
+
parsed[key.strip()] = value.strip()
|
|
24
|
+
return parsed
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _parse_ops(value: str | list[str]) -> list[str]:
|
|
28
|
+
if isinstance(value, list):
|
|
29
|
+
return [str(item).strip() for item in value if str(item).strip()]
|
|
30
|
+
if not value:
|
|
31
|
+
return []
|
|
32
|
+
return [part.strip() for part in value.split(",") if part.strip()]
|
|
98
33
|
|
|
99
|
-
if result.get("max_llm_calls") is None:
|
|
100
|
-
result["max_llm_calls"] = 10
|
|
101
|
-
if result.get("concurrency") is None:
|
|
102
|
-
result["concurrency"] = 1
|
|
103
34
|
|
|
104
|
-
|
|
105
|
-
|
|
35
|
+
def validate_eval_options(options: dict[str, object]) -> dict[str, object]:
|
|
36
|
+
normalized = dict(options)
|
|
37
|
+
seeds = normalized.get("seeds") or ""
|
|
38
|
+
normalized["seeds"] = _parse_seeds(seeds) # type: ignore[arg-type]
|
|
106
39
|
|
|
107
|
-
|
|
108
|
-
|
|
40
|
+
metadata = normalized.get("metadata") or []
|
|
41
|
+
normalized["metadata"] = _parse_metadata(metadata) # type: ignore[arg-type]
|
|
109
42
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
43
|
+
for key in ("max_turns", "max_llm_calls", "concurrency"):
|
|
44
|
+
if key in normalized and normalized[key] not in (None, ""):
|
|
45
|
+
normalized[key] = int(normalized[key]) # type: ignore[arg-type]
|
|
113
46
|
|
|
114
|
-
|
|
115
|
-
|
|
47
|
+
if "ops" in normalized:
|
|
48
|
+
normalized["ops"] = _parse_ops(normalized.get("ops") or "") # type: ignore[arg-type]
|
|
116
49
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
result["trace_format"] = str(trace_format)
|
|
50
|
+
if "poll" in normalized and normalized["poll"] not in (None, ""):
|
|
51
|
+
normalized["poll"] = float(normalized["poll"]) # type: ignore[arg-type]
|
|
120
52
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
53
|
+
if "return_trace" in normalized:
|
|
54
|
+
value = str(normalized["return_trace"]).lower()
|
|
55
|
+
normalized["return_trace"] = value in ("true", "1", "yes")
|
|
124
56
|
|
|
125
|
-
|
|
126
|
-
if model is not None:
|
|
127
|
-
result["model"] = str(model)
|
|
57
|
+
return normalized
|
|
128
58
|
|
|
129
|
-
app_id = result.get("app_id")
|
|
130
|
-
if app_id is not None:
|
|
131
|
-
result["app_id"] = str(app_id)
|
|
132
59
|
|
|
133
|
-
|
|
60
|
+
__all__ = ["validate_eval_options"]
|
|
@@ -239,8 +239,8 @@ def filter_command(config_path: str) -> None:
|
|
|
239
239
|
models = set(filter_cfg.models)
|
|
240
240
|
min_official = filter_cfg.min_official_score
|
|
241
241
|
max_official = filter_cfg.max_official_score
|
|
242
|
-
|
|
243
|
-
|
|
242
|
+
min_verifier_scores = filter_cfg.min_verifier_scores
|
|
243
|
+
max_verifier_scores = filter_cfg.max_verifier_scores
|
|
244
244
|
min_created = _parse_datetime_for_trace(raw_cfg.get("min_created_at"))
|
|
245
245
|
max_created = _parse_datetime_for_trace(raw_cfg.get("max_created_at"))
|
|
246
246
|
limit = filter_cfg.limit
|
|
@@ -311,16 +311,16 @@ def filter_command(config_path: str) -> None:
|
|
|
311
311
|
elif min_official is not None:
|
|
312
312
|
continue
|
|
313
313
|
|
|
314
|
-
|
|
314
|
+
verifier_scores = metadata.get("verifier_scores") or {}
|
|
315
315
|
include = True
|
|
316
|
-
for
|
|
317
|
-
if not _score_ok(
|
|
316
|
+
for verifier_name, threshold in (min_verifier_scores or {}).items():
|
|
317
|
+
if not _score_ok(verifier_scores.get(verifier_name), threshold, None):
|
|
318
318
|
include = False
|
|
319
319
|
break
|
|
320
320
|
if not include:
|
|
321
321
|
continue
|
|
322
|
-
for
|
|
323
|
-
if not _score_ok(
|
|
322
|
+
for verifier_name, threshold in (max_verifier_scores or {}).items():
|
|
323
|
+
if not _score_ok(verifier_scores.get(verifier_name), None, threshold):
|
|
324
324
|
include = False
|
|
325
325
|
break
|
|
326
326
|
if not include:
|
|
@@ -44,8 +44,8 @@ def validate_filter_options(options: MutableMapping[str, Any]) -> MutableMapping
|
|
|
44
44
|
_coerce_list("splits")
|
|
45
45
|
_coerce_list("task_ids")
|
|
46
46
|
_coerce_list("models")
|
|
47
|
-
_coerce_dict("
|
|
48
|
-
_coerce_dict("
|
|
47
|
+
_coerce_dict("min_verifier_scores")
|
|
48
|
+
_coerce_dict("max_verifier_scores")
|
|
49
49
|
|
|
50
50
|
for duration_key in ("min_official_score", "max_official_score"):
|
|
51
51
|
value = result.get(duration_key)
|
|
@@ -18,14 +18,13 @@ import httpx
|
|
|
18
18
|
|
|
19
19
|
from synth_ai.core.tracing_v3.config import resolve_trace_db_settings
|
|
20
20
|
from synth_ai.core.tracing_v3.turso.daemon import start_sqld
|
|
21
|
-
from synth_ai.sdk.
|
|
21
|
+
from synth_ai.sdk.localapi.client import LocalAPIClient
|
|
22
22
|
from synth_ai.sdk.task.contracts import (
|
|
23
23
|
RolloutEnvSpec,
|
|
24
24
|
RolloutMode,
|
|
25
25
|
RolloutPolicySpec,
|
|
26
26
|
RolloutRecordConfig,
|
|
27
27
|
RolloutRequest,
|
|
28
|
-
RolloutSafetyConfig,
|
|
29
28
|
)
|
|
30
29
|
from synth_ai.sdk.task.validators import (
|
|
31
30
|
normalize_inference_url,
|
|
@@ -773,7 +772,7 @@ async def _run_smoke_async(
|
|
|
773
772
|
mock_backend = (mock_backend or "synthetic").strip().lower()
|
|
774
773
|
|
|
775
774
|
# Discover environment if not provided
|
|
776
|
-
async with
|
|
775
|
+
async with LocalAPIClient(base_url=base, api_key=api_key) as client:
|
|
777
776
|
# Probe basic info quickly
|
|
778
777
|
try:
|
|
779
778
|
_ = await client.health()
|
|
@@ -795,12 +794,6 @@ async def _run_smoke_async(
|
|
|
795
794
|
click.echo("Could not infer environment name; pass --env-name.", err=True)
|
|
796
795
|
return 2
|
|
797
796
|
|
|
798
|
-
# Build ops: alternating agent/env for max_steps
|
|
799
|
-
ops: list[str] = []
|
|
800
|
-
for _ in range(max_steps):
|
|
801
|
-
ops.append("agent")
|
|
802
|
-
ops.append("env")
|
|
803
|
-
|
|
804
797
|
# Inference URL: user override > preset > local mock > Synth API default
|
|
805
798
|
synth_base = (os.getenv("SYNTH_API_BASE") or os.getenv("SYNTH_BASE_URL") or "https://api.synth.run").rstrip("/")
|
|
806
799
|
# Avoid double '/api' if base already includes it
|
|
@@ -852,7 +845,6 @@ async def _run_smoke_async(
|
|
|
852
845
|
run_id=run_id,
|
|
853
846
|
env=RolloutEnvSpec(env_name=env_name, config={}, seed=i),
|
|
854
847
|
policy=RolloutPolicySpec(policy_name=policy_name, config=policy_cfg),
|
|
855
|
-
ops=ops,
|
|
856
848
|
record=RolloutRecordConfig(
|
|
857
849
|
trajectories=True,
|
|
858
850
|
logprobs=False,
|
|
@@ -861,7 +853,6 @@ async def _run_smoke_async(
|
|
|
861
853
|
trace_format=("structured" if return_trace else "compact"),
|
|
862
854
|
),
|
|
863
855
|
on_done="reset",
|
|
864
|
-
safety=RolloutSafetyConfig(max_ops=max_steps * 4, max_time_s=900.0),
|
|
865
856
|
training_session_id=None,
|
|
866
857
|
synth_base_url=synth_base,
|
|
867
858
|
mode=RolloutMode.RL,
|
|
@@ -869,7 +860,6 @@ async def _run_smoke_async(
|
|
|
869
860
|
|
|
870
861
|
try:
|
|
871
862
|
click.echo(f">> POST /rollout run_id={run_id} env={env_name} policy={policy_name} url={inference_url_with_cid}")
|
|
872
|
-
click.echo(f" ops={ops[:10]}{'...' if len(ops) > 10 else ''}")
|
|
873
863
|
response = await client.rollout(request)
|
|
874
864
|
except Exception as exc:
|
|
875
865
|
click.echo(f"Rollout[{i}:{g}] failed: {type(exc).__name__}: {exc}", err=True)
|
|
@@ -888,10 +878,10 @@ async def _run_smoke_async(
|
|
|
888
878
|
metrics = response.metrics
|
|
889
879
|
if inferred_url:
|
|
890
880
|
click.echo(f" rollout[{i}:{g}] inference_url: {inferred_url}")
|
|
891
|
-
click.echo(f" rollout[{i}:{g}] episodes={metrics.num_episodes} steps={metrics.num_steps}
|
|
881
|
+
click.echo(f" rollout[{i}:{g}] episodes={metrics.num_episodes} steps={metrics.num_steps} reward_mean={metrics.reward_mean:.4f}")
|
|
892
882
|
|
|
893
883
|
total_steps += int(metrics.num_steps)
|
|
894
|
-
if (metrics.
|
|
884
|
+
if (metrics.reward_mean or 0.0) != 0.0:
|
|
895
885
|
nonzero_returns += 1
|
|
896
886
|
if response.trace is not None and isinstance(response.trace, dict):
|
|
897
887
|
v3_traces += 1
|
|
@@ -916,8 +906,8 @@ async def _run_smoke_async(
|
|
|
916
906
|
metrics_dump = response.metrics.model_dump()
|
|
917
907
|
except Exception:
|
|
918
908
|
metrics_dump = {
|
|
919
|
-
"
|
|
920
|
-
"
|
|
909
|
+
"episode_rewards": getattr(response.metrics, "episode_rewards", None),
|
|
910
|
+
"reward_mean": getattr(response.metrics, "reward_mean", None),
|
|
921
911
|
"num_steps": getattr(response.metrics, "num_steps", None),
|
|
922
912
|
"num_episodes": getattr(response.metrics, "num_episodes", None),
|
|
923
913
|
"outcome_score": getattr(response.metrics, "outcome_score", None),
|
|
@@ -1016,7 +1006,7 @@ async def _run_smoke_async(
|
|
|
1016
1006
|
if v3_traces < successes:
|
|
1017
1007
|
click.echo(" ⚠ Some rollouts missing v3 traces (trace field)", err=True)
|
|
1018
1008
|
if total_steps == 0:
|
|
1019
|
-
click.echo(" ⚠ No steps executed; check
|
|
1009
|
+
click.echo(" ⚠ No steps executed; check policy config", err=True)
|
|
1020
1010
|
|
|
1021
1011
|
return 0
|
|
1022
1012
|
|
|
@@ -1,66 +1,3 @@
|
|
|
1
|
-
"""Status
|
|
1
|
+
"""Status command package."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import click
|
|
6
|
-
|
|
7
|
-
from .config import resolve_backend_config
|
|
8
|
-
from .subcommands.files import files_group
|
|
9
|
-
from .subcommands.jobs import jobs_group
|
|
10
|
-
from .subcommands.models import models_group
|
|
11
|
-
from .subcommands.runs import runs_group
|
|
12
|
-
from .subcommands.session import session_status_cmd
|
|
13
|
-
from .subcommands.summary import summary_command
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _attach_group(cli: click.Group, group: click.Group, name: str) -> None:
|
|
17
|
-
"""Attach the provided Click group to the CLI if not already present."""
|
|
18
|
-
if name in cli.commands:
|
|
19
|
-
return
|
|
20
|
-
cli.add_command(group, name=name)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def register(cli: click.Group) -> None:
|
|
24
|
-
"""Register all status command groups on the provided CLI root."""
|
|
25
|
-
|
|
26
|
-
@click.group(help="Inspect training jobs, models, files, and job runs.")
|
|
27
|
-
@click.option(
|
|
28
|
-
"--base-url",
|
|
29
|
-
envvar="SYNTH_STATUS_BASE_URL",
|
|
30
|
-
default=None,
|
|
31
|
-
help="Synth backend base URL (defaults to environment configuration).",
|
|
32
|
-
)
|
|
33
|
-
@click.option(
|
|
34
|
-
"--api-key",
|
|
35
|
-
envvar="SYNTH_STATUS_API_KEY",
|
|
36
|
-
default=None,
|
|
37
|
-
help="API key for authenticated requests (falls back to Synth defaults).",
|
|
38
|
-
)
|
|
39
|
-
@click.option(
|
|
40
|
-
"--timeout",
|
|
41
|
-
default=30.0,
|
|
42
|
-
show_default=True,
|
|
43
|
-
type=float,
|
|
44
|
-
help="HTTP request timeout in seconds.",
|
|
45
|
-
)
|
|
46
|
-
@click.pass_context
|
|
47
|
-
def status(ctx: click.Context, base_url: str | None, api_key: str | None, timeout: float) -> None:
|
|
48
|
-
"""Populate shared backend configuration for subcommands."""
|
|
49
|
-
cfg = resolve_backend_config(base_url=base_url, api_key=api_key, timeout=timeout)
|
|
50
|
-
ctx.ensure_object(dict)
|
|
51
|
-
ctx.obj["status_backend_config"] = cfg
|
|
52
|
-
|
|
53
|
-
status.add_command(jobs_group, name="jobs")
|
|
54
|
-
status.add_command(models_group, name="models")
|
|
55
|
-
status.add_command(files_group, name="files")
|
|
56
|
-
status.add_command(runs_group, name="runs")
|
|
57
|
-
status.add_command(session_status_cmd, name="session")
|
|
58
|
-
status.add_command(summary_command, name="summary")
|
|
59
|
-
|
|
60
|
-
cli.add_command(status, name="status")
|
|
61
|
-
_attach_group(cli, jobs_group, "jobs")
|
|
62
|
-
_attach_group(cli, models_group, "models")
|
|
63
|
-
_attach_group(cli, files_group, "files")
|
|
64
|
-
_attach_group(cli, runs_group, "runs")
|
|
65
|
-
if "status-summary" not in cli.commands:
|
|
66
|
-
cli.add_command(summary_command, name="status-summary")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""HTTP client for status commands."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
@@ -8,31 +8,43 @@ import httpx
|
|
|
8
8
|
|
|
9
9
|
from .config import BackendConfig
|
|
10
10
|
from .errors import StatusAPIError
|
|
11
|
+
from .utils import build_headers
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class StatusAPIClient:
|
|
14
|
-
"""Thin wrapper around httpx.AsyncClient with convenience methods."""
|
|
15
|
-
|
|
16
15
|
def __init__(self, config: BackendConfig) -> None:
|
|
17
16
|
self._config = config
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
17
|
+
self._client: httpx.AsyncClient | None = None
|
|
18
|
+
|
|
19
|
+
async def __aenter__(self) -> "StatusAPIClient":
|
|
20
|
+
if self._client is None:
|
|
21
|
+
self._client = httpx.AsyncClient(
|
|
22
|
+
base_url=self._config.base_url,
|
|
23
|
+
headers=build_headers(self._config.api_key),
|
|
24
|
+
timeout=self._config.timeout,
|
|
25
|
+
)
|
|
27
26
|
return self
|
|
28
27
|
|
|
29
|
-
async def __aexit__(self,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
28
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
29
|
+
if self._client is not None:
|
|
30
|
+
await self._client.aclose()
|
|
31
|
+
self._client = None
|
|
32
|
+
|
|
33
|
+
async def _get(self, path: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
34
|
+
assert self._client is not None
|
|
35
|
+
resp = await self._client.get(path, params=params)
|
|
36
|
+
if resp.status_code >= 400:
|
|
37
|
+
detail = resp.json().get("detail", "")
|
|
38
|
+
raise StatusAPIError(detail or "Request failed", status_code=resp.status_code)
|
|
39
|
+
return resp.json()
|
|
40
|
+
|
|
41
|
+
async def _post(self, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
42
|
+
assert self._client is not None
|
|
43
|
+
resp = await self._client.post(path, json=payload or {})
|
|
44
|
+
if resp.status_code >= 400:
|
|
45
|
+
detail = resp.json().get("detail", "")
|
|
46
|
+
raise StatusAPIError(detail or "Request failed", status_code=resp.status_code)
|
|
47
|
+
return resp.json()
|
|
36
48
|
|
|
37
49
|
async def list_jobs(
|
|
38
50
|
self,
|
|
@@ -43,94 +55,22 @@ class StatusAPIClient:
|
|
|
43
55
|
limit: int | None = None,
|
|
44
56
|
) -> list[dict[str, Any]]:
|
|
45
57
|
params: dict[str, Any] = {}
|
|
46
|
-
if status:
|
|
58
|
+
if status is not None:
|
|
47
59
|
params["status"] = status
|
|
48
|
-
if job_type:
|
|
60
|
+
if job_type is not None:
|
|
49
61
|
params["type"] = job_type
|
|
50
|
-
if created_after:
|
|
62
|
+
if created_after is not None:
|
|
51
63
|
params["created_after"] = created_after
|
|
52
|
-
if limit:
|
|
64
|
+
if limit is not None:
|
|
53
65
|
params["limit"] = limit
|
|
54
|
-
|
|
55
|
-
return
|
|
66
|
+
payload = await self._get("/learning/jobs", params=params or None)
|
|
67
|
+
return payload.get("jobs", [])
|
|
56
68
|
|
|
57
69
|
async def get_job(self, job_id: str) -> dict[str, Any]:
|
|
58
|
-
|
|
59
|
-
return self._json(resp)
|
|
60
|
-
|
|
61
|
-
async def get_job_status(self, job_id: str) -> dict[str, Any]:
|
|
62
|
-
resp = await self._client.get(f"/learning/jobs/{job_id}/status")
|
|
63
|
-
return self._json(resp)
|
|
70
|
+
return await self._get(f"/learning/jobs/{job_id}")
|
|
64
71
|
|
|
65
72
|
async def cancel_job(self, job_id: str) -> dict[str, Any]:
|
|
66
|
-
|
|
67
|
-
return self._json(resp)
|
|
68
|
-
|
|
69
|
-
async def get_job_config(self, job_id: str) -> dict[str, Any]:
|
|
70
|
-
resp = await self._client.get(f"/learning/jobs/{job_id}/config")
|
|
71
|
-
return self._json(resp)
|
|
72
|
-
|
|
73
|
-
async def get_job_metrics(self, job_id: str) -> dict[str, Any]:
|
|
74
|
-
resp = await self._client.get(f"/learning/jobs/{job_id}/metrics")
|
|
75
|
-
return self._json(resp)
|
|
76
|
-
|
|
77
|
-
async def get_job_timeline(self, job_id: str) -> list[dict[str, Any]]:
|
|
78
|
-
resp = await self._client.get(f"/learning/jobs/{job_id}/timeline")
|
|
79
|
-
return self._json_list(resp, key="timeline")
|
|
80
|
-
|
|
81
|
-
async def list_job_runs(self, job_id: str) -> list[dict[str, Any]]:
|
|
82
|
-
resp = await self._client.get(f"/jobs/{job_id}/runs")
|
|
83
|
-
return self._json_list(resp, key="runs")
|
|
84
|
-
|
|
85
|
-
async def get_job_events(
|
|
86
|
-
self,
|
|
87
|
-
job_id: str,
|
|
88
|
-
*,
|
|
89
|
-
since: str | None = None,
|
|
90
|
-
limit: int | None = None,
|
|
91
|
-
after: str | None = None,
|
|
92
|
-
run_id: str | None = None,
|
|
93
|
-
) -> list[dict[str, Any]]:
|
|
94
|
-
params: dict[str, Any] = {}
|
|
95
|
-
if since:
|
|
96
|
-
params["since"] = since
|
|
97
|
-
if limit:
|
|
98
|
-
params["limit"] = limit
|
|
99
|
-
if after:
|
|
100
|
-
params["after"] = after
|
|
101
|
-
if run_id:
|
|
102
|
-
params["run"] = run_id
|
|
103
|
-
resp = await self._client.get(f"/learning/jobs/{job_id}/events", params=params)
|
|
104
|
-
return self._json_list(resp, key="events")
|
|
105
|
-
|
|
106
|
-
# Files ----------------------------------------------------------------
|
|
107
|
-
|
|
108
|
-
async def list_files(
|
|
109
|
-
self,
|
|
110
|
-
*,
|
|
111
|
-
purpose: str | None = None,
|
|
112
|
-
limit: int | None = None,
|
|
113
|
-
) -> list[dict[str, Any]]:
|
|
114
|
-
params: dict[str, Any] = {}
|
|
115
|
-
if purpose:
|
|
116
|
-
params["purpose"] = purpose
|
|
117
|
-
if limit:
|
|
118
|
-
params["limit"] = limit
|
|
119
|
-
resp = await self._client.get("/files", params=params)
|
|
120
|
-
data = self._json(resp)
|
|
121
|
-
if isinstance(data, dict):
|
|
122
|
-
for key in ("files", "data", "items"):
|
|
123
|
-
if isinstance(data.get(key), list):
|
|
124
|
-
return list(data[key])
|
|
125
|
-
if isinstance(data, list):
|
|
126
|
-
return list(data)
|
|
127
|
-
return []
|
|
128
|
-
|
|
129
|
-
async def get_file(self, file_id: str) -> dict[str, Any]:
|
|
130
|
-
resp = await self._client.get(f"/files/{file_id}")
|
|
131
|
-
return self._json(resp)
|
|
132
|
-
|
|
133
|
-
# Models ---------------------------------------------------------------
|
|
73
|
+
return await self._post(f"/learning/jobs/{job_id}/cancel")
|
|
134
74
|
|
|
135
75
|
async def list_models(
|
|
136
76
|
self,
|
|
@@ -138,55 +78,14 @@ class StatusAPIClient:
|
|
|
138
78
|
limit: int | None = None,
|
|
139
79
|
model_type: str | None = None,
|
|
140
80
|
) -> list[dict[str, Any]]:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
async def
|
|
149
|
-
|
|
150
|
-
return
|
|
151
|
-
|
|
152
|
-
# Helpers --------------------------------------------------------------
|
|
153
|
-
|
|
154
|
-
def _json(self, response: httpx.Response) -> dict[str, Any]:
|
|
155
|
-
try:
|
|
156
|
-
response.raise_for_status()
|
|
157
|
-
except httpx.HTTPStatusError as exc:
|
|
158
|
-
detail = self._extract_detail(exc.response)
|
|
159
|
-
raise StatusAPIError(detail, exc.response.status_code if exc.response else None) from exc
|
|
160
|
-
try:
|
|
161
|
-
data = response.json()
|
|
162
|
-
except ValueError as exc:
|
|
163
|
-
raise StatusAPIError("Backend response was not valid JSON") from exc
|
|
164
|
-
if isinstance(data, dict):
|
|
165
|
-
return data
|
|
166
|
-
return {"data": data}
|
|
167
|
-
|
|
168
|
-
def _json_list(self, response: httpx.Response, *, key: str | None = None) -> list[dict[str, Any]]:
|
|
169
|
-
payload = self._json(response)
|
|
170
|
-
if key and isinstance(payload.get(key), list):
|
|
171
|
-
return list(payload[key])
|
|
172
|
-
if isinstance(payload.get("data"), list):
|
|
173
|
-
return list(payload["data"])
|
|
174
|
-
if isinstance(payload.get("results"), list):
|
|
175
|
-
return list(payload["results"])
|
|
176
|
-
if isinstance(payload, list):
|
|
177
|
-
return list(payload)
|
|
178
|
-
return []
|
|
179
|
-
|
|
180
|
-
@staticmethod
|
|
181
|
-
def _extract_detail(response: httpx.Response | None) -> str:
|
|
182
|
-
if response is None:
|
|
183
|
-
return "Backend request failed"
|
|
184
|
-
try:
|
|
185
|
-
data = response.json()
|
|
186
|
-
if isinstance(data, dict):
|
|
187
|
-
for key in ("detail", "message", "error"):
|
|
188
|
-
if data.get(key):
|
|
189
|
-
return str(data[key])
|
|
190
|
-
return response.text
|
|
191
|
-
except ValueError:
|
|
192
|
-
return response.text
|
|
81
|
+
if model_type:
|
|
82
|
+
payload = await self._get(f"/learning/models/{model_type}")
|
|
83
|
+
return payload.get("models", [])
|
|
84
|
+
params = {"limit": limit} if limit is not None else None
|
|
85
|
+
payload = await self._get("/learning/models", params=params)
|
|
86
|
+
return payload.get("models", [])
|
|
87
|
+
|
|
88
|
+
async def list_job_runs(self, job_id: str) -> list[dict[str, Any]]:
|
|
89
|
+
payload = await self._get(f"/jobs/{job_id}/runs")
|
|
90
|
+
return payload.get("runs", [])
|
|
91
|
+
|