synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +186 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +7 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +21 -46
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +67 -49
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +242 -193
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
- examples/warming_up_to_rl/run_eval.py +127 -18
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +73 -29
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +134 -0
- synth_ai/api/train/configs/sft.py +95 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +49 -43
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +86 -106
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1710 -186
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +82 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +127 -0
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/task/rubrics/strict.py +149 -0
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +228 -89
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -1
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -4,9 +4,11 @@
|
|
|
4
4
|
# task_app_url = "https://YOUR-TASK-APP.modal.run"
|
|
5
5
|
|
|
6
6
|
model = "qwen/qwen3-32b"
|
|
7
|
-
|
|
7
|
+
# Route inference to local task app Groq proxy
|
|
8
|
+
inference_url = "http://localhost:8001/proxy/groq"
|
|
9
|
+
num_episodes = 10
|
|
8
10
|
max_turns = 10
|
|
9
|
-
concurrency =
|
|
11
|
+
concurrency = 10
|
|
10
12
|
# difficulty = "easy" # optional
|
|
11
13
|
|
|
12
14
|
[rollout]
|
|
@@ -14,12 +14,14 @@ import contextlib
|
|
|
14
14
|
import json
|
|
15
15
|
import os
|
|
16
16
|
import re
|
|
17
|
-
import
|
|
18
|
-
from copy import deepcopy
|
|
17
|
+
import sys
|
|
19
18
|
from collections import Counter
|
|
19
|
+
from copy import deepcopy
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Any
|
|
22
22
|
|
|
23
|
+
import tomllib
|
|
24
|
+
|
|
23
25
|
import httpx
|
|
24
26
|
|
|
25
27
|
|
|
@@ -332,6 +334,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
332
334
|
observation = created.get("observation") if isinstance(created, dict) else None
|
|
333
335
|
if not isinstance(observation, dict):
|
|
334
336
|
observation = {}
|
|
337
|
+
try:
|
|
338
|
+
ach_map_initial = observation.get("achievements_status")
|
|
339
|
+
if isinstance(ach_map_initial, dict):
|
|
340
|
+
achievements.update(k for k, v in ach_map_initial.items() if v)
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
335
343
|
|
|
336
344
|
try:
|
|
337
345
|
while turns < MAX_TURNS and not done:
|
|
@@ -351,6 +359,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
351
359
|
nxt = step.get("observation")
|
|
352
360
|
if isinstance(nxt, dict):
|
|
353
361
|
observation = nxt
|
|
362
|
+
try:
|
|
363
|
+
ach_map = observation.get("achievements_status")
|
|
364
|
+
if isinstance(ach_map, dict):
|
|
365
|
+
achievements.update(k for k, v in ach_map.items() if v)
|
|
366
|
+
except Exception:
|
|
367
|
+
pass
|
|
354
368
|
finally:
|
|
355
369
|
with contextlib.suppress(Exception):
|
|
356
370
|
await client.terminate(env_name, env_id)
|
|
@@ -358,21 +372,45 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
358
372
|
return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
|
|
359
373
|
|
|
360
374
|
|
|
361
|
-
|
|
362
|
-
|
|
375
|
+
def _load_dotenv_defaults() -> None:
|
|
376
|
+
"""Load .env-style key/value pairs without clobbering explicit exports."""
|
|
363
377
|
try:
|
|
364
|
-
|
|
365
|
-
if env_path.exists():
|
|
366
|
-
for line in env_path.read_text(encoding="utf-8").splitlines():
|
|
367
|
-
line = line.strip()
|
|
368
|
-
if not line or line.startswith("#") or "=" not in line:
|
|
369
|
-
continue
|
|
370
|
-
k, v = line.split("=", 1)
|
|
371
|
-
k = k.strip()
|
|
372
|
-
v = v.strip().strip('"').strip("'")
|
|
373
|
-
os.environ.setdefault(k, v)
|
|
378
|
+
script_path = Path(__file__).resolve()
|
|
374
379
|
except Exception:
|
|
375
|
-
|
|
380
|
+
return
|
|
381
|
+
candidates: list[Path] = []
|
|
382
|
+
# Prefer the repo root .env, then allow per-directory overrides.
|
|
383
|
+
for base in [Path.cwd(), script_path.parent, *script_path.parents]:
|
|
384
|
+
env_path = base / ".env"
|
|
385
|
+
if env_path not in candidates and env_path.is_file():
|
|
386
|
+
candidates.append(env_path)
|
|
387
|
+
seen: set[str] = set()
|
|
388
|
+
try:
|
|
389
|
+
for env_path in candidates:
|
|
390
|
+
try:
|
|
391
|
+
for raw in env_path.read_text(encoding="utf-8").splitlines():
|
|
392
|
+
line = raw.strip()
|
|
393
|
+
if not line or line.startswith("#") or "=" not in line:
|
|
394
|
+
continue
|
|
395
|
+
key, value = line.split("=", 1)
|
|
396
|
+
key = key.strip()
|
|
397
|
+
if not key or key in seen:
|
|
398
|
+
continue
|
|
399
|
+
seen.add(key)
|
|
400
|
+
val = value.strip().strip('"').strip("'")
|
|
401
|
+
os.environ.setdefault(key, val)
|
|
402
|
+
except Exception:
|
|
403
|
+
continue
|
|
404
|
+
except Exception:
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
async def main() -> None:
|
|
409
|
+
_load_dotenv_defaults()
|
|
410
|
+
if not (os.getenv("ENVIRONMENT_API_KEY") or os.getenv("DEV_ENVIRONMENT_API_KEY")):
|
|
411
|
+
raise RuntimeError(
|
|
412
|
+
"ENVIRONMENT_API_KEY is required. Export it or add it to your project .env."
|
|
413
|
+
)
|
|
376
414
|
|
|
377
415
|
parser = argparse.ArgumentParser(
|
|
378
416
|
description="Baseline eval against task app with optional TOML config"
|
|
@@ -497,7 +535,7 @@ async def main() -> None:
|
|
|
497
535
|
if isinstance(step_block, dict):
|
|
498
536
|
stepwise_details = step_block
|
|
499
537
|
# Extract achievements count if present
|
|
500
|
-
|
|
538
|
+
achieved: set[str] = set()
|
|
501
539
|
try:
|
|
502
540
|
trajs = r.get("trajectories") or []
|
|
503
541
|
final_obs = (
|
|
@@ -511,9 +549,29 @@ async def main() -> None:
|
|
|
511
549
|
else None
|
|
512
550
|
)
|
|
513
551
|
if isinstance(ach_map, dict):
|
|
514
|
-
|
|
552
|
+
achieved.update(k for k, v in ach_map.items() if v)
|
|
553
|
+
except Exception:
|
|
554
|
+
pass
|
|
555
|
+
try:
|
|
556
|
+
step_seen = stepwise_details.get("unique_achievements")
|
|
557
|
+
except Exception:
|
|
558
|
+
step_seen = None
|
|
559
|
+
if isinstance(step_seen, (list, tuple, set)):
|
|
560
|
+
achieved.update(str(a) for a in step_seen)
|
|
561
|
+
else:
|
|
562
|
+
try:
|
|
563
|
+
alt_seen = stepwise_details.get("achievements_seen")
|
|
564
|
+
if isinstance(alt_seen, (list, tuple, set)):
|
|
565
|
+
achieved.update(str(a) for a in alt_seen)
|
|
566
|
+
except Exception:
|
|
567
|
+
pass
|
|
568
|
+
try:
|
|
569
|
+
summary_final = stepwise_details.get("final_achievements")
|
|
570
|
+
if isinstance(summary_final, (list, tuple, set)):
|
|
571
|
+
achieved.update(str(a) for a in summary_final)
|
|
515
572
|
except Exception:
|
|
516
573
|
pass
|
|
574
|
+
ach = sorted(achieved)
|
|
517
575
|
length = 0
|
|
518
576
|
try:
|
|
519
577
|
trajs = r.get("trajectories") or []
|
|
@@ -556,7 +614,10 @@ async def main() -> None:
|
|
|
556
614
|
stepwise_reward_sums: list[float] = []
|
|
557
615
|
stepwise_indicator_sums: list[float] = []
|
|
558
616
|
stepwise_new_ach_totals: list[float] = []
|
|
617
|
+
stepwise_resource_rewards: list[float] = []
|
|
559
618
|
strategies_seen = Counter()
|
|
619
|
+
unique_union: set[str] = set()
|
|
620
|
+
final_union: set[str] = set()
|
|
560
621
|
for r in results:
|
|
561
622
|
if not isinstance(r, dict):
|
|
562
623
|
continue
|
|
@@ -577,6 +638,19 @@ async def main() -> None:
|
|
|
577
638
|
stepwise_new_ach_totals.append(
|
|
578
639
|
float(stepwise_block.get("new_achievements_total"))
|
|
579
640
|
)
|
|
641
|
+
with contextlib.suppress(Exception):
|
|
642
|
+
if stepwise_block.get("resource_reward") is not None:
|
|
643
|
+
stepwise_resource_rewards.append(
|
|
644
|
+
float(stepwise_block.get("resource_reward"))
|
|
645
|
+
)
|
|
646
|
+
with contextlib.suppress(Exception):
|
|
647
|
+
uniq = stepwise_block.get("unique_achievements") or []
|
|
648
|
+
if isinstance(uniq, (list, tuple, set)):
|
|
649
|
+
unique_union.update(str(v) for v in uniq)
|
|
650
|
+
with contextlib.suppress(Exception):
|
|
651
|
+
final = stepwise_block.get("final_achievements") or []
|
|
652
|
+
if isinstance(final, (list, tuple, set)):
|
|
653
|
+
final_union.update(str(v) for v in final)
|
|
580
654
|
strategy_name = stepwise_block.get("strategy")
|
|
581
655
|
if isinstance(strategy_name, str) and strategy_name:
|
|
582
656
|
strategies_seen[strategy_name] += 1
|
|
@@ -603,14 +677,49 @@ async def main() -> None:
|
|
|
603
677
|
aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
|
|
604
678
|
stepwise_new_ach_totals
|
|
605
679
|
)
|
|
680
|
+
if stepwise_resource_rewards:
|
|
681
|
+
aggregate["avg_stepwise_resource_reward"] = (
|
|
682
|
+
sum(stepwise_resource_rewards) / len(stepwise_resource_rewards)
|
|
683
|
+
)
|
|
606
684
|
if strategies_seen:
|
|
607
685
|
aggregate["stepwise_strategies"] = dict(strategies_seen)
|
|
608
|
-
aggregate["stepwise_samples"] =
|
|
686
|
+
aggregate["stepwise_samples"] = max(
|
|
687
|
+
len(stepwise_reward_sums),
|
|
688
|
+
len(stepwise_indicator_sums),
|
|
689
|
+
len(stepwise_new_ach_totals),
|
|
690
|
+
len(stepwise_resource_rewards),
|
|
691
|
+
) if any(
|
|
692
|
+
(
|
|
693
|
+
stepwise_reward_sums,
|
|
694
|
+
stepwise_indicator_sums,
|
|
695
|
+
stepwise_new_ach_totals,
|
|
696
|
+
stepwise_resource_rewards,
|
|
697
|
+
)
|
|
698
|
+
) else 0
|
|
699
|
+
if not unique_union:
|
|
700
|
+
for r in results:
|
|
701
|
+
try:
|
|
702
|
+
for a in r.get("achievements") or []:
|
|
703
|
+
unique_union.add(str(a))
|
|
704
|
+
except Exception:
|
|
705
|
+
continue
|
|
706
|
+
if not final_union:
|
|
707
|
+
final_union.update(unique_union)
|
|
708
|
+
if unique_union:
|
|
709
|
+
aggregate["unique_achievements_union"] = sorted(unique_union)
|
|
710
|
+
if final_union:
|
|
711
|
+
aggregate["final_achievements_union"] = sorted(final_union)
|
|
609
712
|
summary = {
|
|
610
713
|
"episodes": results,
|
|
611
714
|
"aggregate": aggregate,
|
|
612
715
|
}
|
|
613
716
|
print(json.dumps(summary, indent=2))
|
|
717
|
+
# Failure guardrails: any error or zero-turn episodes across the board
|
|
718
|
+
any_errors = any(isinstance(r, dict) and r.get("error") for r in results)
|
|
719
|
+
all_zero_turns = all((int(r.get("turns") or 0) == 0) for r in results if isinstance(r, dict))
|
|
720
|
+
if any_errors or all_zero_turns:
|
|
721
|
+
# Exit non-zero so automation/CI treats this as a failure
|
|
722
|
+
sys.exit(2)
|
|
614
723
|
else:
|
|
615
724
|
|
|
616
725
|
async def _run(seed: int):
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Download subsets of the MATH dataset to local JSONL files."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def extract_examples(dataset: Any, *, limit: int | None) -> list[dict[str, str]]:
|
|
15
|
+
if limit is not None:
|
|
16
|
+
dataset = dataset.select(range(min(limit, len(dataset))))
|
|
17
|
+
examples: list[dict[str, str]] = []
|
|
18
|
+
for item in dataset:
|
|
19
|
+
problem = (item.get("problem") or "").strip()
|
|
20
|
+
solution = item.get("solution") or ""
|
|
21
|
+
if isinstance(solution, list):
|
|
22
|
+
solution = "\n".join(str(part) for part in solution)
|
|
23
|
+
examples.append(
|
|
24
|
+
{
|
|
25
|
+
"problem": problem,
|
|
26
|
+
"solution": solution,
|
|
27
|
+
}
|
|
28
|
+
)
|
|
29
|
+
return examples
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def write_jsonl(path: Path, rows: list[dict[str, str]]) -> None:
|
|
33
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
35
|
+
for row in rows:
|
|
36
|
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def main() -> None:
|
|
40
|
+
parser = argparse.ArgumentParser(
|
|
41
|
+
description="Download MATH dataset splits to JSONL for offline use"
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"--output-dir", default="examples/rl/data", help="Directory to write <split>.jsonl files"
|
|
45
|
+
)
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"--dataset",
|
|
48
|
+
default="nlile/hendrycks-MATH-benchmark",
|
|
49
|
+
help="Hugging Face dataset identifier",
|
|
50
|
+
)
|
|
51
|
+
parser.add_argument(
|
|
52
|
+
"--config", default="algebra", help="Hugging Face dataset config (if required)"
|
|
53
|
+
)
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
"--splits", nargs="*", default=["train", "validation", "test"], help="Splits to download"
|
|
56
|
+
)
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--limit", type=int, default=None, help="Optional cap on examples per split"
|
|
59
|
+
)
|
|
60
|
+
args = parser.parse_args()
|
|
61
|
+
|
|
62
|
+
output_dir = Path(args.output_dir).expanduser()
|
|
63
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
64
|
+
|
|
65
|
+
for split in args.splits:
|
|
66
|
+
print(f"[INFO] Downloading {args.dataset} ({args.config}) split={split}")
|
|
67
|
+
if args.config:
|
|
68
|
+
dataset = load_dataset(args.dataset, args.config, split=split)
|
|
69
|
+
else:
|
|
70
|
+
dataset = load_dataset(args.dataset, split=split)
|
|
71
|
+
rows = extract_examples(dataset, limit=args.limit)
|
|
72
|
+
out_path = output_dir / f"{split}.jsonl"
|
|
73
|
+
write_jsonl(out_path, rows)
|
|
74
|
+
print(f"[INFO] Wrote {len(rows)} examples to {out_path}")
|
|
75
|
+
|
|
76
|
+
print("Done. Set MATH_DATASET_LOCAL_DIR to the output directory when serving the task app.")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if __name__ == "__main__":
|
|
80
|
+
main()
|
synth_ai/__init__.py
CHANGED
|
@@ -2,6 +2,28 @@
|
|
|
2
2
|
Synth AI - Software for aiding the best and multiplying the will.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from importlib import metadata as _metadata
|
|
8
|
+
from importlib.metadata import PackageNotFoundError
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
try: # Prefer the installed package metadata when available
|
|
12
|
+
__version__ = _metadata.version("synth-ai")
|
|
13
|
+
except PackageNotFoundError: # Fallback to pyproject version for editable installs
|
|
14
|
+
try:
|
|
15
|
+
import tomllib as _toml # Python 3.11+
|
|
16
|
+
except ModuleNotFoundError: # pragma: no cover - legacy interpreter guard
|
|
17
|
+
import tomli as _toml # type: ignore[no-redef]
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
|
21
|
+
with pyproject_path.open("rb") as fh:
|
|
22
|
+
_pyproject = _toml.load(fh)
|
|
23
|
+
__version__ = str(_pyproject["project"]["version"])
|
|
24
|
+
except Exception:
|
|
25
|
+
__version__ = "0.0.0.dev0"
|
|
26
|
+
|
|
5
27
|
# Environment exports - moved from synth-env
|
|
6
28
|
from synth_ai.environments import * # noqa
|
|
7
29
|
import synth_ai.environments as environments # expose module name for __all__
|
|
@@ -21,12 +43,22 @@ try:
|
|
|
21
43
|
except Exception:
|
|
22
44
|
AsyncOpenAI = OpenAI = None # type: ignore
|
|
23
45
|
|
|
46
|
+
# Judge API contract schemas
|
|
47
|
+
from synth_ai.judge_schemas import (
|
|
48
|
+
CriterionScorePayload,
|
|
49
|
+
JudgeOptions,
|
|
50
|
+
JudgeScoreRequest,
|
|
51
|
+
JudgeScoreResponse,
|
|
52
|
+
JudgeTaskApp,
|
|
53
|
+
JudgeTracePayload,
|
|
54
|
+
ReviewPayload,
|
|
55
|
+
)
|
|
56
|
+
|
|
24
57
|
# Legacy tracing v1 is not required for v3 usage and can be unavailable in minimal envs.
|
|
25
58
|
tracing = None # type: ignore
|
|
26
59
|
EventPartitionElement = RewardSignal = SystemTrace = TrainingQuestion = None # type: ignore
|
|
27
60
|
trace_event_async = trace_event_sync = upload = None # type: ignore
|
|
28
61
|
|
|
29
|
-
__version__ = "0.2.6.dev4"
|
|
30
62
|
__all__ = [
|
|
31
63
|
"LM",
|
|
32
64
|
"OpenAI",
|
|
@@ -34,4 +66,12 @@ __all__ = [
|
|
|
34
66
|
"Anthropic",
|
|
35
67
|
"AsyncAnthropic",
|
|
36
68
|
"environments",
|
|
69
|
+
# Judge API contracts
|
|
70
|
+
"JudgeScoreRequest",
|
|
71
|
+
"JudgeScoreResponse",
|
|
72
|
+
"JudgeOptions",
|
|
73
|
+
"JudgeTaskApp",
|
|
74
|
+
"JudgeTracePayload",
|
|
75
|
+
"ReviewPayload",
|
|
76
|
+
"CriterionScorePayload",
|
|
37
77
|
] # Explicitly define public API (v1 tracing omitted in minimal env)
|
synth_ai/api/train/builders.py
CHANGED
|
@@ -1,31 +1,45 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any, cast
|
|
7
8
|
|
|
8
9
|
import click
|
|
10
|
+
from pydantic import ValidationError
|
|
9
11
|
|
|
10
12
|
try:
|
|
11
|
-
_models_module =
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
_models_module = cast(
|
|
14
|
+
Any, importlib.import_module("synth_ai.api.models.supported")
|
|
15
|
+
)
|
|
16
|
+
UnsupportedModelError = cast(type[Exception], _models_module.UnsupportedModelError)
|
|
17
|
+
ensure_allowed_model = cast(
|
|
18
|
+
Callable[..., None], _models_module.ensure_allowed_model
|
|
19
|
+
)
|
|
20
|
+
normalize_model_identifier = cast(
|
|
21
|
+
Callable[[str], str], _models_module.normalize_model_identifier
|
|
22
|
+
)
|
|
15
23
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
16
24
|
raise RuntimeError("Unable to load supported model helpers") from exc
|
|
17
25
|
|
|
18
26
|
try:
|
|
19
|
-
|
|
27
|
+
_sft_module = cast(
|
|
28
|
+
Any, importlib.import_module("synth_ai.learning.sft.config")
|
|
29
|
+
)
|
|
30
|
+
prepare_sft_job_payload = cast(
|
|
31
|
+
Callable[..., dict[str, Any]], _sft_module.prepare_sft_job_payload
|
|
32
|
+
)
|
|
20
33
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
21
34
|
raise RuntimeError("Unable to load SFT payload helpers") from exc
|
|
22
35
|
|
|
36
|
+
from .configs import RLConfig, SFTConfig
|
|
23
37
|
from .supported_algos import (
|
|
24
38
|
AlgorithmValidationError,
|
|
25
39
|
ensure_model_supported_for_algorithm,
|
|
26
40
|
validate_algorithm_config,
|
|
27
41
|
)
|
|
28
|
-
from .utils import TrainError, ensure_api_base
|
|
42
|
+
from .utils import TrainError, ensure_api_base
|
|
29
43
|
|
|
30
44
|
|
|
31
45
|
@dataclass(slots=True)
|
|
@@ -42,6 +56,16 @@ class SFTBuildResult:
|
|
|
42
56
|
validation_file: Path | None
|
|
43
57
|
|
|
44
58
|
|
|
59
|
+
def _format_validation_error(path: Path, exc: ValidationError) -> str:
|
|
60
|
+
lines: list[str] = []
|
|
61
|
+
for error in exc.errors():
|
|
62
|
+
loc = ".".join(str(part) for part in error.get("loc", ()))
|
|
63
|
+
msg = error.get("msg", "invalid value")
|
|
64
|
+
lines.append(f"{loc or '<root>'}: {msg}")
|
|
65
|
+
details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
|
|
66
|
+
return f"Config validation failed ({path}):\n{details}"
|
|
67
|
+
|
|
68
|
+
|
|
45
69
|
def build_rl_payload(
|
|
46
70
|
*,
|
|
47
71
|
config_path: Path,
|
|
@@ -50,13 +74,30 @@ def build_rl_payload(
|
|
|
50
74
|
idempotency: str | None,
|
|
51
75
|
allow_experimental: bool | None = None,
|
|
52
76
|
) -> RLBuildResult:
|
|
53
|
-
data = load_toml(config_path)
|
|
54
77
|
try:
|
|
55
|
-
|
|
78
|
+
rl_cfg = RLConfig.from_path(config_path)
|
|
79
|
+
except ValidationError as exc:
|
|
80
|
+
raise click.ClickException(_format_validation_error(config_path, exc)) from exc
|
|
81
|
+
|
|
82
|
+
data = rl_cfg.to_dict()
|
|
83
|
+
# Ensure required [reference] section for backend validators
|
|
84
|
+
try:
|
|
85
|
+
ref_cfg = data.get("reference") if isinstance(data, dict) else None
|
|
86
|
+
if not isinstance(ref_cfg, dict):
|
|
87
|
+
data["reference"] = {"placement": "none"}
|
|
88
|
+
else:
|
|
89
|
+
ref_cfg.setdefault("placement", "none")
|
|
90
|
+
except Exception:
|
|
91
|
+
# Defensive: never fail builder due to optional defaults
|
|
92
|
+
data["reference"] = {"placement": "none"}
|
|
93
|
+
try:
|
|
94
|
+
spec = validate_algorithm_config(
|
|
95
|
+
rl_cfg.algorithm.model_dump(), expected_family="rl"
|
|
96
|
+
)
|
|
56
97
|
except AlgorithmValidationError as exc:
|
|
57
98
|
raise click.ClickException(str(exc)) from exc
|
|
58
99
|
services = data.get("services") if isinstance(data.get("services"), dict) else {}
|
|
59
|
-
model_cfg =
|
|
100
|
+
model_cfg = rl_cfg.model
|
|
60
101
|
|
|
61
102
|
final_task_url = (
|
|
62
103
|
overrides.get("task_url")
|
|
@@ -69,10 +110,8 @@ def build_rl_payload(
|
|
|
69
110
|
"Task app URL required (provide --task-url or set services.task_url in TOML)"
|
|
70
111
|
)
|
|
71
112
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
|
|
75
|
-
model_base = str(raw_base or "").strip()
|
|
113
|
+
model_source = (model_cfg.source or "").strip()
|
|
114
|
+
model_base = (model_cfg.base or "").strip()
|
|
76
115
|
override_model = (overrides.get("model") or "").strip()
|
|
77
116
|
if override_model:
|
|
78
117
|
model_source = override_model
|
|
@@ -98,7 +137,7 @@ def build_rl_payload(
|
|
|
98
137
|
if model_source:
|
|
99
138
|
model_source = normalize_model_identifier(model_source)
|
|
100
139
|
if model_base:
|
|
101
|
-
model_base = normalize_model_identifier(model_base
|
|
140
|
+
model_base = normalize_model_identifier(model_base)
|
|
102
141
|
except UnsupportedModelError as exc:
|
|
103
142
|
raise click.ClickException(str(exc)) from exc
|
|
104
143
|
|
|
@@ -160,22 +199,23 @@ def build_sft_payload(
|
|
|
160
199
|
dataset_override: Path | None,
|
|
161
200
|
allow_experimental: bool | None,
|
|
162
201
|
) -> SFTBuildResult:
|
|
163
|
-
data = load_toml(config_path)
|
|
164
202
|
try:
|
|
165
|
-
|
|
203
|
+
sft_cfg = SFTConfig.from_path(config_path)
|
|
204
|
+
except ValidationError as exc:
|
|
205
|
+
raise TrainError(_format_validation_error(config_path, exc)) from exc
|
|
206
|
+
|
|
207
|
+
data = sft_cfg.to_dict()
|
|
208
|
+
try:
|
|
209
|
+
algo_mapping = sft_cfg.algorithm.model_dump() if sft_cfg.algorithm else None
|
|
210
|
+
spec = validate_algorithm_config(algo_mapping, expected_family="sft")
|
|
166
211
|
except AlgorithmValidationError as exc:
|
|
167
212
|
raise TrainError(str(exc)) from exc
|
|
168
|
-
job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
|
|
169
213
|
data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
170
214
|
hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
|
|
171
215
|
train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
|
|
172
216
|
compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
|
|
173
217
|
|
|
174
|
-
raw_dataset =
|
|
175
|
-
dataset_override
|
|
176
|
-
or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
|
|
177
|
-
or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
|
|
178
|
-
)
|
|
218
|
+
raw_dataset = dataset_override or sft_cfg.job.data or sft_cfg.job.data_path
|
|
179
219
|
if not raw_dataset:
|
|
180
220
|
raise TrainError("Dataset not specified; pass --dataset or set [job].data")
|
|
181
221
|
dataset_path = Path(raw_dataset)
|
|
@@ -260,9 +300,11 @@ def build_sft_payload(
|
|
|
260
300
|
"enabled": bool(validation_cfg.get("enabled", True))
|
|
261
301
|
}
|
|
262
302
|
|
|
263
|
-
raw_model =
|
|
264
|
-
|
|
265
|
-
|
|
303
|
+
raw_model = (sft_cfg.job.model or "").strip()
|
|
304
|
+
if not raw_model:
|
|
305
|
+
model_block = data.get("model")
|
|
306
|
+
if isinstance(model_block, str):
|
|
307
|
+
raw_model = model_block.strip()
|
|
266
308
|
if not raw_model:
|
|
267
309
|
raise TrainError("Model not specified; set [job].model or [model].base in the config")
|
|
268
310
|
|
|
@@ -274,10 +316,12 @@ def build_sft_payload(
|
|
|
274
316
|
)
|
|
275
317
|
except UnsupportedModelError as exc:
|
|
276
318
|
raise TrainError(str(exc)) from exc
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
319
|
+
|
|
320
|
+
if base_model:
|
|
321
|
+
try:
|
|
322
|
+
ensure_model_supported_for_algorithm(base_model, spec)
|
|
323
|
+
except AlgorithmValidationError as exc:
|
|
324
|
+
raise TrainError(str(exc)) from exc
|
|
281
325
|
|
|
282
326
|
try:
|
|
283
327
|
payload = prepare_sft_job_payload(
|
synth_ai/api/train/cli.py
CHANGED
|
@@ -2,15 +2,17 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
4
|
import os
|
|
5
|
-
from collections.abc import Mapping
|
|
5
|
+
from collections.abc import Callable, Mapping
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any, cast
|
|
8
8
|
|
|
9
9
|
import click
|
|
10
10
|
|
|
11
11
|
try:
|
|
12
|
-
_config_module =
|
|
13
|
-
|
|
12
|
+
_config_module = cast(
|
|
13
|
+
Any, importlib.import_module("synth_ai.config.base_url")
|
|
14
|
+
)
|
|
15
|
+
get_backend_from_env = cast(Callable[[], str], _config_module.get_backend_from_env)
|
|
14
16
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
15
17
|
raise RuntimeError("Unable to load backend configuration helpers") from exc
|
|
16
18
|
|
|
@@ -238,8 +240,12 @@ def train_command(
|
|
|
238
240
|
]
|
|
239
241
|
if missing_keys:
|
|
240
242
|
try:
|
|
241
|
-
_task_apps_module =
|
|
242
|
-
|
|
243
|
+
_task_apps_module = cast(
|
|
244
|
+
Any, importlib.import_module("synth_ai.cli.task_apps")
|
|
245
|
+
)
|
|
246
|
+
_interactive_fill_env = cast(
|
|
247
|
+
Callable[[Path], Path | None], _task_apps_module._interactive_fill_env
|
|
248
|
+
)
|
|
243
249
|
except Exception as exc: # pragma: no cover - protective fallback
|
|
244
250
|
raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
|
|
245
251
|
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Typed training config loaders for RL and SFT jobs."""
|
|
2
|
+
|
|
3
|
+
from .rl import (
|
|
4
|
+
EvaluationConfig,
|
|
5
|
+
JudgeConfig,
|
|
6
|
+
JudgeOptionsConfig,
|
|
7
|
+
ModelConfig,
|
|
8
|
+
RLConfig,
|
|
9
|
+
RLServicesConfig,
|
|
10
|
+
RLTrainingConfig,
|
|
11
|
+
RolloutConfig,
|
|
12
|
+
WeightSyncConfig,
|
|
13
|
+
)
|
|
14
|
+
from .sft import (
|
|
15
|
+
HyperparametersConfig,
|
|
16
|
+
HyperparametersParallelism,
|
|
17
|
+
JobConfig,
|
|
18
|
+
SFTConfig,
|
|
19
|
+
SFTDataConfig,
|
|
20
|
+
TrainingConfig,
|
|
21
|
+
TrainingValidationConfig,
|
|
22
|
+
)
|
|
23
|
+
from .shared import AlgorithmConfig, ComputeConfig
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"AlgorithmConfig",
|
|
27
|
+
"ComputeConfig",
|
|
28
|
+
"EvaluationConfig",
|
|
29
|
+
"HyperparametersConfig",
|
|
30
|
+
"HyperparametersParallelism",
|
|
31
|
+
"JobConfig",
|
|
32
|
+
"JudgeConfig",
|
|
33
|
+
"JudgeOptionsConfig",
|
|
34
|
+
"ModelConfig",
|
|
35
|
+
"RLConfig",
|
|
36
|
+
"RLServicesConfig",
|
|
37
|
+
"RLTrainingConfig",
|
|
38
|
+
"RolloutConfig",
|
|
39
|
+
"SFTConfig",
|
|
40
|
+
"SFTDataConfig",
|
|
41
|
+
"TrainingConfig",
|
|
42
|
+
"TrainingValidationConfig",
|
|
43
|
+
"WeightSyncConfig",
|
|
44
|
+
]
|