synth-ai 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/task_app_config_notes.md +488 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
- examples/warming_up_to_rl/run_eval.py +142 -25
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
- synth_ai/api/train/builders.py +25 -14
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +143 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/jobs/client.py +15 -3
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/RECORD +38 -31
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ import json
|
|
|
15
15
|
import os
|
|
16
16
|
import re
|
|
17
17
|
import tomllib
|
|
18
|
+
from copy import deepcopy
|
|
18
19
|
from collections import Counter
|
|
19
20
|
from pathlib import Path
|
|
20
21
|
from typing import Any
|
|
@@ -115,26 +116,34 @@ class TaskAppClient:
|
|
|
115
116
|
run_id: str,
|
|
116
117
|
env_name: str,
|
|
117
118
|
seed: int,
|
|
118
|
-
difficulty: str,
|
|
119
|
+
difficulty: str | None,
|
|
119
120
|
policy_name: str,
|
|
120
121
|
policy_config: dict[str, Any],
|
|
121
122
|
max_turns: int,
|
|
123
|
+
env_config: dict[str, Any] | None = None,
|
|
124
|
+
ops: list[str] | None = None,
|
|
122
125
|
) -> dict[str, Any]:
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
+
ops_seq: list[str] = list(ops) if ops is not None else []
|
|
127
|
+
if not ops_seq:
|
|
128
|
+
for _ in range(max_turns):
|
|
129
|
+
ops_seq.extend(["agent", "env"])
|
|
130
|
+
env_cfg: dict[str, Any] = {}
|
|
131
|
+
if isinstance(env_config, dict):
|
|
132
|
+
env_cfg.update(env_config)
|
|
133
|
+
if difficulty is not None and "difficulty" not in env_cfg:
|
|
134
|
+
env_cfg["difficulty"] = difficulty
|
|
126
135
|
payload: dict[str, Any] = {
|
|
127
136
|
"run_id": run_id,
|
|
128
137
|
"env": {
|
|
129
138
|
"env_name": env_name,
|
|
130
|
-
"config":
|
|
139
|
+
"config": env_cfg,
|
|
131
140
|
"seed": seed,
|
|
132
141
|
},
|
|
133
142
|
"policy": {
|
|
134
143
|
"policy_name": policy_name,
|
|
135
144
|
"config": policy_config,
|
|
136
145
|
},
|
|
137
|
-
"ops":
|
|
146
|
+
"ops": ops_seq,
|
|
138
147
|
"on_done": "terminate",
|
|
139
148
|
}
|
|
140
149
|
# Ensure X-API-Key is included
|
|
@@ -415,11 +424,20 @@ async def main() -> None:
|
|
|
415
424
|
async with sem:
|
|
416
425
|
try:
|
|
417
426
|
run_id = f"eval-{seed}"
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
}
|
|
427
|
+
rollout_cfg_raw = cfg.get("rollout") or {}
|
|
428
|
+
rollout_cfg = (
|
|
429
|
+
dict(rollout_cfg_raw) if isinstance(rollout_cfg_raw, dict) else {}
|
|
430
|
+
)
|
|
431
|
+
env_config_raw = rollout_cfg.get("env_config") or {}
|
|
432
|
+
env_config = (
|
|
433
|
+
deepcopy(env_config_raw) if isinstance(env_config_raw, dict) else {}
|
|
434
|
+
)
|
|
435
|
+
policy_cfg_raw = rollout_cfg.get("policy_config") or {}
|
|
436
|
+
policy_cfg = (
|
|
437
|
+
deepcopy(policy_cfg_raw) if isinstance(policy_cfg_raw, dict) else {}
|
|
438
|
+
)
|
|
439
|
+
policy_cfg.setdefault("model", cfg.get("model", MODEL))
|
|
440
|
+
policy_cfg.setdefault("inference_url", inf_url)
|
|
423
441
|
for k in (
|
|
424
442
|
"max_tokens",
|
|
425
443
|
"temperature",
|
|
@@ -428,18 +446,56 @@ async def main() -> None:
|
|
|
428
446
|
"thinking_budget",
|
|
429
447
|
"use_tools",
|
|
430
448
|
):
|
|
431
|
-
if k in cfg and cfg.get(k) is not None:
|
|
449
|
+
if k in cfg and cfg.get(k) is not None and k not in policy_cfg:
|
|
432
450
|
policy_cfg[k] = cfg.get(k)
|
|
433
451
|
|
|
452
|
+
env_name = str(rollout_cfg.get("env_name") or "crafter")
|
|
453
|
+
policy_name = str(
|
|
454
|
+
rollout_cfg.get("policy_name") or cfg.get("policy_name") or "crafter"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
max_turns_local = MAX_TURNS
|
|
458
|
+
for candidate in (rollout_cfg.get("max_turns"), cfg.get("max_turns")):
|
|
459
|
+
if candidate is None:
|
|
460
|
+
continue
|
|
461
|
+
with contextlib.suppress(Exception):
|
|
462
|
+
max_turns_local = int(candidate)
|
|
463
|
+
break
|
|
464
|
+
|
|
465
|
+
difficulty_override: str | None = None
|
|
466
|
+
if isinstance(env_config, dict):
|
|
467
|
+
diff_cfg = env_config.get("difficulty")
|
|
468
|
+
if isinstance(diff_cfg, str) and diff_cfg:
|
|
469
|
+
difficulty_override = diff_cfg
|
|
470
|
+
if difficulty_override is None:
|
|
471
|
+
cfg_diff = rollout_cfg.get("difficulty") or cfg.get("difficulty")
|
|
472
|
+
if isinstance(cfg_diff, str) and cfg_diff:
|
|
473
|
+
difficulty_override = cfg_diff
|
|
474
|
+
if difficulty_override is None:
|
|
475
|
+
difficulty_override = os.getenv("DIFFICULTY", "easy")
|
|
476
|
+
|
|
434
477
|
r = await client.rollout(
|
|
435
478
|
run_id=run_id,
|
|
436
|
-
env_name=
|
|
479
|
+
env_name=env_name,
|
|
437
480
|
seed=seed,
|
|
438
|
-
difficulty=
|
|
439
|
-
policy_name=
|
|
481
|
+
difficulty=difficulty_override,
|
|
482
|
+
policy_name=policy_name,
|
|
440
483
|
policy_config=policy_cfg,
|
|
441
|
-
max_turns=
|
|
484
|
+
max_turns=max_turns_local,
|
|
485
|
+
env_config=env_config,
|
|
442
486
|
)
|
|
487
|
+
metrics_block = r.get("metrics") or {}
|
|
488
|
+
mean_return = None
|
|
489
|
+
if isinstance(metrics_block, dict):
|
|
490
|
+
with contextlib.suppress(Exception):
|
|
491
|
+
mean_return = float(metrics_block.get("mean_return"))
|
|
492
|
+
stepwise_details: dict[str, Any] = {}
|
|
493
|
+
if isinstance(metrics_block, dict):
|
|
494
|
+
details_block = metrics_block.get("details") or {}
|
|
495
|
+
if isinstance(details_block, dict):
|
|
496
|
+
step_block = details_block.get("stepwise") or {}
|
|
497
|
+
if isinstance(step_block, dict):
|
|
498
|
+
stepwise_details = step_block
|
|
443
499
|
# Extract achievements count if present
|
|
444
500
|
ach = []
|
|
445
501
|
try:
|
|
@@ -465,9 +521,22 @@ async def main() -> None:
|
|
|
465
521
|
length = int(trajs[0].get("length") or 0)
|
|
466
522
|
except Exception:
|
|
467
523
|
pass
|
|
468
|
-
return {
|
|
524
|
+
return {
|
|
525
|
+
"seed": seed,
|
|
526
|
+
"turns": length,
|
|
527
|
+
"achievements": ach,
|
|
528
|
+
"mean_return": mean_return,
|
|
529
|
+
"stepwise": stepwise_details,
|
|
530
|
+
}
|
|
469
531
|
except Exception as e:
|
|
470
|
-
return {
|
|
532
|
+
return {
|
|
533
|
+
"seed": seed,
|
|
534
|
+
"turns": 0,
|
|
535
|
+
"achievements": [],
|
|
536
|
+
"mean_return": None,
|
|
537
|
+
"stepwise": {},
|
|
538
|
+
"error": str(e),
|
|
539
|
+
}
|
|
471
540
|
|
|
472
541
|
results = await asyncio.gather(
|
|
473
542
|
*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)],
|
|
@@ -483,15 +552,63 @@ async def main() -> None:
|
|
|
483
552
|
all_ach[a] += 1
|
|
484
553
|
except Exception:
|
|
485
554
|
pass
|
|
555
|
+
mean_returns: list[float] = []
|
|
556
|
+
stepwise_reward_sums: list[float] = []
|
|
557
|
+
stepwise_indicator_sums: list[float] = []
|
|
558
|
+
stepwise_new_ach_totals: list[float] = []
|
|
559
|
+
strategies_seen = Counter()
|
|
560
|
+
for r in results:
|
|
561
|
+
if not isinstance(r, dict):
|
|
562
|
+
continue
|
|
563
|
+
with contextlib.suppress(Exception):
|
|
564
|
+
mean_val = r.get("mean_return")
|
|
565
|
+
if mean_val is not None:
|
|
566
|
+
mean_returns.append(float(mean_val))
|
|
567
|
+
stepwise_block = r.get("stepwise")
|
|
568
|
+
if isinstance(stepwise_block, dict) and stepwise_block:
|
|
569
|
+
with contextlib.suppress(Exception):
|
|
570
|
+
if stepwise_block.get("reward_sum") is not None:
|
|
571
|
+
stepwise_reward_sums.append(float(stepwise_block.get("reward_sum")))
|
|
572
|
+
with contextlib.suppress(Exception):
|
|
573
|
+
if stepwise_block.get("indicator_sum") is not None:
|
|
574
|
+
stepwise_indicator_sums.append(float(stepwise_block.get("indicator_sum")))
|
|
575
|
+
with contextlib.suppress(Exception):
|
|
576
|
+
if stepwise_block.get("new_achievements_total") is not None:
|
|
577
|
+
stepwise_new_ach_totals.append(
|
|
578
|
+
float(stepwise_block.get("new_achievements_total"))
|
|
579
|
+
)
|
|
580
|
+
strategy_name = stepwise_block.get("strategy")
|
|
581
|
+
if isinstance(strategy_name, str) and strategy_name:
|
|
582
|
+
strategies_seen[strategy_name] += 1
|
|
583
|
+
aggregate: dict[str, Any] = {
|
|
584
|
+
"completed": sum(
|
|
585
|
+
1 for r in results if isinstance(r, dict) and not r.get("error")
|
|
586
|
+
),
|
|
587
|
+
"total": len(results),
|
|
588
|
+
"avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
|
|
589
|
+
"avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
|
|
590
|
+
"achievements_freq": dict(all_ach),
|
|
591
|
+
}
|
|
592
|
+
if mean_returns:
|
|
593
|
+
aggregate["avg_mean_return"] = sum(mean_returns) / len(mean_returns)
|
|
594
|
+
if stepwise_reward_sums:
|
|
595
|
+
aggregate["avg_stepwise_reward_sum"] = sum(stepwise_reward_sums) / len(
|
|
596
|
+
stepwise_reward_sums
|
|
597
|
+
)
|
|
598
|
+
if stepwise_indicator_sums:
|
|
599
|
+
aggregate["avg_stepwise_indicator_sum"] = sum(stepwise_indicator_sums) / len(
|
|
600
|
+
stepwise_indicator_sums
|
|
601
|
+
)
|
|
602
|
+
if stepwise_new_ach_totals:
|
|
603
|
+
aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
|
|
604
|
+
stepwise_new_ach_totals
|
|
605
|
+
)
|
|
606
|
+
if strategies_seen:
|
|
607
|
+
aggregate["stepwise_strategies"] = dict(strategies_seen)
|
|
608
|
+
aggregate["stepwise_samples"] = len(stepwise_reward_sums)
|
|
486
609
|
summary = {
|
|
487
610
|
"episodes": results,
|
|
488
|
-
"aggregate":
|
|
489
|
-
"completed": sum(1 for r in results if not r.get("error")),
|
|
490
|
-
"total": len(results),
|
|
491
|
-
"avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
|
|
492
|
-
"avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
|
|
493
|
-
"achievements_freq": dict(all_ach),
|
|
494
|
-
},
|
|
611
|
+
"aggregate": aggregate,
|
|
495
612
|
}
|
|
496
613
|
print(json.dumps(summary, indent=2))
|
|
497
614
|
else:
|
|
@@ -9,7 +9,7 @@ from datetime import datetime
|
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
11
|
from fastapi import APIRouter, HTTPException, Request, status
|
|
12
|
-
from pydantic import BaseModel
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
13
|
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
14
|
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
15
|
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
@@ -142,12 +142,59 @@ class RolloutTrajectory(BaseModel):
|
|
|
142
142
|
decision_samples: list[dict[str, Any]] | None = None
|
|
143
143
|
|
|
144
144
|
|
|
145
|
+
def _normalize_step_strategy(raw_strategy: Any) -> str:
|
|
146
|
+
if not isinstance(raw_strategy, str):
|
|
147
|
+
return "consistent"
|
|
148
|
+
candidate = raw_strategy.strip().lower()
|
|
149
|
+
if not candidate:
|
|
150
|
+
return "consistent"
|
|
151
|
+
mapping = {
|
|
152
|
+
"simple": "consistent",
|
|
153
|
+
"consistent": "consistent",
|
|
154
|
+
"consistent_stepwise": "consistent",
|
|
155
|
+
"decision_consistent": "consistent",
|
|
156
|
+
"per_achievement": "per_achievement",
|
|
157
|
+
"per-achievement": "per_achievement",
|
|
158
|
+
"perachievement": "per_achievement",
|
|
159
|
+
"achievement_weighted": "per_achievement",
|
|
160
|
+
"complex": "per_achievement",
|
|
161
|
+
}
|
|
162
|
+
return mapping.get(candidate, "consistent")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _coerce_weights(raw_weights: Any) -> dict[str, float]:
|
|
166
|
+
weights: dict[str, float] = {}
|
|
167
|
+
if isinstance(raw_weights, dict):
|
|
168
|
+
for key, value in raw_weights.items():
|
|
169
|
+
try:
|
|
170
|
+
weights[str(key)] = float(value)
|
|
171
|
+
except Exception:
|
|
172
|
+
continue
|
|
173
|
+
return weights
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _coerce_k_limits(raw_limits: Any) -> dict[str, int]:
|
|
177
|
+
limits: dict[str, int] = {}
|
|
178
|
+
if isinstance(raw_limits, dict):
|
|
179
|
+
for key, value in raw_limits.items():
|
|
180
|
+
try:
|
|
181
|
+
limits[str(key)] = int(value)
|
|
182
|
+
except Exception:
|
|
183
|
+
continue
|
|
184
|
+
return limits
|
|
185
|
+
|
|
186
|
+
|
|
145
187
|
def compute_stepwise_reward(
|
|
146
188
|
prev_achievements: dict[str, bool],
|
|
147
189
|
new_achievements: dict[str, bool],
|
|
148
190
|
decision_index: int,
|
|
149
191
|
actions_summary: list[dict[str, Any]],
|
|
150
192
|
indicator_lambda: float,
|
|
193
|
+
*,
|
|
194
|
+
strategy: str | None = None,
|
|
195
|
+
weights: dict[str, float] | None = None,
|
|
196
|
+
k_limits: dict[str, int] | None = None,
|
|
197
|
+
episode_counts: dict[str, int] | None = None,
|
|
151
198
|
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
152
199
|
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
153
200
|
|
|
@@ -156,24 +203,88 @@ def compute_stepwise_reward(
|
|
|
156
203
|
|
|
157
204
|
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
158
205
|
indicator = 1 if unlocked else 0
|
|
159
|
-
|
|
206
|
+
normalized_strategy = _normalize_step_strategy(strategy)
|
|
207
|
+
base_reward = 0.0
|
|
208
|
+
reward_components: list[dict[str, Any]] = []
|
|
209
|
+
credited: list[str] = []
|
|
210
|
+
|
|
211
|
+
if indicator:
|
|
212
|
+
if normalized_strategy == "per_achievement":
|
|
213
|
+
weight_map = weights or {}
|
|
214
|
+
limit_map = k_limits or {}
|
|
215
|
+
counts = episode_counts if isinstance(episode_counts, dict) else {}
|
|
216
|
+
for name in unlocked:
|
|
217
|
+
try:
|
|
218
|
+
limit_val = int(limit_map.get(name, 1))
|
|
219
|
+
except Exception:
|
|
220
|
+
limit_val = 1
|
|
221
|
+
# limit_val <= 0 implies unlimited rewards
|
|
222
|
+
unlimited = limit_val <= 0
|
|
223
|
+
try:
|
|
224
|
+
prev_count = int(counts.get(name, 0))
|
|
225
|
+
except Exception:
|
|
226
|
+
prev_count = 0
|
|
227
|
+
should_credit = unlimited or (prev_count < max(limit_val, 0))
|
|
228
|
+
if should_credit:
|
|
229
|
+
try:
|
|
230
|
+
weight_val = float(weight_map.get(name, 1.0))
|
|
231
|
+
except Exception:
|
|
232
|
+
weight_val = 1.0
|
|
233
|
+
base_reward += weight_val
|
|
234
|
+
reward_components.append(
|
|
235
|
+
{
|
|
236
|
+
"achievement": name,
|
|
237
|
+
"weight": weight_val,
|
|
238
|
+
"count_prior": prev_count,
|
|
239
|
+
"count_limit": limit_val,
|
|
240
|
+
}
|
|
241
|
+
)
|
|
242
|
+
credited.append(name)
|
|
243
|
+
if episode_counts is not None:
|
|
244
|
+
episode_counts[name] = prev_count + 1
|
|
245
|
+
else:
|
|
246
|
+
base_reward = 1.0
|
|
247
|
+
reward_components.append(
|
|
248
|
+
{
|
|
249
|
+
"achievement": "__indicator__",
|
|
250
|
+
"weight": 1.0,
|
|
251
|
+
"count_prior": 0,
|
|
252
|
+
"count_limit": 1,
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
reward_value = float(indicator_lambda) * float(base_reward)
|
|
160
257
|
|
|
161
258
|
stepwise_info = {
|
|
162
259
|
"decision_index": decision_index,
|
|
163
260
|
"indicator": indicator,
|
|
164
261
|
"new_achievements": unlocked,
|
|
165
262
|
"reward": reward_value,
|
|
263
|
+
"strategy": normalized_strategy,
|
|
264
|
+
"base_reward": float(base_reward),
|
|
166
265
|
}
|
|
266
|
+
if reward_components:
|
|
267
|
+
stepwise_info["components"] = reward_components
|
|
268
|
+
if credited:
|
|
269
|
+
stepwise_info["credited_achievements"] = credited
|
|
270
|
+
|
|
167
271
|
decision_sample = {
|
|
168
272
|
"decision_index": decision_index,
|
|
169
273
|
"indicator": indicator,
|
|
170
274
|
"r_i": reward_value,
|
|
275
|
+
"base": float(base_reward),
|
|
276
|
+
"strategy": normalized_strategy,
|
|
171
277
|
"actions": actions_summary,
|
|
172
278
|
}
|
|
279
|
+
if reward_components:
|
|
280
|
+
decision_sample["components"] = reward_components
|
|
281
|
+
|
|
173
282
|
stats = {
|
|
174
283
|
"indicator": float(indicator),
|
|
175
284
|
"reward": reward_value,
|
|
176
285
|
"new_achievements_count": float(len(unlocked)),
|
|
286
|
+
"base_reward": float(base_reward),
|
|
287
|
+
"credited_achievements_count": float(len(credited)),
|
|
177
288
|
}
|
|
178
289
|
return stepwise_info, decision_sample, stats
|
|
179
290
|
|
|
@@ -183,6 +294,9 @@ class RolloutMetrics(BaseModel):
|
|
|
183
294
|
mean_return: float
|
|
184
295
|
num_steps: int
|
|
185
296
|
num_episodes: int = 0
|
|
297
|
+
outcome_score: float | None = None
|
|
298
|
+
events_score: float | None = None
|
|
299
|
+
details: dict[str, Any] = Field(default_factory=dict)
|
|
186
300
|
|
|
187
301
|
|
|
188
302
|
class RolloutResponse(BaseModel):
|
|
@@ -1053,6 +1167,9 @@ async def execute_rollout(
|
|
|
1053
1167
|
|
|
1054
1168
|
step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
|
|
1055
1169
|
step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
|
|
1170
|
+
step_rewards_strategy = _normalize_step_strategy(step_rewards_cfg_raw.get("strategy"))
|
|
1171
|
+
step_rewards_weights = _coerce_weights(step_rewards_cfg_raw.get("weights"))
|
|
1172
|
+
step_rewards_k_limits = _coerce_k_limits(step_rewards_cfg_raw.get("k_limits"))
|
|
1056
1173
|
try:
|
|
1057
1174
|
step_rewards_indicator_lambda = float(
|
|
1058
1175
|
step_rewards_cfg_raw.get("indicator_lambda") or 0.0
|
|
@@ -1113,6 +1230,7 @@ async def execute_rollout(
|
|
|
1113
1230
|
episode_seen_achievements: set[str] = {
|
|
1114
1231
|
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1115
1232
|
}
|
|
1233
|
+
episode_achievement_counts: dict[str, int] = {}
|
|
1116
1234
|
stepwise_indicator_sum = 0.0
|
|
1117
1235
|
stepwise_reward_sum = 0.0
|
|
1118
1236
|
stepwise_new_achievements_total = 0
|
|
@@ -1560,6 +1678,10 @@ async def execute_rollout(
|
|
|
1560
1678
|
decision_index,
|
|
1561
1679
|
decision_actions,
|
|
1562
1680
|
step_rewards_indicator_lambda,
|
|
1681
|
+
strategy=step_rewards_strategy,
|
|
1682
|
+
weights=step_rewards_weights,
|
|
1683
|
+
k_limits=step_rewards_k_limits,
|
|
1684
|
+
episode_counts=episode_achievement_counts,
|
|
1563
1685
|
)
|
|
1564
1686
|
indicator_val = int(stats.get("indicator", 0.0))
|
|
1565
1687
|
reward_stepwise = float(stats.get("reward", 0.0))
|
|
@@ -1656,6 +1778,11 @@ async def execute_rollout(
|
|
|
1656
1778
|
|
|
1657
1779
|
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
1658
1780
|
current_obs = reset_response.observation
|
|
1781
|
+
prev_achievements = _extract_achievements(current_obs)
|
|
1782
|
+
episode_seen_achievements = {
|
|
1783
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1784
|
+
}
|
|
1785
|
+
episode_achievement_counts.clear()
|
|
1659
1786
|
elif request.on_done == "terminate":
|
|
1660
1787
|
break
|
|
1661
1788
|
|
|
@@ -1704,6 +1831,23 @@ async def execute_rollout(
|
|
|
1704
1831
|
num_steps=len(trajectory_steps),
|
|
1705
1832
|
num_episodes=1,
|
|
1706
1833
|
)
|
|
1834
|
+
if step_rewards_active:
|
|
1835
|
+
stepwise_summary: dict[str, Any] = {
|
|
1836
|
+
"indicator_sum": float(stepwise_indicator_sum),
|
|
1837
|
+
"reward_sum": float(stepwise_reward_sum),
|
|
1838
|
+
"new_achievements_total": int(stepwise_new_achievements_total),
|
|
1839
|
+
"mode": step_rewards_mode,
|
|
1840
|
+
"strategy": step_rewards_strategy,
|
|
1841
|
+
"indicator_lambda": float(step_rewards_indicator_lambda),
|
|
1842
|
+
}
|
|
1843
|
+
if step_rewards_beta:
|
|
1844
|
+
stepwise_summary["step_beta"] = float(step_rewards_beta)
|
|
1845
|
+
if step_rewards_strategy == "per_achievement":
|
|
1846
|
+
if step_rewards_weights:
|
|
1847
|
+
stepwise_summary["weights"] = dict(step_rewards_weights)
|
|
1848
|
+
if step_rewards_k_limits:
|
|
1849
|
+
stepwise_summary["k_limits"] = dict(step_rewards_k_limits)
|
|
1850
|
+
metrics.details["stepwise"] = stepwise_summary
|
|
1707
1851
|
|
|
1708
1852
|
# Environment-specific: Log summary if available
|
|
1709
1853
|
try:
|
synth_ai/api/train/builders.py
CHANGED
|
@@ -1,16 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any, cast
|
|
6
7
|
|
|
7
8
|
import click
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
_models_module = importlib.import_module("synth_ai.api.models.supported")
|
|
12
|
+
UnsupportedModelError = _models_module.UnsupportedModelError
|
|
13
|
+
ensure_allowed_model = _models_module.ensure_allowed_model
|
|
14
|
+
normalize_model_identifier = _models_module.normalize_model_identifier
|
|
15
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
16
|
+
raise RuntimeError("Unable to load supported model helpers") from exc
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
prepare_sft_job_payload = importlib.import_module("synth_ai.learning.sft.config").prepare_sft_job_payload
|
|
20
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
21
|
+
raise RuntimeError("Unable to load SFT payload helpers") from exc
|
|
14
22
|
|
|
15
23
|
from .supported_algos import (
|
|
16
24
|
AlgorithmValidationError,
|
|
@@ -122,23 +130,26 @@ def build_rl_payload(
|
|
|
122
130
|
except Exception:
|
|
123
131
|
pass
|
|
124
132
|
|
|
133
|
+
payload_data: dict[str, Any] = {
|
|
134
|
+
"endpoint_base_url": final_task_url.rstrip("/"),
|
|
135
|
+
"config": data,
|
|
136
|
+
}
|
|
125
137
|
payload: dict[str, Any] = {
|
|
126
138
|
"job_type": "rl",
|
|
127
139
|
"compute": data.get("compute", {}),
|
|
128
|
-
"data":
|
|
129
|
-
"endpoint_base_url": final_task_url.rstrip("/"),
|
|
130
|
-
"config": data,
|
|
131
|
-
},
|
|
140
|
+
"data": payload_data,
|
|
132
141
|
"tags": {"source": "train-cli"},
|
|
133
142
|
}
|
|
134
143
|
if model_source:
|
|
135
|
-
|
|
144
|
+
payload_data["model"] = model_source
|
|
136
145
|
if model_base:
|
|
137
|
-
|
|
146
|
+
payload_data["base_model"] = model_base
|
|
138
147
|
|
|
139
148
|
backend = overrides.get("backend")
|
|
140
149
|
if backend:
|
|
141
|
-
|
|
150
|
+
metadata_default: dict[str, Any] = {}
|
|
151
|
+
metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
|
|
152
|
+
metadata["backend_base_url"] = ensure_api_base(str(backend))
|
|
142
153
|
|
|
143
154
|
return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
|
|
144
155
|
|
synth_ai/api/train/cli.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
import os
|
|
5
|
+
from collections.abc import Mapping
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from typing import Any
|
|
6
8
|
|
|
7
9
|
import click
|
|
8
|
-
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
_config_module = importlib.import_module("synth_ai.config.base_url")
|
|
13
|
+
get_backend_from_env = _config_module.get_backend_from_env
|
|
14
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
15
|
+
raise RuntimeError("Unable to load backend configuration helpers") from exc
|
|
9
16
|
|
|
10
17
|
from .builders import build_rl_payload, build_sft_payload
|
|
11
18
|
from .config_finder import discover_configs, prompt_for_config
|
|
@@ -231,7 +238,8 @@ def train_command(
|
|
|
231
238
|
]
|
|
232
239
|
if missing_keys:
|
|
233
240
|
try:
|
|
234
|
-
|
|
241
|
+
_task_apps_module = importlib.import_module("synth_ai.cli.task_apps")
|
|
242
|
+
_interactive_fill_env = _task_apps_module._interactive_fill_env
|
|
235
243
|
except Exception as exc: # pragma: no cover - protective fallback
|
|
236
244
|
raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
|
|
237
245
|
|
|
@@ -386,9 +394,19 @@ def handle_rl(
|
|
|
386
394
|
verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url}
|
|
387
395
|
)
|
|
388
396
|
try:
|
|
389
|
-
|
|
397
|
+
parsed_json = vresp.json()
|
|
390
398
|
except Exception:
|
|
391
|
-
|
|
399
|
+
parsed_json = None
|
|
400
|
+
|
|
401
|
+
if isinstance(parsed_json, Mapping):
|
|
402
|
+
vjs: dict[str, Any] = dict(parsed_json)
|
|
403
|
+
else:
|
|
404
|
+
vjs = {
|
|
405
|
+
"status": vresp.status_code,
|
|
406
|
+
"text": (vresp.text or "")[:400],
|
|
407
|
+
}
|
|
408
|
+
if parsed_json is not None:
|
|
409
|
+
vjs["body"] = parsed_json
|
|
392
410
|
except Exception as _ve:
|
|
393
411
|
raise click.ClickException(
|
|
394
412
|
f"Task app verification call failed: {type(_ve).__name__}: {_ve}"
|
|
@@ -404,8 +422,13 @@ def handle_rl(
|
|
|
404
422
|
# Print concise summary
|
|
405
423
|
try:
|
|
406
424
|
cands = vjs.get("candidates_first15") or []
|
|
407
|
-
|
|
408
|
-
|
|
425
|
+
attempts_raw = vjs.get("attempts")
|
|
426
|
+
attempts: list[Mapping[str, Any]] = (
|
|
427
|
+
[a for a in attempts_raw if isinstance(a, Mapping)]
|
|
428
|
+
if isinstance(attempts_raw, list)
|
|
429
|
+
else []
|
|
430
|
+
)
|
|
431
|
+
statuses = [attempt.get("status") for attempt in attempts]
|
|
409
432
|
click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
|
|
410
433
|
except Exception:
|
|
411
434
|
pass
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import Callable, Iterable, MutableMapping
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -11,6 +12,18 @@ from . import task_app
|
|
|
11
12
|
from .utils import REPO_ROOT, mask_value, read_env_file, write_env_value
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
def _load_saved_env_path() -> Path | None:
|
|
16
|
+
try:
|
|
17
|
+
module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
18
|
+
loader = module.load_env_file_path
|
|
19
|
+
saved_path = loader()
|
|
20
|
+
if saved_path:
|
|
21
|
+
return Path(saved_path)
|
|
22
|
+
except Exception:
|
|
23
|
+
return None
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
|
|
14
27
|
@dataclass(slots=True)
|
|
15
28
|
class KeySpec:
|
|
16
29
|
name: str
|
|
@@ -156,25 +169,11 @@ def resolve_env(
|
|
|
156
169
|
raise click.ClickException(f"Env file not found: {path}")
|
|
157
170
|
resolver = EnvResolver(provided)
|
|
158
171
|
else:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
if saved_env_path:
|
|
165
|
-
saved_path = Path(saved_env_path)
|
|
166
|
-
if saved_path.exists():
|
|
167
|
-
click.echo(f"Using .env file: {saved_path}")
|
|
168
|
-
resolver = EnvResolver([saved_path])
|
|
169
|
-
else:
|
|
170
|
-
# Saved path no longer exists, fall back to prompt
|
|
171
|
-
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
172
|
-
resolver.select_new_env()
|
|
173
|
-
else:
|
|
174
|
-
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
175
|
-
resolver.select_new_env()
|
|
176
|
-
except Exception:
|
|
177
|
-
# If import fails or any error, fall back to original behavior
|
|
172
|
+
saved_path = _load_saved_env_path()
|
|
173
|
+
if saved_path and saved_path.exists():
|
|
174
|
+
click.echo(f"Using .env file: {saved_path}")
|
|
175
|
+
resolver = EnvResolver([saved_path])
|
|
176
|
+
else:
|
|
178
177
|
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
179
178
|
resolver.select_new_env()
|
|
180
179
|
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
from collections.abc import Mapping
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
7
|
+
try:
|
|
8
|
+
_models_module = importlib.import_module("synth_ai.api.models.supported")
|
|
9
|
+
RL_SUPPORTED_MODELS = _models_module.RL_SUPPORTED_MODELS
|
|
10
|
+
SFT_SUPPORTED_MODELS = _models_module.SFT_SUPPORTED_MODELS
|
|
11
|
+
training_modes_for_model = _models_module.training_modes_for_model
|
|
12
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
13
|
+
raise RuntimeError("Unable to load supported model metadata") from exc
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@dataclass(frozen=True)
|