synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/run_eval.py +267 -41
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +74 -33
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +146 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +15 -3
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,14 @@ import contextlib
|
|
|
14
14
|
import json
|
|
15
15
|
import os
|
|
16
16
|
import re
|
|
17
|
-
import
|
|
17
|
+
import sys
|
|
18
18
|
from collections import Counter
|
|
19
|
+
from copy import deepcopy
|
|
19
20
|
from pathlib import Path
|
|
20
21
|
from typing import Any
|
|
21
22
|
|
|
23
|
+
import tomllib
|
|
24
|
+
|
|
22
25
|
import httpx
|
|
23
26
|
|
|
24
27
|
|
|
@@ -115,26 +118,34 @@ class TaskAppClient:
|
|
|
115
118
|
run_id: str,
|
|
116
119
|
env_name: str,
|
|
117
120
|
seed: int,
|
|
118
|
-
difficulty: str,
|
|
121
|
+
difficulty: str | None,
|
|
119
122
|
policy_name: str,
|
|
120
123
|
policy_config: dict[str, Any],
|
|
121
124
|
max_turns: int,
|
|
125
|
+
env_config: dict[str, Any] | None = None,
|
|
126
|
+
ops: list[str] | None = None,
|
|
122
127
|
) -> dict[str, Any]:
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
128
|
+
ops_seq: list[str] = list(ops) if ops is not None else []
|
|
129
|
+
if not ops_seq:
|
|
130
|
+
for _ in range(max_turns):
|
|
131
|
+
ops_seq.extend(["agent", "env"])
|
|
132
|
+
env_cfg: dict[str, Any] = {}
|
|
133
|
+
if isinstance(env_config, dict):
|
|
134
|
+
env_cfg.update(env_config)
|
|
135
|
+
if difficulty is not None and "difficulty" not in env_cfg:
|
|
136
|
+
env_cfg["difficulty"] = difficulty
|
|
126
137
|
payload: dict[str, Any] = {
|
|
127
138
|
"run_id": run_id,
|
|
128
139
|
"env": {
|
|
129
140
|
"env_name": env_name,
|
|
130
|
-
"config":
|
|
141
|
+
"config": env_cfg,
|
|
131
142
|
"seed": seed,
|
|
132
143
|
},
|
|
133
144
|
"policy": {
|
|
134
145
|
"policy_name": policy_name,
|
|
135
146
|
"config": policy_config,
|
|
136
147
|
},
|
|
137
|
-
"ops":
|
|
148
|
+
"ops": ops_seq,
|
|
138
149
|
"on_done": "terminate",
|
|
139
150
|
}
|
|
140
151
|
# Ensure X-API-Key is included
|
|
@@ -323,6 +334,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
323
334
|
observation = created.get("observation") if isinstance(created, dict) else None
|
|
324
335
|
if not isinstance(observation, dict):
|
|
325
336
|
observation = {}
|
|
337
|
+
try:
|
|
338
|
+
ach_map_initial = observation.get("achievements_status")
|
|
339
|
+
if isinstance(ach_map_initial, dict):
|
|
340
|
+
achievements.update(k for k, v in ach_map_initial.items() if v)
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
326
343
|
|
|
327
344
|
try:
|
|
328
345
|
while turns < MAX_TURNS and not done:
|
|
@@ -342,6 +359,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
342
359
|
nxt = step.get("observation")
|
|
343
360
|
if isinstance(nxt, dict):
|
|
344
361
|
observation = nxt
|
|
362
|
+
try:
|
|
363
|
+
ach_map = observation.get("achievements_status")
|
|
364
|
+
if isinstance(ach_map, dict):
|
|
365
|
+
achievements.update(k for k, v in ach_map.items() if v)
|
|
366
|
+
except Exception:
|
|
367
|
+
pass
|
|
345
368
|
finally:
|
|
346
369
|
with contextlib.suppress(Exception):
|
|
347
370
|
await client.terminate(env_name, env_id)
|
|
@@ -349,21 +372,45 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
349
372
|
return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
|
|
350
373
|
|
|
351
374
|
|
|
352
|
-
|
|
353
|
-
|
|
375
|
+
def _load_dotenv_defaults() -> None:
|
|
376
|
+
"""Load .env-style key/value pairs without clobbering explicit exports."""
|
|
354
377
|
try:
|
|
355
|
-
|
|
356
|
-
if env_path.exists():
|
|
357
|
-
for line in env_path.read_text(encoding="utf-8").splitlines():
|
|
358
|
-
line = line.strip()
|
|
359
|
-
if not line or line.startswith("#") or "=" not in line:
|
|
360
|
-
continue
|
|
361
|
-
k, v = line.split("=", 1)
|
|
362
|
-
k = k.strip()
|
|
363
|
-
v = v.strip().strip('"').strip("'")
|
|
364
|
-
os.environ.setdefault(k, v)
|
|
378
|
+
script_path = Path(__file__).resolve()
|
|
365
379
|
except Exception:
|
|
366
|
-
|
|
380
|
+
return
|
|
381
|
+
candidates: list[Path] = []
|
|
382
|
+
# Prefer the repo root .env, then allow per-directory overrides.
|
|
383
|
+
for base in [Path.cwd(), script_path.parent, *script_path.parents]:
|
|
384
|
+
env_path = base / ".env"
|
|
385
|
+
if env_path not in candidates and env_path.is_file():
|
|
386
|
+
candidates.append(env_path)
|
|
387
|
+
seen: set[str] = set()
|
|
388
|
+
try:
|
|
389
|
+
for env_path in candidates:
|
|
390
|
+
try:
|
|
391
|
+
for raw in env_path.read_text(encoding="utf-8").splitlines():
|
|
392
|
+
line = raw.strip()
|
|
393
|
+
if not line or line.startswith("#") or "=" not in line:
|
|
394
|
+
continue
|
|
395
|
+
key, value = line.split("=", 1)
|
|
396
|
+
key = key.strip()
|
|
397
|
+
if not key or key in seen:
|
|
398
|
+
continue
|
|
399
|
+
seen.add(key)
|
|
400
|
+
val = value.strip().strip('"').strip("'")
|
|
401
|
+
os.environ.setdefault(key, val)
|
|
402
|
+
except Exception:
|
|
403
|
+
continue
|
|
404
|
+
except Exception:
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
async def main() -> None:
|
|
409
|
+
_load_dotenv_defaults()
|
|
410
|
+
if not (os.getenv("ENVIRONMENT_API_KEY") or os.getenv("DEV_ENVIRONMENT_API_KEY")):
|
|
411
|
+
raise RuntimeError(
|
|
412
|
+
"ENVIRONMENT_API_KEY is required. Export it or add it to your project .env."
|
|
413
|
+
)
|
|
367
414
|
|
|
368
415
|
parser = argparse.ArgumentParser(
|
|
369
416
|
description="Baseline eval against task app with optional TOML config"
|
|
@@ -415,11 +462,20 @@ async def main() -> None:
|
|
|
415
462
|
async with sem:
|
|
416
463
|
try:
|
|
417
464
|
run_id = f"eval-{seed}"
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
}
|
|
465
|
+
rollout_cfg_raw = cfg.get("rollout") or {}
|
|
466
|
+
rollout_cfg = (
|
|
467
|
+
dict(rollout_cfg_raw) if isinstance(rollout_cfg_raw, dict) else {}
|
|
468
|
+
)
|
|
469
|
+
env_config_raw = rollout_cfg.get("env_config") or {}
|
|
470
|
+
env_config = (
|
|
471
|
+
deepcopy(env_config_raw) if isinstance(env_config_raw, dict) else {}
|
|
472
|
+
)
|
|
473
|
+
policy_cfg_raw = rollout_cfg.get("policy_config") or {}
|
|
474
|
+
policy_cfg = (
|
|
475
|
+
deepcopy(policy_cfg_raw) if isinstance(policy_cfg_raw, dict) else {}
|
|
476
|
+
)
|
|
477
|
+
policy_cfg.setdefault("model", cfg.get("model", MODEL))
|
|
478
|
+
policy_cfg.setdefault("inference_url", inf_url)
|
|
423
479
|
for k in (
|
|
424
480
|
"max_tokens",
|
|
425
481
|
"temperature",
|
|
@@ -428,20 +484,58 @@ async def main() -> None:
|
|
|
428
484
|
"thinking_budget",
|
|
429
485
|
"use_tools",
|
|
430
486
|
):
|
|
431
|
-
if k in cfg and cfg.get(k) is not None:
|
|
487
|
+
if k in cfg and cfg.get(k) is not None and k not in policy_cfg:
|
|
432
488
|
policy_cfg[k] = cfg.get(k)
|
|
433
489
|
|
|
490
|
+
env_name = str(rollout_cfg.get("env_name") or "crafter")
|
|
491
|
+
policy_name = str(
|
|
492
|
+
rollout_cfg.get("policy_name") or cfg.get("policy_name") or "crafter"
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
max_turns_local = MAX_TURNS
|
|
496
|
+
for candidate in (rollout_cfg.get("max_turns"), cfg.get("max_turns")):
|
|
497
|
+
if candidate is None:
|
|
498
|
+
continue
|
|
499
|
+
with contextlib.suppress(Exception):
|
|
500
|
+
max_turns_local = int(candidate)
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
difficulty_override: str | None = None
|
|
504
|
+
if isinstance(env_config, dict):
|
|
505
|
+
diff_cfg = env_config.get("difficulty")
|
|
506
|
+
if isinstance(diff_cfg, str) and diff_cfg:
|
|
507
|
+
difficulty_override = diff_cfg
|
|
508
|
+
if difficulty_override is None:
|
|
509
|
+
cfg_diff = rollout_cfg.get("difficulty") or cfg.get("difficulty")
|
|
510
|
+
if isinstance(cfg_diff, str) and cfg_diff:
|
|
511
|
+
difficulty_override = cfg_diff
|
|
512
|
+
if difficulty_override is None:
|
|
513
|
+
difficulty_override = os.getenv("DIFFICULTY", "easy")
|
|
514
|
+
|
|
434
515
|
r = await client.rollout(
|
|
435
516
|
run_id=run_id,
|
|
436
|
-
env_name=
|
|
517
|
+
env_name=env_name,
|
|
437
518
|
seed=seed,
|
|
438
|
-
difficulty=
|
|
439
|
-
policy_name=
|
|
519
|
+
difficulty=difficulty_override,
|
|
520
|
+
policy_name=policy_name,
|
|
440
521
|
policy_config=policy_cfg,
|
|
441
|
-
max_turns=
|
|
522
|
+
max_turns=max_turns_local,
|
|
523
|
+
env_config=env_config,
|
|
442
524
|
)
|
|
525
|
+
metrics_block = r.get("metrics") or {}
|
|
526
|
+
mean_return = None
|
|
527
|
+
if isinstance(metrics_block, dict):
|
|
528
|
+
with contextlib.suppress(Exception):
|
|
529
|
+
mean_return = float(metrics_block.get("mean_return"))
|
|
530
|
+
stepwise_details: dict[str, Any] = {}
|
|
531
|
+
if isinstance(metrics_block, dict):
|
|
532
|
+
details_block = metrics_block.get("details") or {}
|
|
533
|
+
if isinstance(details_block, dict):
|
|
534
|
+
step_block = details_block.get("stepwise") or {}
|
|
535
|
+
if isinstance(step_block, dict):
|
|
536
|
+
stepwise_details = step_block
|
|
443
537
|
# Extract achievements count if present
|
|
444
|
-
|
|
538
|
+
achieved: set[str] = set()
|
|
445
539
|
try:
|
|
446
540
|
trajs = r.get("trajectories") or []
|
|
447
541
|
final_obs = (
|
|
@@ -455,9 +549,29 @@ async def main() -> None:
|
|
|
455
549
|
else None
|
|
456
550
|
)
|
|
457
551
|
if isinstance(ach_map, dict):
|
|
458
|
-
|
|
552
|
+
achieved.update(k for k, v in ach_map.items() if v)
|
|
553
|
+
except Exception:
|
|
554
|
+
pass
|
|
555
|
+
try:
|
|
556
|
+
step_seen = stepwise_details.get("unique_achievements")
|
|
557
|
+
except Exception:
|
|
558
|
+
step_seen = None
|
|
559
|
+
if isinstance(step_seen, (list, tuple, set)):
|
|
560
|
+
achieved.update(str(a) for a in step_seen)
|
|
561
|
+
else:
|
|
562
|
+
try:
|
|
563
|
+
alt_seen = stepwise_details.get("achievements_seen")
|
|
564
|
+
if isinstance(alt_seen, (list, tuple, set)):
|
|
565
|
+
achieved.update(str(a) for a in alt_seen)
|
|
566
|
+
except Exception:
|
|
567
|
+
pass
|
|
568
|
+
try:
|
|
569
|
+
summary_final = stepwise_details.get("final_achievements")
|
|
570
|
+
if isinstance(summary_final, (list, tuple, set)):
|
|
571
|
+
achieved.update(str(a) for a in summary_final)
|
|
459
572
|
except Exception:
|
|
460
573
|
pass
|
|
574
|
+
ach = sorted(achieved)
|
|
461
575
|
length = 0
|
|
462
576
|
try:
|
|
463
577
|
trajs = r.get("trajectories") or []
|
|
@@ -465,9 +579,22 @@ async def main() -> None:
|
|
|
465
579
|
length = int(trajs[0].get("length") or 0)
|
|
466
580
|
except Exception:
|
|
467
581
|
pass
|
|
468
|
-
return {
|
|
582
|
+
return {
|
|
583
|
+
"seed": seed,
|
|
584
|
+
"turns": length,
|
|
585
|
+
"achievements": ach,
|
|
586
|
+
"mean_return": mean_return,
|
|
587
|
+
"stepwise": stepwise_details,
|
|
588
|
+
}
|
|
469
589
|
except Exception as e:
|
|
470
|
-
return {
|
|
590
|
+
return {
|
|
591
|
+
"seed": seed,
|
|
592
|
+
"turns": 0,
|
|
593
|
+
"achievements": [],
|
|
594
|
+
"mean_return": None,
|
|
595
|
+
"stepwise": {},
|
|
596
|
+
"error": str(e),
|
|
597
|
+
}
|
|
471
598
|
|
|
472
599
|
results = await asyncio.gather(
|
|
473
600
|
*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)],
|
|
@@ -483,17 +610,116 @@ async def main() -> None:
|
|
|
483
610
|
all_ach[a] += 1
|
|
484
611
|
except Exception:
|
|
485
612
|
pass
|
|
613
|
+
mean_returns: list[float] = []
|
|
614
|
+
stepwise_reward_sums: list[float] = []
|
|
615
|
+
stepwise_indicator_sums: list[float] = []
|
|
616
|
+
stepwise_new_ach_totals: list[float] = []
|
|
617
|
+
stepwise_resource_rewards: list[float] = []
|
|
618
|
+
strategies_seen = Counter()
|
|
619
|
+
unique_union: set[str] = set()
|
|
620
|
+
final_union: set[str] = set()
|
|
621
|
+
for r in results:
|
|
622
|
+
if not isinstance(r, dict):
|
|
623
|
+
continue
|
|
624
|
+
with contextlib.suppress(Exception):
|
|
625
|
+
mean_val = r.get("mean_return")
|
|
626
|
+
if mean_val is not None:
|
|
627
|
+
mean_returns.append(float(mean_val))
|
|
628
|
+
stepwise_block = r.get("stepwise")
|
|
629
|
+
if isinstance(stepwise_block, dict) and stepwise_block:
|
|
630
|
+
with contextlib.suppress(Exception):
|
|
631
|
+
if stepwise_block.get("reward_sum") is not None:
|
|
632
|
+
stepwise_reward_sums.append(float(stepwise_block.get("reward_sum")))
|
|
633
|
+
with contextlib.suppress(Exception):
|
|
634
|
+
if stepwise_block.get("indicator_sum") is not None:
|
|
635
|
+
stepwise_indicator_sums.append(float(stepwise_block.get("indicator_sum")))
|
|
636
|
+
with contextlib.suppress(Exception):
|
|
637
|
+
if stepwise_block.get("new_achievements_total") is not None:
|
|
638
|
+
stepwise_new_ach_totals.append(
|
|
639
|
+
float(stepwise_block.get("new_achievements_total"))
|
|
640
|
+
)
|
|
641
|
+
with contextlib.suppress(Exception):
|
|
642
|
+
if stepwise_block.get("resource_reward") is not None:
|
|
643
|
+
stepwise_resource_rewards.append(
|
|
644
|
+
float(stepwise_block.get("resource_reward"))
|
|
645
|
+
)
|
|
646
|
+
with contextlib.suppress(Exception):
|
|
647
|
+
uniq = stepwise_block.get("unique_achievements") or []
|
|
648
|
+
if isinstance(uniq, (list, tuple, set)):
|
|
649
|
+
unique_union.update(str(v) for v in uniq)
|
|
650
|
+
with contextlib.suppress(Exception):
|
|
651
|
+
final = stepwise_block.get("final_achievements") or []
|
|
652
|
+
if isinstance(final, (list, tuple, set)):
|
|
653
|
+
final_union.update(str(v) for v in final)
|
|
654
|
+
strategy_name = stepwise_block.get("strategy")
|
|
655
|
+
if isinstance(strategy_name, str) and strategy_name:
|
|
656
|
+
strategies_seen[strategy_name] += 1
|
|
657
|
+
aggregate: dict[str, Any] = {
|
|
658
|
+
"completed": sum(
|
|
659
|
+
1 for r in results if isinstance(r, dict) and not r.get("error")
|
|
660
|
+
),
|
|
661
|
+
"total": len(results),
|
|
662
|
+
"avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
|
|
663
|
+
"avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
|
|
664
|
+
"achievements_freq": dict(all_ach),
|
|
665
|
+
}
|
|
666
|
+
if mean_returns:
|
|
667
|
+
aggregate["avg_mean_return"] = sum(mean_returns) / len(mean_returns)
|
|
668
|
+
if stepwise_reward_sums:
|
|
669
|
+
aggregate["avg_stepwise_reward_sum"] = sum(stepwise_reward_sums) / len(
|
|
670
|
+
stepwise_reward_sums
|
|
671
|
+
)
|
|
672
|
+
if stepwise_indicator_sums:
|
|
673
|
+
aggregate["avg_stepwise_indicator_sum"] = sum(stepwise_indicator_sums) / len(
|
|
674
|
+
stepwise_indicator_sums
|
|
675
|
+
)
|
|
676
|
+
if stepwise_new_ach_totals:
|
|
677
|
+
aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
|
|
678
|
+
stepwise_new_ach_totals
|
|
679
|
+
)
|
|
680
|
+
if stepwise_resource_rewards:
|
|
681
|
+
aggregate["avg_stepwise_resource_reward"] = (
|
|
682
|
+
sum(stepwise_resource_rewards) / len(stepwise_resource_rewards)
|
|
683
|
+
)
|
|
684
|
+
if strategies_seen:
|
|
685
|
+
aggregate["stepwise_strategies"] = dict(strategies_seen)
|
|
686
|
+
aggregate["stepwise_samples"] = max(
|
|
687
|
+
len(stepwise_reward_sums),
|
|
688
|
+
len(stepwise_indicator_sums),
|
|
689
|
+
len(stepwise_new_ach_totals),
|
|
690
|
+
len(stepwise_resource_rewards),
|
|
691
|
+
) if any(
|
|
692
|
+
(
|
|
693
|
+
stepwise_reward_sums,
|
|
694
|
+
stepwise_indicator_sums,
|
|
695
|
+
stepwise_new_ach_totals,
|
|
696
|
+
stepwise_resource_rewards,
|
|
697
|
+
)
|
|
698
|
+
) else 0
|
|
699
|
+
if not unique_union:
|
|
700
|
+
for r in results:
|
|
701
|
+
try:
|
|
702
|
+
for a in r.get("achievements") or []:
|
|
703
|
+
unique_union.add(str(a))
|
|
704
|
+
except Exception:
|
|
705
|
+
continue
|
|
706
|
+
if not final_union:
|
|
707
|
+
final_union.update(unique_union)
|
|
708
|
+
if unique_union:
|
|
709
|
+
aggregate["unique_achievements_union"] = sorted(unique_union)
|
|
710
|
+
if final_union:
|
|
711
|
+
aggregate["final_achievements_union"] = sorted(final_union)
|
|
486
712
|
summary = {
|
|
487
713
|
"episodes": results,
|
|
488
|
-
"aggregate":
|
|
489
|
-
"completed": sum(1 for r in results if not r.get("error")),
|
|
490
|
-
"total": len(results),
|
|
491
|
-
"avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
|
|
492
|
-
"avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
|
|
493
|
-
"achievements_freq": dict(all_ach),
|
|
494
|
-
},
|
|
714
|
+
"aggregate": aggregate,
|
|
495
715
|
}
|
|
496
716
|
print(json.dumps(summary, indent=2))
|
|
717
|
+
# Failure guardrails: any error or zero-turn episodes across the board
|
|
718
|
+
any_errors = any(isinstance(r, dict) and r.get("error") for r in results)
|
|
719
|
+
all_zero_turns = all((int(r.get("turns") or 0) == 0) for r in results if isinstance(r, dict))
|
|
720
|
+
if any_errors or all_zero_turns:
|
|
721
|
+
# Exit non-zero so automation/CI treats this as a failure
|
|
722
|
+
sys.exit(2)
|
|
497
723
|
else:
|
|
498
724
|
|
|
499
725
|
async def _run(seed: int):
|
|
@@ -93,6 +93,7 @@ TASK_APP_ROOT = _resolve_task_app_root(REPO_ROOT)
|
|
|
93
93
|
SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
94
94
|
|
|
95
95
|
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
96
|
+
RUBRICS_ROOT = (EXAMPLES_ROOT / "multi_step" / "rubrics").resolve()
|
|
96
97
|
|
|
97
98
|
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
98
99
|
try:
|
|
@@ -344,40 +345,9 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
|
344
345
|
)
|
|
345
346
|
|
|
346
347
|
|
|
347
|
-
OUTCOME_RUBRIC = load_rubric(
|
|
348
|
-
{
|
|
349
|
-
"version": "1",
|
|
350
|
-
"goal_text": "Reward unlocking Crafter achievements and survival.",
|
|
351
|
-
"aggregation": "weighted_sum",
|
|
352
|
-
"criteria": [
|
|
353
|
-
{
|
|
354
|
-
"id": "achievements",
|
|
355
|
-
"description": "Unlock achievements or crafting milestones.",
|
|
356
|
-
"weight": 1.0,
|
|
357
|
-
},
|
|
358
|
-
{
|
|
359
|
-
"id": "survival",
|
|
360
|
-
"description": "Maintain health, food, and drink levels.",
|
|
361
|
-
"weight": 1.0,
|
|
362
|
-
},
|
|
363
|
-
],
|
|
364
|
-
}
|
|
365
|
-
)
|
|
348
|
+
OUTCOME_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_outcome_rubric.json"))
|
|
366
349
|
|
|
367
|
-
EVENTS_RUBRIC = load_rubric(
|
|
368
|
-
{
|
|
369
|
-
"version": "1",
|
|
370
|
-
"goal_text": "Encourage purposeful step-wise exploration and crafting.",
|
|
371
|
-
"aggregation": "weighted_sum",
|
|
372
|
-
"criteria": [
|
|
373
|
-
{
|
|
374
|
-
"id": "progress_steps",
|
|
375
|
-
"description": "Actions progress quests, crafting, or exploration.",
|
|
376
|
-
"weight": 1.0,
|
|
377
|
-
}
|
|
378
|
-
],
|
|
379
|
-
}
|
|
380
|
-
)
|
|
350
|
+
EVENTS_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_events_rubric.json"))
|
|
381
351
|
|
|
382
352
|
|
|
383
353
|
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|