synth-ai 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (38) hide show
  1. examples/multi_step/task_app_config_notes.md +488 -0
  2. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
  3. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  4. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  5. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
  6. examples/warming_up_to_rl/run_eval.py +142 -25
  7. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
  8. synth_ai/api/train/builders.py +25 -14
  9. synth_ai/api/train/cli.py +29 -6
  10. synth_ai/api/train/env_resolver.py +18 -19
  11. synth_ai/api/train/supported_algos.py +8 -5
  12. synth_ai/api/train/utils.py +6 -1
  13. synth_ai/cli/__init__.py +4 -2
  14. synth_ai/cli/_storage.py +19 -0
  15. synth_ai/cli/balance.py +14 -2
  16. synth_ai/cli/calc.py +37 -22
  17. synth_ai/cli/legacy_root_backup.py +12 -14
  18. synth_ai/cli/recent.py +12 -7
  19. synth_ai/cli/status.py +4 -3
  20. synth_ai/cli/task_apps.py +143 -137
  21. synth_ai/cli/traces.py +4 -3
  22. synth_ai/cli/watch.py +3 -2
  23. synth_ai/jobs/client.py +15 -3
  24. synth_ai/task/server.py +14 -7
  25. synth_ai/tracing_v3/decorators.py +51 -26
  26. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  27. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  28. synth_ai/tracing_v3/replica_sync.py +8 -4
  29. synth_ai/tracing_v3/storage/utils.py +11 -9
  30. synth_ai/tracing_v3/turso/__init__.py +12 -0
  31. synth_ai/tracing_v3/turso/daemon.py +2 -1
  32. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  33. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/METADATA +4 -2
  34. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/RECORD +38 -31
  35. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
  36. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
  37. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
  38. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ import json
15
15
  import os
16
16
  import re
17
17
  import tomllib
18
+ from copy import deepcopy
18
19
  from collections import Counter
19
20
  from pathlib import Path
20
21
  from typing import Any
@@ -115,26 +116,34 @@ class TaskAppClient:
115
116
  run_id: str,
116
117
  env_name: str,
117
118
  seed: int,
118
- difficulty: str,
119
+ difficulty: str | None,
119
120
  policy_name: str,
120
121
  policy_config: dict[str, Any],
121
122
  max_turns: int,
123
+ env_config: dict[str, Any] | None = None,
124
+ ops: list[str] | None = None,
122
125
  ) -> dict[str, Any]:
123
- ops: list[str] = []
124
- for _ in range(max_turns):
125
- ops.extend(["agent", "env"])
126
+ ops_seq: list[str] = list(ops) if ops is not None else []
127
+ if not ops_seq:
128
+ for _ in range(max_turns):
129
+ ops_seq.extend(["agent", "env"])
130
+ env_cfg: dict[str, Any] = {}
131
+ if isinstance(env_config, dict):
132
+ env_cfg.update(env_config)
133
+ if difficulty is not None and "difficulty" not in env_cfg:
134
+ env_cfg["difficulty"] = difficulty
126
135
  payload: dict[str, Any] = {
127
136
  "run_id": run_id,
128
137
  "env": {
129
138
  "env_name": env_name,
130
- "config": {"difficulty": difficulty},
139
+ "config": env_cfg,
131
140
  "seed": seed,
132
141
  },
133
142
  "policy": {
134
143
  "policy_name": policy_name,
135
144
  "config": policy_config,
136
145
  },
137
- "ops": ops,
146
+ "ops": ops_seq,
138
147
  "on_done": "terminate",
139
148
  }
140
149
  # Ensure X-API-Key is included
@@ -415,11 +424,20 @@ async def main() -> None:
415
424
  async with sem:
416
425
  try:
417
426
  run_id = f"eval-{seed}"
418
- # Build policy config from TOML (explicit control; no server-side guessing)
419
- policy_cfg: dict[str, Any] = {
420
- "model": cfg.get("model", MODEL),
421
- "inference_url": inf_url,
422
- }
427
+ rollout_cfg_raw = cfg.get("rollout") or {}
428
+ rollout_cfg = (
429
+ dict(rollout_cfg_raw) if isinstance(rollout_cfg_raw, dict) else {}
430
+ )
431
+ env_config_raw = rollout_cfg.get("env_config") or {}
432
+ env_config = (
433
+ deepcopy(env_config_raw) if isinstance(env_config_raw, dict) else {}
434
+ )
435
+ policy_cfg_raw = rollout_cfg.get("policy_config") or {}
436
+ policy_cfg = (
437
+ deepcopy(policy_cfg_raw) if isinstance(policy_cfg_raw, dict) else {}
438
+ )
439
+ policy_cfg.setdefault("model", cfg.get("model", MODEL))
440
+ policy_cfg.setdefault("inference_url", inf_url)
423
441
  for k in (
424
442
  "max_tokens",
425
443
  "temperature",
@@ -428,18 +446,56 @@ async def main() -> None:
428
446
  "thinking_budget",
429
447
  "use_tools",
430
448
  ):
431
- if k in cfg and cfg.get(k) is not None:
449
+ if k in cfg and cfg.get(k) is not None and k not in policy_cfg:
432
450
  policy_cfg[k] = cfg.get(k)
433
451
 
452
+ env_name = str(rollout_cfg.get("env_name") or "crafter")
453
+ policy_name = str(
454
+ rollout_cfg.get("policy_name") or cfg.get("policy_name") or "crafter"
455
+ )
456
+
457
+ max_turns_local = MAX_TURNS
458
+ for candidate in (rollout_cfg.get("max_turns"), cfg.get("max_turns")):
459
+ if candidate is None:
460
+ continue
461
+ with contextlib.suppress(Exception):
462
+ max_turns_local = int(candidate)
463
+ break
464
+
465
+ difficulty_override: str | None = None
466
+ if isinstance(env_config, dict):
467
+ diff_cfg = env_config.get("difficulty")
468
+ if isinstance(diff_cfg, str) and diff_cfg:
469
+ difficulty_override = diff_cfg
470
+ if difficulty_override is None:
471
+ cfg_diff = rollout_cfg.get("difficulty") or cfg.get("difficulty")
472
+ if isinstance(cfg_diff, str) and cfg_diff:
473
+ difficulty_override = cfg_diff
474
+ if difficulty_override is None:
475
+ difficulty_override = os.getenv("DIFFICULTY", "easy")
476
+
434
477
  r = await client.rollout(
435
478
  run_id=run_id,
436
- env_name="crafter",
479
+ env_name=env_name,
437
480
  seed=seed,
438
- difficulty=os.getenv("DIFFICULTY", "easy"),
439
- policy_name=cfg.get("policy_name", "crafter"),
481
+ difficulty=difficulty_override,
482
+ policy_name=policy_name,
440
483
  policy_config=policy_cfg,
441
- max_turns=MAX_TURNS,
484
+ max_turns=max_turns_local,
485
+ env_config=env_config,
442
486
  )
487
+ metrics_block = r.get("metrics") or {}
488
+ mean_return = None
489
+ if isinstance(metrics_block, dict):
490
+ with contextlib.suppress(Exception):
491
+ mean_return = float(metrics_block.get("mean_return"))
492
+ stepwise_details: dict[str, Any] = {}
493
+ if isinstance(metrics_block, dict):
494
+ details_block = metrics_block.get("details") or {}
495
+ if isinstance(details_block, dict):
496
+ step_block = details_block.get("stepwise") or {}
497
+ if isinstance(step_block, dict):
498
+ stepwise_details = step_block
443
499
  # Extract achievements count if present
444
500
  ach = []
445
501
  try:
@@ -465,9 +521,22 @@ async def main() -> None:
465
521
  length = int(trajs[0].get("length") or 0)
466
522
  except Exception:
467
523
  pass
468
- return {"seed": seed, "turns": length, "achievements": ach}
524
+ return {
525
+ "seed": seed,
526
+ "turns": length,
527
+ "achievements": ach,
528
+ "mean_return": mean_return,
529
+ "stepwise": stepwise_details,
530
+ }
469
531
  except Exception as e:
470
- return {"seed": seed, "turns": 0, "achievements": [], "error": str(e)}
532
+ return {
533
+ "seed": seed,
534
+ "turns": 0,
535
+ "achievements": [],
536
+ "mean_return": None,
537
+ "stepwise": {},
538
+ "error": str(e),
539
+ }
471
540
 
472
541
  results = await asyncio.gather(
473
542
  *[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)],
@@ -483,15 +552,63 @@ async def main() -> None:
483
552
  all_ach[a] += 1
484
553
  except Exception:
485
554
  pass
555
+ mean_returns: list[float] = []
556
+ stepwise_reward_sums: list[float] = []
557
+ stepwise_indicator_sums: list[float] = []
558
+ stepwise_new_ach_totals: list[float] = []
559
+ strategies_seen = Counter()
560
+ for r in results:
561
+ if not isinstance(r, dict):
562
+ continue
563
+ with contextlib.suppress(Exception):
564
+ mean_val = r.get("mean_return")
565
+ if mean_val is not None:
566
+ mean_returns.append(float(mean_val))
567
+ stepwise_block = r.get("stepwise")
568
+ if isinstance(stepwise_block, dict) and stepwise_block:
569
+ with contextlib.suppress(Exception):
570
+ if stepwise_block.get("reward_sum") is not None:
571
+ stepwise_reward_sums.append(float(stepwise_block.get("reward_sum")))
572
+ with contextlib.suppress(Exception):
573
+ if stepwise_block.get("indicator_sum") is not None:
574
+ stepwise_indicator_sums.append(float(stepwise_block.get("indicator_sum")))
575
+ with contextlib.suppress(Exception):
576
+ if stepwise_block.get("new_achievements_total") is not None:
577
+ stepwise_new_ach_totals.append(
578
+ float(stepwise_block.get("new_achievements_total"))
579
+ )
580
+ strategy_name = stepwise_block.get("strategy")
581
+ if isinstance(strategy_name, str) and strategy_name:
582
+ strategies_seen[strategy_name] += 1
583
+ aggregate: dict[str, Any] = {
584
+ "completed": sum(
585
+ 1 for r in results if isinstance(r, dict) and not r.get("error")
586
+ ),
587
+ "total": len(results),
588
+ "avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
589
+ "avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
590
+ "achievements_freq": dict(all_ach),
591
+ }
592
+ if mean_returns:
593
+ aggregate["avg_mean_return"] = sum(mean_returns) / len(mean_returns)
594
+ if stepwise_reward_sums:
595
+ aggregate["avg_stepwise_reward_sum"] = sum(stepwise_reward_sums) / len(
596
+ stepwise_reward_sums
597
+ )
598
+ if stepwise_indicator_sums:
599
+ aggregate["avg_stepwise_indicator_sum"] = sum(stepwise_indicator_sums) / len(
600
+ stepwise_indicator_sums
601
+ )
602
+ if stepwise_new_ach_totals:
603
+ aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
604
+ stepwise_new_ach_totals
605
+ )
606
+ if strategies_seen:
607
+ aggregate["stepwise_strategies"] = dict(strategies_seen)
608
+ aggregate["stepwise_samples"] = len(stepwise_reward_sums)
486
609
  summary = {
487
610
  "episodes": results,
488
- "aggregate": {
489
- "completed": sum(1 for r in results if not r.get("error")),
490
- "total": len(results),
491
- "avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
492
- "avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
493
- "achievements_freq": dict(all_ach),
494
- },
611
+ "aggregate": aggregate,
495
612
  }
496
613
  print(json.dumps(summary, indent=2))
497
614
  else:
@@ -9,7 +9,7 @@ from datetime import datetime
9
9
  from typing import Any
10
10
 
11
11
  from fastapi import APIRouter, HTTPException, Request, status
12
- from pydantic import BaseModel
12
+ from pydantic import BaseModel, Field
13
13
  from synth_ai.lm.vendors.base import BaseLMResponse
14
14
  from synth_ai.task.tracing_utils import unique_sft_path
15
15
  from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
@@ -142,12 +142,59 @@ class RolloutTrajectory(BaseModel):
142
142
  decision_samples: list[dict[str, Any]] | None = None
143
143
 
144
144
 
145
+ def _normalize_step_strategy(raw_strategy: Any) -> str:
146
+ if not isinstance(raw_strategy, str):
147
+ return "consistent"
148
+ candidate = raw_strategy.strip().lower()
149
+ if not candidate:
150
+ return "consistent"
151
+ mapping = {
152
+ "simple": "consistent",
153
+ "consistent": "consistent",
154
+ "consistent_stepwise": "consistent",
155
+ "decision_consistent": "consistent",
156
+ "per_achievement": "per_achievement",
157
+ "per-achievement": "per_achievement",
158
+ "perachievement": "per_achievement",
159
+ "achievement_weighted": "per_achievement",
160
+ "complex": "per_achievement",
161
+ }
162
+ return mapping.get(candidate, "consistent")
163
+
164
+
165
+ def _coerce_weights(raw_weights: Any) -> dict[str, float]:
166
+ weights: dict[str, float] = {}
167
+ if isinstance(raw_weights, dict):
168
+ for key, value in raw_weights.items():
169
+ try:
170
+ weights[str(key)] = float(value)
171
+ except Exception:
172
+ continue
173
+ return weights
174
+
175
+
176
+ def _coerce_k_limits(raw_limits: Any) -> dict[str, int]:
177
+ limits: dict[str, int] = {}
178
+ if isinstance(raw_limits, dict):
179
+ for key, value in raw_limits.items():
180
+ try:
181
+ limits[str(key)] = int(value)
182
+ except Exception:
183
+ continue
184
+ return limits
185
+
186
+
145
187
  def compute_stepwise_reward(
146
188
  prev_achievements: dict[str, bool],
147
189
  new_achievements: dict[str, bool],
148
190
  decision_index: int,
149
191
  actions_summary: list[dict[str, Any]],
150
192
  indicator_lambda: float,
193
+ *,
194
+ strategy: str | None = None,
195
+ weights: dict[str, float] | None = None,
196
+ k_limits: dict[str, int] | None = None,
197
+ episode_counts: dict[str, int] | None = None,
151
198
  ) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
152
199
  """Compute stepwise reward metadata given achievement states before/after a decision."""
153
200
 
@@ -156,24 +203,88 @@ def compute_stepwise_reward(
156
203
 
157
204
  unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
158
205
  indicator = 1 if unlocked else 0
159
- reward_value = float(indicator_lambda) * indicator
206
+ normalized_strategy = _normalize_step_strategy(strategy)
207
+ base_reward = 0.0
208
+ reward_components: list[dict[str, Any]] = []
209
+ credited: list[str] = []
210
+
211
+ if indicator:
212
+ if normalized_strategy == "per_achievement":
213
+ weight_map = weights or {}
214
+ limit_map = k_limits or {}
215
+ counts = episode_counts if isinstance(episode_counts, dict) else {}
216
+ for name in unlocked:
217
+ try:
218
+ limit_val = int(limit_map.get(name, 1))
219
+ except Exception:
220
+ limit_val = 1
221
+ # limit_val <= 0 implies unlimited rewards
222
+ unlimited = limit_val <= 0
223
+ try:
224
+ prev_count = int(counts.get(name, 0))
225
+ except Exception:
226
+ prev_count = 0
227
+ should_credit = unlimited or (prev_count < max(limit_val, 0))
228
+ if should_credit:
229
+ try:
230
+ weight_val = float(weight_map.get(name, 1.0))
231
+ except Exception:
232
+ weight_val = 1.0
233
+ base_reward += weight_val
234
+ reward_components.append(
235
+ {
236
+ "achievement": name,
237
+ "weight": weight_val,
238
+ "count_prior": prev_count,
239
+ "count_limit": limit_val,
240
+ }
241
+ )
242
+ credited.append(name)
243
+ if episode_counts is not None:
244
+ episode_counts[name] = prev_count + 1
245
+ else:
246
+ base_reward = 1.0
247
+ reward_components.append(
248
+ {
249
+ "achievement": "__indicator__",
250
+ "weight": 1.0,
251
+ "count_prior": 0,
252
+ "count_limit": 1,
253
+ }
254
+ )
255
+
256
+ reward_value = float(indicator_lambda) * float(base_reward)
160
257
 
161
258
  stepwise_info = {
162
259
  "decision_index": decision_index,
163
260
  "indicator": indicator,
164
261
  "new_achievements": unlocked,
165
262
  "reward": reward_value,
263
+ "strategy": normalized_strategy,
264
+ "base_reward": float(base_reward),
166
265
  }
266
+ if reward_components:
267
+ stepwise_info["components"] = reward_components
268
+ if credited:
269
+ stepwise_info["credited_achievements"] = credited
270
+
167
271
  decision_sample = {
168
272
  "decision_index": decision_index,
169
273
  "indicator": indicator,
170
274
  "r_i": reward_value,
275
+ "base": float(base_reward),
276
+ "strategy": normalized_strategy,
171
277
  "actions": actions_summary,
172
278
  }
279
+ if reward_components:
280
+ decision_sample["components"] = reward_components
281
+
173
282
  stats = {
174
283
  "indicator": float(indicator),
175
284
  "reward": reward_value,
176
285
  "new_achievements_count": float(len(unlocked)),
286
+ "base_reward": float(base_reward),
287
+ "credited_achievements_count": float(len(credited)),
177
288
  }
178
289
  return stepwise_info, decision_sample, stats
179
290
 
@@ -183,6 +294,9 @@ class RolloutMetrics(BaseModel):
183
294
  mean_return: float
184
295
  num_steps: int
185
296
  num_episodes: int = 0
297
+ outcome_score: float | None = None
298
+ events_score: float | None = None
299
+ details: dict[str, Any] = Field(default_factory=dict)
186
300
 
187
301
 
188
302
  class RolloutResponse(BaseModel):
@@ -1053,6 +1167,9 @@ async def execute_rollout(
1053
1167
 
1054
1168
  step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
1055
1169
  step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
1170
+ step_rewards_strategy = _normalize_step_strategy(step_rewards_cfg_raw.get("strategy"))
1171
+ step_rewards_weights = _coerce_weights(step_rewards_cfg_raw.get("weights"))
1172
+ step_rewards_k_limits = _coerce_k_limits(step_rewards_cfg_raw.get("k_limits"))
1056
1173
  try:
1057
1174
  step_rewards_indicator_lambda = float(
1058
1175
  step_rewards_cfg_raw.get("indicator_lambda") or 0.0
@@ -1113,6 +1230,7 @@ async def execute_rollout(
1113
1230
  episode_seen_achievements: set[str] = {
1114
1231
  k for k, v in (prev_achievements or {}).items() if bool(v)
1115
1232
  }
1233
+ episode_achievement_counts: dict[str, int] = {}
1116
1234
  stepwise_indicator_sum = 0.0
1117
1235
  stepwise_reward_sum = 0.0
1118
1236
  stepwise_new_achievements_total = 0
@@ -1560,6 +1678,10 @@ async def execute_rollout(
1560
1678
  decision_index,
1561
1679
  decision_actions,
1562
1680
  step_rewards_indicator_lambda,
1681
+ strategy=step_rewards_strategy,
1682
+ weights=step_rewards_weights,
1683
+ k_limits=step_rewards_k_limits,
1684
+ episode_counts=episode_achievement_counts,
1563
1685
  )
1564
1686
  indicator_val = int(stats.get("indicator", 0.0))
1565
1687
  reward_stepwise = float(stats.get("reward", 0.0))
@@ -1656,6 +1778,11 @@ async def execute_rollout(
1656
1778
 
1657
1779
  reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
1658
1780
  current_obs = reset_response.observation
1781
+ prev_achievements = _extract_achievements(current_obs)
1782
+ episode_seen_achievements = {
1783
+ k for k, v in (prev_achievements or {}).items() if bool(v)
1784
+ }
1785
+ episode_achievement_counts.clear()
1659
1786
  elif request.on_done == "terminate":
1660
1787
  break
1661
1788
 
@@ -1704,6 +1831,23 @@ async def execute_rollout(
1704
1831
  num_steps=len(trajectory_steps),
1705
1832
  num_episodes=1,
1706
1833
  )
1834
+ if step_rewards_active:
1835
+ stepwise_summary: dict[str, Any] = {
1836
+ "indicator_sum": float(stepwise_indicator_sum),
1837
+ "reward_sum": float(stepwise_reward_sum),
1838
+ "new_achievements_total": int(stepwise_new_achievements_total),
1839
+ "mode": step_rewards_mode,
1840
+ "strategy": step_rewards_strategy,
1841
+ "indicator_lambda": float(step_rewards_indicator_lambda),
1842
+ }
1843
+ if step_rewards_beta:
1844
+ stepwise_summary["step_beta"] = float(step_rewards_beta)
1845
+ if step_rewards_strategy == "per_achievement":
1846
+ if step_rewards_weights:
1847
+ stepwise_summary["weights"] = dict(step_rewards_weights)
1848
+ if step_rewards_k_limits:
1849
+ stepwise_summary["k_limits"] = dict(step_rewards_k_limits)
1850
+ metrics.details["stepwise"] = stepwise_summary
1707
1851
 
1708
1852
  # Environment-specific: Log summary if available
1709
1853
  try:
@@ -1,16 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
5
- from typing import Any
6
+ from typing import Any, cast
6
7
 
7
8
  import click
8
- from synth_ai.api.models.supported import (
9
- UnsupportedModelError,
10
- ensure_allowed_model,
11
- normalize_model_identifier,
12
- )
13
- from synth_ai.learning.sft.config import prepare_sft_job_payload
9
+
10
+ try:
11
+ _models_module = importlib.import_module("synth_ai.api.models.supported")
12
+ UnsupportedModelError = _models_module.UnsupportedModelError
13
+ ensure_allowed_model = _models_module.ensure_allowed_model
14
+ normalize_model_identifier = _models_module.normalize_model_identifier
15
+ except Exception as exc: # pragma: no cover - critical dependency
16
+ raise RuntimeError("Unable to load supported model helpers") from exc
17
+
18
+ try:
19
+ prepare_sft_job_payload = importlib.import_module("synth_ai.learning.sft.config").prepare_sft_job_payload
20
+ except Exception as exc: # pragma: no cover - critical dependency
21
+ raise RuntimeError("Unable to load SFT payload helpers") from exc
14
22
 
15
23
  from .supported_algos import (
16
24
  AlgorithmValidationError,
@@ -122,23 +130,26 @@ def build_rl_payload(
122
130
  except Exception:
123
131
  pass
124
132
 
133
+ payload_data: dict[str, Any] = {
134
+ "endpoint_base_url": final_task_url.rstrip("/"),
135
+ "config": data,
136
+ }
125
137
  payload: dict[str, Any] = {
126
138
  "job_type": "rl",
127
139
  "compute": data.get("compute", {}),
128
- "data": {
129
- "endpoint_base_url": final_task_url.rstrip("/"),
130
- "config": data,
131
- },
140
+ "data": payload_data,
132
141
  "tags": {"source": "train-cli"},
133
142
  }
134
143
  if model_source:
135
- payload["data"]["model"] = model_source
144
+ payload_data["model"] = model_source
136
145
  if model_base:
137
- payload["data"]["base_model"] = model_base
146
+ payload_data["base_model"] = model_base
138
147
 
139
148
  backend = overrides.get("backend")
140
149
  if backend:
141
- payload.setdefault("metadata", {})["backend_base_url"] = ensure_api_base(str(backend))
150
+ metadata_default: dict[str, Any] = {}
151
+ metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
152
+ metadata["backend_base_url"] = ensure_api_base(str(backend))
142
153
 
143
154
  return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
144
155
 
synth_ai/api/train/cli.py CHANGED
@@ -1,11 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import os
5
+ from collections.abc import Mapping
4
6
  from pathlib import Path
5
7
  from typing import Any
6
8
 
7
9
  import click
8
- from synth_ai.config.base_url import get_backend_from_env
10
+
11
+ try:
12
+ _config_module = importlib.import_module("synth_ai.config.base_url")
13
+ get_backend_from_env = _config_module.get_backend_from_env
14
+ except Exception as exc: # pragma: no cover - critical dependency
15
+ raise RuntimeError("Unable to load backend configuration helpers") from exc
9
16
 
10
17
  from .builders import build_rl_payload, build_sft_payload
11
18
  from .config_finder import discover_configs, prompt_for_config
@@ -231,7 +238,8 @@ def train_command(
231
238
  ]
232
239
  if missing_keys:
233
240
  try:
234
- from synth_ai.cli.task_apps import _interactive_fill_env
241
+ _task_apps_module = importlib.import_module("synth_ai.cli.task_apps")
242
+ _interactive_fill_env = _task_apps_module._interactive_fill_env
235
243
  except Exception as exc: # pragma: no cover - protective fallback
236
244
  raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
237
245
 
@@ -386,9 +394,19 @@ def handle_rl(
386
394
  verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url}
387
395
  )
388
396
  try:
389
- vjs = vresp.json()
397
+ parsed_json = vresp.json()
390
398
  except Exception:
391
- vjs = {"status": vresp.status_code, "text": (vresp.text or "")[:400]}
399
+ parsed_json = None
400
+
401
+ if isinstance(parsed_json, Mapping):
402
+ vjs: dict[str, Any] = dict(parsed_json)
403
+ else:
404
+ vjs = {
405
+ "status": vresp.status_code,
406
+ "text": (vresp.text or "")[:400],
407
+ }
408
+ if parsed_json is not None:
409
+ vjs["body"] = parsed_json
392
410
  except Exception as _ve:
393
411
  raise click.ClickException(
394
412
  f"Task app verification call failed: {type(_ve).__name__}: {_ve}"
@@ -404,8 +422,13 @@ def handle_rl(
404
422
  # Print concise summary
405
423
  try:
406
424
  cands = vjs.get("candidates_first15") or []
407
- attempts = vjs.get("attempts") or []
408
- statuses = [a.get("status") for a in attempts]
425
+ attempts_raw = vjs.get("attempts")
426
+ attempts: list[Mapping[str, Any]] = (
427
+ [a for a in attempts_raw if isinstance(a, Mapping)]
428
+ if isinstance(attempts_raw, list)
429
+ else []
430
+ )
431
+ statuses = [attempt.get("status") for attempt in attempts]
409
432
  click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
410
433
  except Exception:
411
434
  pass
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import os
4
5
  from collections.abc import Callable, Iterable, MutableMapping
5
6
  from dataclasses import dataclass
@@ -11,6 +12,18 @@ from . import task_app
11
12
  from .utils import REPO_ROOT, mask_value, read_env_file, write_env_value
12
13
 
13
14
 
15
+ def _load_saved_env_path() -> Path | None:
16
+ try:
17
+ module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
18
+ loader = module.load_env_file_path
19
+ saved_path = loader()
20
+ if saved_path:
21
+ return Path(saved_path)
22
+ except Exception:
23
+ return None
24
+ return None
25
+
26
+
14
27
  @dataclass(slots=True)
15
28
  class KeySpec:
16
29
  name: str
@@ -156,25 +169,11 @@ def resolve_env(
156
169
  raise click.ClickException(f"Env file not found: {path}")
157
170
  resolver = EnvResolver(provided)
158
171
  else:
159
- # Check for saved .env path from demo command
160
- try:
161
- from synth_ai.demos.demo_task_apps.core import load_env_file_path
162
-
163
- saved_env_path = load_env_file_path()
164
- if saved_env_path:
165
- saved_path = Path(saved_env_path)
166
- if saved_path.exists():
167
- click.echo(f"Using .env file: {saved_path}")
168
- resolver = EnvResolver([saved_path])
169
- else:
170
- # Saved path no longer exists, fall back to prompt
171
- resolver = EnvResolver(_collect_default_candidates(config_path))
172
- resolver.select_new_env()
173
- else:
174
- resolver = EnvResolver(_collect_default_candidates(config_path))
175
- resolver.select_new_env()
176
- except Exception:
177
- # If import fails or any error, fall back to original behavior
172
+ saved_path = _load_saved_env_path()
173
+ if saved_path and saved_path.exists():
174
+ click.echo(f"Using .env file: {saved_path}")
175
+ resolver = EnvResolver([saved_path])
176
+ else:
178
177
  resolver = EnvResolver(_collect_default_candidates(config_path))
179
178
  resolver.select_new_env()
180
179
 
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  from collections.abc import Mapping
4
5
  from dataclasses import dataclass
5
6
 
6
- from synth_ai.api.models.supported import (
7
- RL_SUPPORTED_MODELS,
8
- SFT_SUPPORTED_MODELS,
9
- training_modes_for_model,
10
- )
7
+ try:
8
+ _models_module = importlib.import_module("synth_ai.api.models.supported")
9
+ RL_SUPPORTED_MODELS = _models_module.RL_SUPPORTED_MODELS
10
+ SFT_SUPPORTED_MODELS = _models_module.SFT_SUPPORTED_MODELS
11
+ training_modes_for_model = _models_module.training_modes_for_model
12
+ except Exception as exc: # pragma: no cover - critical dependency
13
+ raise RuntimeError("Unable to load supported model metadata") from exc
11
14
 
12
15
 
13
16
  @dataclass(frozen=True)