synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (153) hide show
  1. synth_ai/__init__.py +13 -13
  2. synth_ai/cli/__init__.py +6 -15
  3. synth_ai/cli/commands/eval/__init__.py +6 -15
  4. synth_ai/cli/commands/eval/config.py +338 -0
  5. synth_ai/cli/commands/eval/core.py +236 -1091
  6. synth_ai/cli/commands/eval/runner.py +704 -0
  7. synth_ai/cli/commands/eval/validation.py +44 -117
  8. synth_ai/cli/commands/filter/core.py +7 -7
  9. synth_ai/cli/commands/filter/validation.py +2 -2
  10. synth_ai/cli/commands/smoke/core.py +7 -17
  11. synth_ai/cli/commands/status/__init__.py +1 -64
  12. synth_ai/cli/commands/status/client.py +50 -151
  13. synth_ai/cli/commands/status/config.py +3 -83
  14. synth_ai/cli/commands/status/errors.py +4 -13
  15. synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
  16. synth_ai/cli/commands/status/subcommands/config.py +13 -0
  17. synth_ai/cli/commands/status/subcommands/files.py +18 -63
  18. synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
  19. synth_ai/cli/commands/status/subcommands/models.py +18 -62
  20. synth_ai/cli/commands/status/subcommands/runs.py +16 -63
  21. synth_ai/cli/commands/status/subcommands/session.py +67 -172
  22. synth_ai/cli/commands/status/subcommands/summary.py +24 -32
  23. synth_ai/cli/commands/status/subcommands/utils.py +41 -0
  24. synth_ai/cli/commands/status/utils.py +16 -107
  25. synth_ai/cli/commands/train/__init__.py +18 -20
  26. synth_ai/cli/commands/train/errors.py +3 -3
  27. synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
  28. synth_ai/cli/commands/train/validation.py +7 -7
  29. synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
  30. synth_ai/cli/commands/train/verifier_validation.py +235 -0
  31. synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
  32. synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
  33. synth_ai/cli/demo_apps/math/config.toml +0 -1
  34. synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
  35. synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
  36. synth_ai/cli/lib/apps/task_app.py +12 -13
  37. synth_ai/cli/lib/task_app_discovery.py +6 -6
  38. synth_ai/cli/lib/train_cfgs.py +10 -10
  39. synth_ai/cli/task_apps/__init__.py +11 -0
  40. synth_ai/cli/task_apps/commands.py +7 -15
  41. synth_ai/core/env.py +12 -1
  42. synth_ai/core/errors.py +1 -2
  43. synth_ai/core/integrations/cloudflare.py +209 -33
  44. synth_ai/core/tracing_v3/abstractions.py +46 -0
  45. synth_ai/data/__init__.py +3 -30
  46. synth_ai/data/enums.py +1 -20
  47. synth_ai/data/rewards.py +100 -3
  48. synth_ai/products/graph_evolve/__init__.py +1 -2
  49. synth_ai/products/graph_evolve/config.py +16 -16
  50. synth_ai/products/graph_evolve/converters/__init__.py +3 -3
  51. synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
  52. synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
  53. synth_ai/products/graph_gepa/__init__.py +23 -0
  54. synth_ai/products/graph_gepa/converters/__init__.py +19 -0
  55. synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
  56. synth_ai/sdk/__init__.py +45 -35
  57. synth_ai/sdk/api/eval/__init__.py +33 -0
  58. synth_ai/sdk/api/eval/job.py +732 -0
  59. synth_ai/sdk/api/research_agent/__init__.py +276 -66
  60. synth_ai/sdk/api/train/builders.py +181 -0
  61. synth_ai/sdk/api/train/cli.py +41 -33
  62. synth_ai/sdk/api/train/configs/__init__.py +6 -4
  63. synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
  64. synth_ai/sdk/api/train/configs/rl.py +264 -16
  65. synth_ai/sdk/api/train/configs/sft.py +165 -1
  66. synth_ai/sdk/api/train/graph_validators.py +12 -12
  67. synth_ai/sdk/api/train/graphgen.py +169 -51
  68. synth_ai/sdk/api/train/graphgen_models.py +95 -45
  69. synth_ai/sdk/api/train/local_api.py +10 -0
  70. synth_ai/sdk/api/train/pollers.py +36 -0
  71. synth_ai/sdk/api/train/prompt_learning.py +390 -60
  72. synth_ai/sdk/api/train/rl.py +41 -5
  73. synth_ai/sdk/api/train/sft.py +2 -0
  74. synth_ai/sdk/api/train/task_app.py +20 -0
  75. synth_ai/sdk/api/train/validators.py +17 -17
  76. synth_ai/sdk/graphs/completions.py +239 -33
  77. synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
  78. synth_ai/sdk/learning/__init__.py +35 -5
  79. synth_ai/sdk/learning/context_learning_client.py +531 -0
  80. synth_ai/sdk/learning/context_learning_types.py +294 -0
  81. synth_ai/sdk/learning/prompt_learning_client.py +1 -1
  82. synth_ai/sdk/learning/prompt_learning_types.py +2 -1
  83. synth_ai/sdk/learning/rl/__init__.py +0 -4
  84. synth_ai/sdk/learning/rl/contracts.py +0 -4
  85. synth_ai/sdk/localapi/__init__.py +40 -0
  86. synth_ai/sdk/localapi/apps/__init__.py +28 -0
  87. synth_ai/sdk/localapi/client.py +10 -0
  88. synth_ai/sdk/localapi/contracts.py +10 -0
  89. synth_ai/sdk/localapi/helpers.py +519 -0
  90. synth_ai/sdk/localapi/rollouts.py +93 -0
  91. synth_ai/sdk/localapi/server.py +29 -0
  92. synth_ai/sdk/localapi/template.py +49 -0
  93. synth_ai/sdk/streaming/handlers.py +6 -6
  94. synth_ai/sdk/streaming/streamer.py +10 -6
  95. synth_ai/sdk/task/__init__.py +18 -5
  96. synth_ai/sdk/task/apps/__init__.py +37 -1
  97. synth_ai/sdk/task/client.py +9 -1
  98. synth_ai/sdk/task/config.py +6 -11
  99. synth_ai/sdk/task/contracts.py +137 -95
  100. synth_ai/sdk/task/in_process.py +32 -22
  101. synth_ai/sdk/task/in_process_runner.py +9 -4
  102. synth_ai/sdk/task/rubrics/__init__.py +2 -3
  103. synth_ai/sdk/task/rubrics/loaders.py +4 -4
  104. synth_ai/sdk/task/rubrics/strict.py +3 -4
  105. synth_ai/sdk/task/server.py +76 -16
  106. synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
  107. synth_ai/sdk/task/validators.py +34 -49
  108. synth_ai/sdk/training/__init__.py +7 -16
  109. synth_ai/sdk/tunnels/__init__.py +118 -0
  110. synth_ai/sdk/tunnels/cleanup.py +83 -0
  111. synth_ai/sdk/tunnels/ports.py +120 -0
  112. synth_ai/sdk/tunnels/tunneled_api.py +363 -0
  113. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
  114. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
  115. synth_ai/cli/commands/baseline/__init__.py +0 -12
  116. synth_ai/cli/commands/baseline/core.py +0 -636
  117. synth_ai/cli/commands/baseline/list.py +0 -94
  118. synth_ai/cli/commands/eval/errors.py +0 -81
  119. synth_ai/cli/commands/status/formatters.py +0 -164
  120. synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
  121. synth_ai/cli/commands/status/subcommands/usage.py +0 -203
  122. synth_ai/cli/commands/train/judge_validation.py +0 -305
  123. synth_ai/cli/usage.py +0 -159
  124. synth_ai/data/specs.py +0 -36
  125. synth_ai/sdk/api/research_agent/cli.py +0 -428
  126. synth_ai/sdk/api/research_agent/config.py +0 -357
  127. synth_ai/sdk/api/research_agent/job.py +0 -717
  128. synth_ai/sdk/baseline/__init__.py +0 -25
  129. synth_ai/sdk/baseline/config.py +0 -209
  130. synth_ai/sdk/baseline/discovery.py +0 -216
  131. synth_ai/sdk/baseline/execution.py +0 -154
  132. synth_ai/sdk/judging/__init__.py +0 -15
  133. synth_ai/sdk/judging/base.py +0 -24
  134. synth_ai/sdk/judging/client.py +0 -191
  135. synth_ai/sdk/judging/types.py +0 -42
  136. synth_ai/sdk/research_agent/__init__.py +0 -34
  137. synth_ai/sdk/research_agent/container_builder.py +0 -328
  138. synth_ai/sdk/research_agent/container_spec.py +0 -198
  139. synth_ai/sdk/research_agent/defaults.py +0 -34
  140. synth_ai/sdk/research_agent/results_collector.py +0 -69
  141. synth_ai/sdk/specs/__init__.py +0 -46
  142. synth_ai/sdk/specs/dataclasses.py +0 -149
  143. synth_ai/sdk/specs/loader.py +0 -144
  144. synth_ai/sdk/specs/serializer.py +0 -199
  145. synth_ai/sdk/specs/validation.py +0 -250
  146. synth_ai/sdk/tracing/__init__.py +0 -39
  147. synth_ai/sdk/usage/__init__.py +0 -37
  148. synth_ai/sdk/usage/client.py +0 -171
  149. synth_ai/sdk/usage/models.py +0 -261
  150. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
  151. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
  152. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
  153. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,133 +1,60 @@
1
- from __future__ import annotations
2
-
3
- import re
4
- from collections.abc import MutableMapping
5
- from typing import Any
6
-
7
- __all__ = ["validate_eval_options"]
8
-
9
- _SEED_RANGE = re.compile(r"^\s*(-?\d+)\s*-\s*(-?\d+)\s*$")
1
+ """Validation helpers for eval options."""
10
2
 
11
-
12
- def _coerce_bool(value: Any) -> bool:
13
- if isinstance(value, str):
14
- return value.strip().lower() in {"1", "true", "yes", "on"}
15
- return bool(value)
16
-
17
-
18
- def _coerce_int(value: Any) -> int | None:
19
- if value is None or value == "":
20
- return None
21
- return int(value)
3
+ from __future__ import annotations
22
4
 
23
5
 
24
- def _parse_seeds(value: Any) -> list[int]:
25
- if value is None:
6
+ def _parse_seeds(value: str | list[int]) -> list[int]:
7
+ if isinstance(value, list):
8
+ return value
9
+ if not value:
26
10
  return []
27
- if isinstance(value, str):
28
- chunks = [chunk.strip() for chunk in value.split(",") if chunk.strip()]
29
- elif isinstance(value, list | tuple | set):
30
- chunks = list(value)
31
- else:
32
- chunks = [value]
33
- seeds: list[int] = []
34
- for chunk in chunks:
35
- if isinstance(chunk, int):
36
- seeds.append(chunk)
37
- else:
38
- text = str(chunk).strip()
39
- if not text:
40
- continue
41
- match = _SEED_RANGE.match(text)
42
- if match:
43
- start = int(match.group(1))
44
- end = int(match.group(2))
45
- if start > end:
46
- raise ValueError(f"Invalid seed range '{text}': start must be <= end")
47
- seeds.extend(range(start, end + 1))
48
- else:
49
- seeds.append(int(text))
50
- return seeds
51
-
52
-
53
- def _normalize_metadata(value: Any) -> dict[str, str]:
54
- if value is None:
55
- return {}
56
- if isinstance(value, MutableMapping):
57
- return {str(k): str(v) for k, v in value.items()}
58
- if isinstance(value, list | tuple):
59
- result: dict[str, str] = {}
60
- for item in value:
61
- if isinstance(item, str) and "=" in item:
62
- key, val = item.split("=", 1)
63
- result[key.strip()] = val.strip()
64
- return result
65
- if isinstance(value, str) and "=" in value:
66
- key, val = value.split("=", 1)
67
- return {key.strip(): val.strip()}
68
- return {}
11
+ parts = [part.strip() for part in value.split(",") if part.strip()]
12
+ return [int(part) for part in parts]
69
13
 
70
14
 
71
- def _ensure_list(value: Any) -> list[str] | None:
72
- if value is None:
73
- return None
74
- if isinstance(value, list | tuple | set):
75
- return [str(item) for item in value]
76
- return [str(value)]
77
-
78
-
79
- def _ensure_dict(value: Any) -> dict[str, Any]:
80
- if isinstance(value, MutableMapping):
81
- return dict(value)
82
- return {}
83
-
84
-
85
- def validate_eval_options(options: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
86
- """Validate and normalise eval configuration options."""
87
-
88
- result: dict[str, Any] = dict(options)
89
-
90
- if "seeds" in result:
91
- result["seeds"] = _parse_seeds(result.get("seeds"))
92
-
93
- for field in ("max_turns", "max_llm_calls", "concurrency"):
94
- try:
95
- result[field] = _coerce_int(result.get(field))
96
- except Exception as exc:
97
- raise ValueError(f"Invalid value for {field}: {result.get(field)}") from exc
15
+ def _parse_metadata(values: list[str]) -> dict[str, str]:
16
+ if not values:
17
+ return {}
18
+ parsed: dict[str, str] = {}
19
+ for entry in values:
20
+ if "=" not in entry:
21
+ raise ValueError("Metadata filter must be key=value")
22
+ key, value = entry.split("=", 1)
23
+ parsed[key.strip()] = value.strip()
24
+ return parsed
25
+
26
+
27
+ def _parse_ops(value: str | list[str]) -> list[str]:
28
+ if isinstance(value, list):
29
+ return [str(item).strip() for item in value if str(item).strip()]
30
+ if not value:
31
+ return []
32
+ return [part.strip() for part in value.split(",") if part.strip()]
98
33
 
99
- if result.get("max_llm_calls") is None:
100
- result["max_llm_calls"] = 10
101
- if result.get("concurrency") is None:
102
- result["concurrency"] = 1
103
34
 
104
- if "return_trace" in result:
105
- result["return_trace"] = _coerce_bool(result.get("return_trace"))
35
+ def validate_eval_options(options: dict[str, object]) -> dict[str, object]:
36
+ normalized = dict(options)
37
+ seeds = normalized.get("seeds") or ""
38
+ normalized["seeds"] = _parse_seeds(seeds) # type: ignore[arg-type]
106
39
 
107
- metadata_value = result.get("metadata")
108
- result["metadata"] = _normalize_metadata(metadata_value)
40
+ metadata = normalized.get("metadata") or []
41
+ normalized["metadata"] = _parse_metadata(metadata) # type: ignore[arg-type]
109
42
 
110
- if "ops" in result:
111
- ops_list = _ensure_list(result.get("ops"))
112
- result["ops"] = ops_list
43
+ for key in ("max_turns", "max_llm_calls", "concurrency"):
44
+ if key in normalized and normalized[key] not in (None, ""):
45
+ normalized[key] = int(normalized[key]) # type: ignore[arg-type]
113
46
 
114
- result["env_config"] = _ensure_dict(result.get("env_config"))
115
- result["policy_config"] = _ensure_dict(result.get("policy_config"))
47
+ if "ops" in normalized:
48
+ normalized["ops"] = _parse_ops(normalized.get("ops") or "") # type: ignore[arg-type]
116
49
 
117
- trace_format = result.get("trace_format")
118
- if trace_format is not None:
119
- result["trace_format"] = str(trace_format)
50
+ if "poll" in normalized and normalized["poll"] not in (None, ""):
51
+ normalized["poll"] = float(normalized["poll"]) # type: ignore[arg-type]
120
52
 
121
- metadata_sql = result.get("metadata_sql")
122
- if metadata_sql is not None and not isinstance(metadata_sql, str):
123
- result["metadata_sql"] = str(metadata_sql)
53
+ if "return_trace" in normalized:
54
+ value = str(normalized["return_trace"]).lower()
55
+ normalized["return_trace"] = value in ("true", "1", "yes")
124
56
 
125
- model = result.get("model")
126
- if model is not None:
127
- result["model"] = str(model)
57
+ return normalized
128
58
 
129
- app_id = result.get("app_id")
130
- if app_id is not None:
131
- result["app_id"] = str(app_id)
132
59
 
133
- return result
60
+ __all__ = ["validate_eval_options"]
@@ -239,8 +239,8 @@ def filter_command(config_path: str) -> None:
239
239
  models = set(filter_cfg.models)
240
240
  min_official = filter_cfg.min_official_score
241
241
  max_official = filter_cfg.max_official_score
242
- min_judge_scores = filter_cfg.min_judge_scores
243
- max_judge_scores = filter_cfg.max_judge_scores
242
+ min_verifier_scores = filter_cfg.min_verifier_scores
243
+ max_verifier_scores = filter_cfg.max_verifier_scores
244
244
  min_created = _parse_datetime_for_trace(raw_cfg.get("min_created_at"))
245
245
  max_created = _parse_datetime_for_trace(raw_cfg.get("max_created_at"))
246
246
  limit = filter_cfg.limit
@@ -311,16 +311,16 @@ def filter_command(config_path: str) -> None:
311
311
  elif min_official is not None:
312
312
  continue
313
313
 
314
- judge_scores = metadata.get("judge_scores") or {}
314
+ verifier_scores = metadata.get("verifier_scores") or {}
315
315
  include = True
316
- for judge_name, threshold in (min_judge_scores or {}).items():
317
- if not _score_ok(judge_scores.get(judge_name), threshold, None):
316
+ for verifier_name, threshold in (min_verifier_scores or {}).items():
317
+ if not _score_ok(verifier_scores.get(verifier_name), threshold, None):
318
318
  include = False
319
319
  break
320
320
  if not include:
321
321
  continue
322
- for judge_name, threshold in (max_judge_scores or {}).items():
323
- if not _score_ok(judge_scores.get(judge_name), None, threshold):
322
+ for verifier_name, threshold in (max_verifier_scores or {}).items():
323
+ if not _score_ok(verifier_scores.get(verifier_name), None, threshold):
324
324
  include = False
325
325
  break
326
326
  if not include:
@@ -44,8 +44,8 @@ def validate_filter_options(options: MutableMapping[str, Any]) -> MutableMapping
44
44
  _coerce_list("splits")
45
45
  _coerce_list("task_ids")
46
46
  _coerce_list("models")
47
- _coerce_dict("min_judge_scores")
48
- _coerce_dict("max_judge_scores")
47
+ _coerce_dict("min_verifier_scores")
48
+ _coerce_dict("max_verifier_scores")
49
49
 
50
50
  for duration_key in ("min_official_score", "max_official_score"):
51
51
  value = result.get(duration_key)
@@ -18,14 +18,13 @@ import httpx
18
18
 
19
19
  from synth_ai.core.tracing_v3.config import resolve_trace_db_settings
20
20
  from synth_ai.core.tracing_v3.turso.daemon import start_sqld
21
- from synth_ai.sdk.task.client import TaskAppClient
21
+ from synth_ai.sdk.localapi.client import LocalAPIClient
22
22
  from synth_ai.sdk.task.contracts import (
23
23
  RolloutEnvSpec,
24
24
  RolloutMode,
25
25
  RolloutPolicySpec,
26
26
  RolloutRecordConfig,
27
27
  RolloutRequest,
28
- RolloutSafetyConfig,
29
28
  )
30
29
  from synth_ai.sdk.task.validators import (
31
30
  normalize_inference_url,
@@ -773,7 +772,7 @@ async def _run_smoke_async(
773
772
  mock_backend = (mock_backend or "synthetic").strip().lower()
774
773
 
775
774
  # Discover environment if not provided
776
- async with TaskAppClient(base_url=base, api_key=api_key) as client:
775
+ async with LocalAPIClient(base_url=base, api_key=api_key) as client:
777
776
  # Probe basic info quickly
778
777
  try:
779
778
  _ = await client.health()
@@ -795,12 +794,6 @@ async def _run_smoke_async(
795
794
  click.echo("Could not infer environment name; pass --env-name.", err=True)
796
795
  return 2
797
796
 
798
- # Build ops: alternating agent/env for max_steps
799
- ops: list[str] = []
800
- for _ in range(max_steps):
801
- ops.append("agent")
802
- ops.append("env")
803
-
804
797
  # Inference URL: user override > preset > local mock > Synth API default
805
798
  synth_base = (os.getenv("SYNTH_API_BASE") or os.getenv("SYNTH_BASE_URL") or "https://api.synth.run").rstrip("/")
806
799
  # Avoid double '/api' if base already includes it
@@ -852,7 +845,6 @@ async def _run_smoke_async(
852
845
  run_id=run_id,
853
846
  env=RolloutEnvSpec(env_name=env_name, config={}, seed=i),
854
847
  policy=RolloutPolicySpec(policy_name=policy_name, config=policy_cfg),
855
- ops=ops,
856
848
  record=RolloutRecordConfig(
857
849
  trajectories=True,
858
850
  logprobs=False,
@@ -861,7 +853,6 @@ async def _run_smoke_async(
861
853
  trace_format=("structured" if return_trace else "compact"),
862
854
  ),
863
855
  on_done="reset",
864
- safety=RolloutSafetyConfig(max_ops=max_steps * 4, max_time_s=900.0),
865
856
  training_session_id=None,
866
857
  synth_base_url=synth_base,
867
858
  mode=RolloutMode.RL,
@@ -869,7 +860,6 @@ async def _run_smoke_async(
869
860
 
870
861
  try:
871
862
  click.echo(f">> POST /rollout run_id={run_id} env={env_name} policy={policy_name} url={inference_url_with_cid}")
872
- click.echo(f" ops={ops[:10]}{'...' if len(ops) > 10 else ''}")
873
863
  response = await client.rollout(request)
874
864
  except Exception as exc:
875
865
  click.echo(f"Rollout[{i}:{g}] failed: {type(exc).__name__}: {exc}", err=True)
@@ -888,10 +878,10 @@ async def _run_smoke_async(
888
878
  metrics = response.metrics
889
879
  if inferred_url:
890
880
  click.echo(f" rollout[{i}:{g}] inference_url: {inferred_url}")
891
- click.echo(f" rollout[{i}:{g}] episodes={metrics.num_episodes} steps={metrics.num_steps} mean_return={metrics.mean_return:.4f}")
881
+ click.echo(f" rollout[{i}:{g}] episodes={metrics.num_episodes} steps={metrics.num_steps} reward_mean={metrics.reward_mean:.4f}")
892
882
 
893
883
  total_steps += int(metrics.num_steps)
894
- if (metrics.mean_return or 0.0) != 0.0:
884
+ if (metrics.reward_mean or 0.0) != 0.0:
895
885
  nonzero_returns += 1
896
886
  if response.trace is not None and isinstance(response.trace, dict):
897
887
  v3_traces += 1
@@ -916,8 +906,8 @@ async def _run_smoke_async(
916
906
  metrics_dump = response.metrics.model_dump()
917
907
  except Exception:
918
908
  metrics_dump = {
919
- "episode_returns": getattr(response.metrics, "episode_returns", None),
920
- "mean_return": getattr(response.metrics, "mean_return", None),
909
+ "episode_rewards": getattr(response.metrics, "episode_rewards", None),
910
+ "reward_mean": getattr(response.metrics, "reward_mean", None),
921
911
  "num_steps": getattr(response.metrics, "num_steps", None),
922
912
  "num_episodes": getattr(response.metrics, "num_episodes", None),
923
913
  "outcome_score": getattr(response.metrics, "outcome_score", None),
@@ -1016,7 +1006,7 @@ async def _run_smoke_async(
1016
1006
  if v3_traces < successes:
1017
1007
  click.echo(" ⚠ Some rollouts missing v3 traces (trace field)", err=True)
1018
1008
  if total_steps == 0:
1019
- click.echo(" ⚠ No steps executed; check ops/policy config", err=True)
1009
+ click.echo(" ⚠ No steps executed; check policy config", err=True)
1020
1010
 
1021
1011
  return 0
1022
1012
 
@@ -1,66 +1,3 @@
1
- """Status and listing commands for the Synth CLI."""
1
+ """Status command package."""
2
2
 
3
3
  from __future__ import annotations
4
-
5
- import click
6
-
7
- from .config import resolve_backend_config
8
- from .subcommands.files import files_group
9
- from .subcommands.jobs import jobs_group
10
- from .subcommands.models import models_group
11
- from .subcommands.runs import runs_group
12
- from .subcommands.session import session_status_cmd
13
- from .subcommands.summary import summary_command
14
-
15
-
16
- def _attach_group(cli: click.Group, group: click.Group, name: str) -> None:
17
- """Attach the provided Click group to the CLI if not already present."""
18
- if name in cli.commands:
19
- return
20
- cli.add_command(group, name=name)
21
-
22
-
23
- def register(cli: click.Group) -> None:
24
- """Register all status command groups on the provided CLI root."""
25
-
26
- @click.group(help="Inspect training jobs, models, files, and job runs.")
27
- @click.option(
28
- "--base-url",
29
- envvar="SYNTH_STATUS_BASE_URL",
30
- default=None,
31
- help="Synth backend base URL (defaults to environment configuration).",
32
- )
33
- @click.option(
34
- "--api-key",
35
- envvar="SYNTH_STATUS_API_KEY",
36
- default=None,
37
- help="API key for authenticated requests (falls back to Synth defaults).",
38
- )
39
- @click.option(
40
- "--timeout",
41
- default=30.0,
42
- show_default=True,
43
- type=float,
44
- help="HTTP request timeout in seconds.",
45
- )
46
- @click.pass_context
47
- def status(ctx: click.Context, base_url: str | None, api_key: str | None, timeout: float) -> None:
48
- """Populate shared backend configuration for subcommands."""
49
- cfg = resolve_backend_config(base_url=base_url, api_key=api_key, timeout=timeout)
50
- ctx.ensure_object(dict)
51
- ctx.obj["status_backend_config"] = cfg
52
-
53
- status.add_command(jobs_group, name="jobs")
54
- status.add_command(models_group, name="models")
55
- status.add_command(files_group, name="files")
56
- status.add_command(runs_group, name="runs")
57
- status.add_command(session_status_cmd, name="session")
58
- status.add_command(summary_command, name="summary")
59
-
60
- cli.add_command(status, name="status")
61
- _attach_group(cli, jobs_group, "jobs")
62
- _attach_group(cli, models_group, "models")
63
- _attach_group(cli, files_group, "files")
64
- _attach_group(cli, runs_group, "runs")
65
- if "status-summary" not in cli.commands:
66
- cli.add_command(summary_command, name="status-summary")
@@ -1,4 +1,4 @@
1
- """Async HTTP client for Synth status and listing endpoints."""
1
+ """HTTP client for status commands."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
@@ -8,31 +8,43 @@ import httpx
8
8
 
9
9
  from .config import BackendConfig
10
10
  from .errors import StatusAPIError
11
+ from .utils import build_headers
11
12
 
12
13
 
13
14
  class StatusAPIClient:
14
- """Thin wrapper around httpx.AsyncClient with convenience methods."""
15
-
16
15
  def __init__(self, config: BackendConfig) -> None:
17
16
  self._config = config
18
- timeout = httpx.Timeout(config.timeout)
19
- self._client = httpx.AsyncClient(
20
- base_url=config.base_url,
21
- headers=config.headers,
22
- timeout=timeout,
23
- )
24
-
25
- async def __aenter__(self) -> StatusAPIClient:
26
- await self._client.__aenter__()
17
+ self._client: httpx.AsyncClient | None = None
18
+
19
+ async def __aenter__(self) -> "StatusAPIClient":
20
+ if self._client is None:
21
+ self._client = httpx.AsyncClient(
22
+ base_url=self._config.base_url,
23
+ headers=build_headers(self._config.api_key),
24
+ timeout=self._config.timeout,
25
+ )
27
26
  return self
28
27
 
29
- async def __aexit__(self, *args: Any) -> None:
30
- await self._client.__aexit__(*args)
31
-
32
- async def close(self) -> None:
33
- await self._client.aclose()
34
-
35
- # Jobs -----------------------------------------------------------------
28
+ async def __aexit__(self, exc_type, exc, tb) -> None:
29
+ if self._client is not None:
30
+ await self._client.aclose()
31
+ self._client = None
32
+
33
+ async def _get(self, path: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
34
+ assert self._client is not None
35
+ resp = await self._client.get(path, params=params)
36
+ if resp.status_code >= 400:
37
+ detail = resp.json().get("detail", "")
38
+ raise StatusAPIError(detail or "Request failed", status_code=resp.status_code)
39
+ return resp.json()
40
+
41
+ async def _post(self, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
42
+ assert self._client is not None
43
+ resp = await self._client.post(path, json=payload or {})
44
+ if resp.status_code >= 400:
45
+ detail = resp.json().get("detail", "")
46
+ raise StatusAPIError(detail or "Request failed", status_code=resp.status_code)
47
+ return resp.json()
36
48
 
37
49
  async def list_jobs(
38
50
  self,
@@ -43,94 +55,22 @@ class StatusAPIClient:
43
55
  limit: int | None = None,
44
56
  ) -> list[dict[str, Any]]:
45
57
  params: dict[str, Any] = {}
46
- if status:
58
+ if status is not None:
47
59
  params["status"] = status
48
- if job_type:
60
+ if job_type is not None:
49
61
  params["type"] = job_type
50
- if created_after:
62
+ if created_after is not None:
51
63
  params["created_after"] = created_after
52
- if limit:
64
+ if limit is not None:
53
65
  params["limit"] = limit
54
- resp = await self._client.get("/learning/jobs", params=params)
55
- return self._json_list(resp, key="jobs")
66
+ payload = await self._get("/learning/jobs", params=params or None)
67
+ return payload.get("jobs", [])
56
68
 
57
69
  async def get_job(self, job_id: str) -> dict[str, Any]:
58
- resp = await self._client.get(f"/learning/jobs/{job_id}")
59
- return self._json(resp)
60
-
61
- async def get_job_status(self, job_id: str) -> dict[str, Any]:
62
- resp = await self._client.get(f"/learning/jobs/{job_id}/status")
63
- return self._json(resp)
70
+ return await self._get(f"/learning/jobs/{job_id}")
64
71
 
65
72
  async def cancel_job(self, job_id: str) -> dict[str, Any]:
66
- resp = await self._client.post(f"/learning/jobs/{job_id}/cancel")
67
- return self._json(resp)
68
-
69
- async def get_job_config(self, job_id: str) -> dict[str, Any]:
70
- resp = await self._client.get(f"/learning/jobs/{job_id}/config")
71
- return self._json(resp)
72
-
73
- async def get_job_metrics(self, job_id: str) -> dict[str, Any]:
74
- resp = await self._client.get(f"/learning/jobs/{job_id}/metrics")
75
- return self._json(resp)
76
-
77
- async def get_job_timeline(self, job_id: str) -> list[dict[str, Any]]:
78
- resp = await self._client.get(f"/learning/jobs/{job_id}/timeline")
79
- return self._json_list(resp, key="timeline")
80
-
81
- async def list_job_runs(self, job_id: str) -> list[dict[str, Any]]:
82
- resp = await self._client.get(f"/jobs/{job_id}/runs")
83
- return self._json_list(resp, key="runs")
84
-
85
- async def get_job_events(
86
- self,
87
- job_id: str,
88
- *,
89
- since: str | None = None,
90
- limit: int | None = None,
91
- after: str | None = None,
92
- run_id: str | None = None,
93
- ) -> list[dict[str, Any]]:
94
- params: dict[str, Any] = {}
95
- if since:
96
- params["since"] = since
97
- if limit:
98
- params["limit"] = limit
99
- if after:
100
- params["after"] = after
101
- if run_id:
102
- params["run"] = run_id
103
- resp = await self._client.get(f"/learning/jobs/{job_id}/events", params=params)
104
- return self._json_list(resp, key="events")
105
-
106
- # Files ----------------------------------------------------------------
107
-
108
- async def list_files(
109
- self,
110
- *,
111
- purpose: str | None = None,
112
- limit: int | None = None,
113
- ) -> list[dict[str, Any]]:
114
- params: dict[str, Any] = {}
115
- if purpose:
116
- params["purpose"] = purpose
117
- if limit:
118
- params["limit"] = limit
119
- resp = await self._client.get("/files", params=params)
120
- data = self._json(resp)
121
- if isinstance(data, dict):
122
- for key in ("files", "data", "items"):
123
- if isinstance(data.get(key), list):
124
- return list(data[key])
125
- if isinstance(data, list):
126
- return list(data)
127
- return []
128
-
129
- async def get_file(self, file_id: str) -> dict[str, Any]:
130
- resp = await self._client.get(f"/files/{file_id}")
131
- return self._json(resp)
132
-
133
- # Models ---------------------------------------------------------------
73
+ return await self._post(f"/learning/jobs/{job_id}/cancel")
134
74
 
135
75
  async def list_models(
136
76
  self,
@@ -138,55 +78,14 @@ class StatusAPIClient:
138
78
  limit: int | None = None,
139
79
  model_type: str | None = None,
140
80
  ) -> list[dict[str, Any]]:
141
- params: dict[str, Any] = {}
142
- if limit:
143
- params["limit"] = limit
144
- endpoint = "/learning/models/rl" if model_type == "rl" else "/learning/models"
145
- resp = await self._client.get(endpoint, params=params)
146
- return self._json_list(resp, key="models")
147
-
148
- async def get_model(self, model_id: str) -> dict[str, Any]:
149
- resp = await self._client.get(f"/learning/models/{model_id}")
150
- return self._json(resp)
151
-
152
- # Helpers --------------------------------------------------------------
153
-
154
- def _json(self, response: httpx.Response) -> dict[str, Any]:
155
- try:
156
- response.raise_for_status()
157
- except httpx.HTTPStatusError as exc:
158
- detail = self._extract_detail(exc.response)
159
- raise StatusAPIError(detail, exc.response.status_code if exc.response else None) from exc
160
- try:
161
- data = response.json()
162
- except ValueError as exc:
163
- raise StatusAPIError("Backend response was not valid JSON") from exc
164
- if isinstance(data, dict):
165
- return data
166
- return {"data": data}
167
-
168
- def _json_list(self, response: httpx.Response, *, key: str | None = None) -> list[dict[str, Any]]:
169
- payload = self._json(response)
170
- if key and isinstance(payload.get(key), list):
171
- return list(payload[key])
172
- if isinstance(payload.get("data"), list):
173
- return list(payload["data"])
174
- if isinstance(payload.get("results"), list):
175
- return list(payload["results"])
176
- if isinstance(payload, list):
177
- return list(payload)
178
- return []
179
-
180
- @staticmethod
181
- def _extract_detail(response: httpx.Response | None) -> str:
182
- if response is None:
183
- return "Backend request failed"
184
- try:
185
- data = response.json()
186
- if isinstance(data, dict):
187
- for key in ("detail", "message", "error"):
188
- if data.get(key):
189
- return str(data[key])
190
- return response.text
191
- except ValueError:
192
- return response.text
81
+ if model_type:
82
+ payload = await self._get(f"/learning/models/{model_type}")
83
+ return payload.get("models", [])
84
+ params = {"limit": limit} if limit is not None else None
85
+ payload = await self._get("/learning/models", params=params)
86
+ return payload.get("models", [])
87
+
88
+ async def list_job_runs(self, job_id: str) -> list[dict[str, Any]]:
89
+ payload = await self._get(f"/jobs/{job_id}/runs")
90
+ return payload.get("runs", [])
91
+