synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (73) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +494 -0
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
  19. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  20. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  21. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
  22. examples/warming_up_to_rl/run_eval.py +267 -41
  23. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  25. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  26. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
  27. synth_ai/__init__.py +41 -1
  28. synth_ai/api/train/builders.py +74 -33
  29. synth_ai/api/train/cli.py +29 -6
  30. synth_ai/api/train/configs/__init__.py +44 -0
  31. synth_ai/api/train/configs/rl.py +133 -0
  32. synth_ai/api/train/configs/sft.py +94 -0
  33. synth_ai/api/train/configs/shared.py +24 -0
  34. synth_ai/api/train/env_resolver.py +18 -19
  35. synth_ai/api/train/supported_algos.py +8 -5
  36. synth_ai/api/train/utils.py +6 -1
  37. synth_ai/cli/__init__.py +4 -2
  38. synth_ai/cli/_storage.py +19 -0
  39. synth_ai/cli/balance.py +14 -2
  40. synth_ai/cli/calc.py +37 -22
  41. synth_ai/cli/demo.py +38 -39
  42. synth_ai/cli/legacy_root_backup.py +12 -14
  43. synth_ai/cli/recent.py +12 -7
  44. synth_ai/cli/rl_demo.py +81 -102
  45. synth_ai/cli/status.py +4 -3
  46. synth_ai/cli/task_apps.py +146 -137
  47. synth_ai/cli/traces.py +4 -3
  48. synth_ai/cli/watch.py +3 -2
  49. synth_ai/demos/core/cli.py +121 -159
  50. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  51. synth_ai/evals/__init__.py +15 -0
  52. synth_ai/evals/client.py +85 -0
  53. synth_ai/evals/types.py +42 -0
  54. synth_ai/jobs/client.py +15 -3
  55. synth_ai/judge_schemas.py +127 -0
  56. synth_ai/rubrics/__init__.py +22 -0
  57. synth_ai/rubrics/validators.py +126 -0
  58. synth_ai/task/server.py +14 -7
  59. synth_ai/tracing_v3/decorators.py +51 -26
  60. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  61. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  62. synth_ai/tracing_v3/replica_sync.py +8 -4
  63. synth_ai/tracing_v3/serialization.py +130 -0
  64. synth_ai/tracing_v3/storage/utils.py +11 -9
  65. synth_ai/tracing_v3/turso/__init__.py +12 -0
  66. synth_ai/tracing_v3/turso/daemon.py +2 -1
  67. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  68. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
  69. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
  70. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  71. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  72. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  73. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
synth_ai/__init__.py CHANGED
@@ -2,6 +2,28 @@
2
2
  Synth AI - Software for aiding the best and multiplying the will.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
7
+ from importlib import metadata as _metadata
8
+ from importlib.metadata import PackageNotFoundError
9
+ from pathlib import Path
10
+
11
+ try: # Prefer the installed package metadata when available
12
+ __version__ = _metadata.version("synth-ai")
13
+ except PackageNotFoundError: # Fallback to pyproject version for editable installs
14
+ try:
15
+ import tomllib as _toml # Python 3.11+
16
+ except ModuleNotFoundError: # pragma: no cover - legacy interpreter guard
17
+ import tomli as _toml # type: ignore[no-redef]
18
+
19
+ try:
20
+ pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
21
+ with pyproject_path.open("rb") as fh:
22
+ _pyproject = _toml.load(fh)
23
+ __version__ = str(_pyproject["project"]["version"])
24
+ except Exception:
25
+ __version__ = "0.0.0.dev0"
26
+
5
27
  # Environment exports - moved from synth-env
6
28
  from synth_ai.environments import * # noqa
7
29
  import synth_ai.environments as environments # expose module name for __all__
@@ -21,12 +43,22 @@ try:
21
43
  except Exception:
22
44
  AsyncOpenAI = OpenAI = None # type: ignore
23
45
 
46
+ # Judge API contract schemas
47
+ from synth_ai.judge_schemas import (
48
+ JudgeScoreRequest,
49
+ JudgeScoreResponse,
50
+ JudgeOptions,
51
+ JudgeTaskApp,
52
+ JudgeTracePayload,
53
+ ReviewPayload,
54
+ CriterionScorePayload,
55
+ )
56
+
24
57
  # Legacy tracing v1 is not required for v3 usage and can be unavailable in minimal envs.
25
58
  tracing = None # type: ignore
26
59
  EventPartitionElement = RewardSignal = SystemTrace = TrainingQuestion = None # type: ignore
27
60
  trace_event_async = trace_event_sync = upload = None # type: ignore
28
61
 
29
- __version__ = "0.2.6.dev4"
30
62
  __all__ = [
31
63
  "LM",
32
64
  "OpenAI",
@@ -34,4 +66,12 @@ __all__ = [
34
66
  "Anthropic",
35
67
  "AsyncAnthropic",
36
68
  "environments",
69
+ # Judge API contracts
70
+ "JudgeScoreRequest",
71
+ "JudgeScoreResponse",
72
+ "JudgeOptions",
73
+ "JudgeTaskApp",
74
+ "JudgeTracePayload",
75
+ "ReviewPayload",
76
+ "CriterionScorePayload",
37
77
  ] # Explicitly define public API (v1 tracing omitted in minimal env)
@@ -1,23 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
5
- from typing import Any
6
+ from typing import Any, cast
6
7
 
7
8
  import click
8
- from synth_ai.api.models.supported import (
9
- UnsupportedModelError,
10
- ensure_allowed_model,
11
- normalize_model_identifier,
12
- )
13
- from synth_ai.learning.sft.config import prepare_sft_job_payload
9
+ from pydantic import ValidationError
10
+
11
+ try:
12
+ _models_module = importlib.import_module("synth_ai.api.models.supported")
13
+ UnsupportedModelError = _models_module.UnsupportedModelError
14
+ ensure_allowed_model = _models_module.ensure_allowed_model
15
+ normalize_model_identifier = _models_module.normalize_model_identifier
16
+ except Exception as exc: # pragma: no cover - critical dependency
17
+ raise RuntimeError("Unable to load supported model helpers") from exc
18
+
19
+ try:
20
+ prepare_sft_job_payload = importlib.import_module("synth_ai.learning.sft.config").prepare_sft_job_payload
21
+ except Exception as exc: # pragma: no cover - critical dependency
22
+ raise RuntimeError("Unable to load SFT payload helpers") from exc
14
23
 
15
24
  from .supported_algos import (
16
25
  AlgorithmValidationError,
17
26
  ensure_model_supported_for_algorithm,
18
27
  validate_algorithm_config,
19
28
  )
20
- from .utils import TrainError, ensure_api_base, load_toml
29
+ from .utils import TrainError, ensure_api_base
30
+ from .configs import RLConfig, SFTConfig
21
31
 
22
32
 
23
33
  @dataclass(slots=True)
@@ -34,6 +44,16 @@ class SFTBuildResult:
34
44
  validation_file: Path | None
35
45
 
36
46
 
47
+ def _format_validation_error(path: Path, exc: ValidationError) -> str:
48
+ lines: list[str] = []
49
+ for error in exc.errors():
50
+ loc = ".".join(str(part) for part in error.get("loc", ()))
51
+ msg = error.get("msg", "invalid value")
52
+ lines.append(f"{loc or '<root>'}: {msg}")
53
+ details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
54
+ return f"Config validation failed ({path}):\n{details}"
55
+
56
+
37
57
  def build_rl_payload(
38
58
  *,
39
59
  config_path: Path,
@@ -42,13 +62,30 @@ def build_rl_payload(
42
62
  idempotency: str | None,
43
63
  allow_experimental: bool | None = None,
44
64
  ) -> RLBuildResult:
45
- data = load_toml(config_path)
46
65
  try:
47
- spec = validate_algorithm_config(data.get("algorithm"), expected_family="rl")
66
+ rl_cfg = RLConfig.from_path(config_path)
67
+ except ValidationError as exc:
68
+ raise click.ClickException(_format_validation_error(config_path, exc)) from exc
69
+
70
+ data = rl_cfg.to_dict()
71
+ # Ensure required [reference] section for backend validators
72
+ try:
73
+ ref_cfg = data.get("reference") if isinstance(data, dict) else None
74
+ if not isinstance(ref_cfg, dict):
75
+ data["reference"] = {"placement": "none"}
76
+ else:
77
+ ref_cfg.setdefault("placement", "none")
78
+ except Exception:
79
+ # Defensive: never fail builder due to optional defaults
80
+ data["reference"] = {"placement": "none"}
81
+ try:
82
+ spec = validate_algorithm_config(
83
+ rl_cfg.algorithm.model_dump(), expected_family="rl"
84
+ )
48
85
  except AlgorithmValidationError as exc:
49
86
  raise click.ClickException(str(exc)) from exc
50
87
  services = data.get("services") if isinstance(data.get("services"), dict) else {}
51
- model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
88
+ model_cfg = rl_cfg.model
52
89
 
53
90
  final_task_url = (
54
91
  overrides.get("task_url")
@@ -61,10 +98,8 @@ def build_rl_payload(
61
98
  "Task app URL required (provide --task-url or set services.task_url in TOML)"
62
99
  )
63
100
 
64
- raw_source = model_cfg.get("source") if isinstance(model_cfg, dict) else ""
65
- model_source = str(raw_source or "").strip()
66
- raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
67
- model_base = str(raw_base or "").strip()
101
+ model_source = (model_cfg.source or "").strip()
102
+ model_base = (model_cfg.base or "").strip()
68
103
  override_model = (overrides.get("model") or "").strip()
69
104
  if override_model:
70
105
  model_source = override_model
@@ -122,23 +157,26 @@ def build_rl_payload(
122
157
  except Exception:
123
158
  pass
124
159
 
160
+ payload_data: dict[str, Any] = {
161
+ "endpoint_base_url": final_task_url.rstrip("/"),
162
+ "config": data,
163
+ }
125
164
  payload: dict[str, Any] = {
126
165
  "job_type": "rl",
127
166
  "compute": data.get("compute", {}),
128
- "data": {
129
- "endpoint_base_url": final_task_url.rstrip("/"),
130
- "config": data,
131
- },
167
+ "data": payload_data,
132
168
  "tags": {"source": "train-cli"},
133
169
  }
134
170
  if model_source:
135
- payload["data"]["model"] = model_source
171
+ payload_data["model"] = model_source
136
172
  if model_base:
137
- payload["data"]["base_model"] = model_base
173
+ payload_data["base_model"] = model_base
138
174
 
139
175
  backend = overrides.get("backend")
140
176
  if backend:
141
- payload.setdefault("metadata", {})["backend_base_url"] = ensure_api_base(str(backend))
177
+ metadata_default: dict[str, Any] = {}
178
+ metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
179
+ metadata["backend_base_url"] = ensure_api_base(str(backend))
142
180
 
143
181
  return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
144
182
 
@@ -149,22 +187,23 @@ def build_sft_payload(
149
187
  dataset_override: Path | None,
150
188
  allow_experimental: bool | None,
151
189
  ) -> SFTBuildResult:
152
- data = load_toml(config_path)
153
190
  try:
154
- spec = validate_algorithm_config(data.get("algorithm"), expected_family="sft")
191
+ sft_cfg = SFTConfig.from_path(config_path)
192
+ except ValidationError as exc:
193
+ raise TrainError(_format_validation_error(config_path, exc)) from exc
194
+
195
+ data = sft_cfg.to_dict()
196
+ try:
197
+ algo_mapping = sft_cfg.algorithm.model_dump() if sft_cfg.algorithm else None
198
+ spec = validate_algorithm_config(algo_mapping, expected_family="sft")
155
199
  except AlgorithmValidationError as exc:
156
200
  raise TrainError(str(exc)) from exc
157
- job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
158
201
  data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
159
202
  hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
160
203
  train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
161
204
  compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
162
205
 
163
- raw_dataset = (
164
- dataset_override
165
- or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
166
- or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
167
- )
206
+ raw_dataset = dataset_override or sft_cfg.job.data or sft_cfg.job.data_path
168
207
  if not raw_dataset:
169
208
  raise TrainError("Dataset not specified; pass --dataset or set [job].data")
170
209
  dataset_path = Path(raw_dataset)
@@ -249,9 +288,11 @@ def build_sft_payload(
249
288
  "enabled": bool(validation_cfg.get("enabled", True))
250
289
  }
251
290
 
252
- raw_model = str(
253
- job_cfg.get("model") if isinstance(job_cfg, dict) else None or data.get("model") or ""
254
- ).strip()
291
+ raw_model = (sft_cfg.job.model or "").strip()
292
+ if not raw_model:
293
+ model_block = data.get("model")
294
+ if isinstance(model_block, str):
295
+ raw_model = model_block.strip()
255
296
  if not raw_model:
256
297
  raise TrainError("Model not specified; set [job].model or [model].base in the config")
257
298
 
synth_ai/api/train/cli.py CHANGED
@@ -1,11 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import os
5
+ from collections.abc import Mapping
4
6
  from pathlib import Path
5
7
  from typing import Any
6
8
 
7
9
  import click
8
- from synth_ai.config.base_url import get_backend_from_env
10
+
11
+ try:
12
+ _config_module = importlib.import_module("synth_ai.config.base_url")
13
+ get_backend_from_env = _config_module.get_backend_from_env
14
+ except Exception as exc: # pragma: no cover - critical dependency
15
+ raise RuntimeError("Unable to load backend configuration helpers") from exc
9
16
 
10
17
  from .builders import build_rl_payload, build_sft_payload
11
18
  from .config_finder import discover_configs, prompt_for_config
@@ -231,7 +238,8 @@ def train_command(
231
238
  ]
232
239
  if missing_keys:
233
240
  try:
234
- from synth_ai.cli.task_apps import _interactive_fill_env
241
+ _task_apps_module = importlib.import_module("synth_ai.cli.task_apps")
242
+ _interactive_fill_env = _task_apps_module._interactive_fill_env
235
243
  except Exception as exc: # pragma: no cover - protective fallback
236
244
  raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
237
245
 
@@ -386,9 +394,19 @@ def handle_rl(
386
394
  verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url}
387
395
  )
388
396
  try:
389
- vjs = vresp.json()
397
+ parsed_json = vresp.json()
390
398
  except Exception:
391
- vjs = {"status": vresp.status_code, "text": (vresp.text or "")[:400]}
399
+ parsed_json = None
400
+
401
+ if isinstance(parsed_json, Mapping):
402
+ vjs: dict[str, Any] = dict(parsed_json)
403
+ else:
404
+ vjs = {
405
+ "status": vresp.status_code,
406
+ "text": (vresp.text or "")[:400],
407
+ }
408
+ if parsed_json is not None:
409
+ vjs["body"] = parsed_json
392
410
  except Exception as _ve:
393
411
  raise click.ClickException(
394
412
  f"Task app verification call failed: {type(_ve).__name__}: {_ve}"
@@ -404,8 +422,13 @@ def handle_rl(
404
422
  # Print concise summary
405
423
  try:
406
424
  cands = vjs.get("candidates_first15") or []
407
- attempts = vjs.get("attempts") or []
408
- statuses = [a.get("status") for a in attempts]
425
+ attempts_raw = vjs.get("attempts")
426
+ attempts: list[Mapping[str, Any]] = (
427
+ [a for a in attempts_raw if isinstance(a, Mapping)]
428
+ if isinstance(attempts_raw, list)
429
+ else []
430
+ )
431
+ statuses = [attempt.get("status") for attempt in attempts]
409
432
  click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
410
433
  except Exception:
411
434
  pass
@@ -0,0 +1,44 @@
1
+ """Typed training config loaders for RL and SFT jobs."""
2
+
3
+ from .shared import AlgorithmConfig, ComputeConfig
4
+ from .sft import (
5
+ HyperparametersConfig,
6
+ HyperparametersParallelism,
7
+ JobConfig,
8
+ SFTConfig,
9
+ SFTDataConfig,
10
+ TrainingConfig,
11
+ TrainingValidationConfig,
12
+ )
13
+ from .rl import (
14
+ EvaluationConfig,
15
+ JudgeConfig,
16
+ JudgeOptionsConfig,
17
+ ModelConfig,
18
+ RLConfig,
19
+ RLServicesConfig,
20
+ RLTrainingConfig,
21
+ RolloutConfig,
22
+ WeightSyncConfig,
23
+ )
24
+
25
+ __all__ = [
26
+ "AlgorithmConfig",
27
+ "ComputeConfig",
28
+ "EvaluationConfig",
29
+ "HyperparametersConfig",
30
+ "HyperparametersParallelism",
31
+ "JobConfig",
32
+ "JudgeConfig",
33
+ "JudgeOptionsConfig",
34
+ "ModelConfig",
35
+ "RLConfig",
36
+ "RLServicesConfig",
37
+ "RLTrainingConfig",
38
+ "RolloutConfig",
39
+ "SFTConfig",
40
+ "SFTDataConfig",
41
+ "TrainingConfig",
42
+ "TrainingValidationConfig",
43
+ "WeightSyncConfig",
44
+ ]
@@ -0,0 +1,133 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Mapping
5
+
6
+ from pydantic import model_validator
7
+
8
+ from ..utils import load_toml
9
+ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
10
+
11
+
12
+ class RLServicesConfig(ExtraModel):
13
+ task_url: str
14
+ judge_url: str | None = None
15
+
16
+
17
+ class ModelConfig(ExtraModel):
18
+ source: str | None = None
19
+ base: str | None = None
20
+ trainer_mode: str
21
+ label: str
22
+
23
+ @model_validator(mode="after")
24
+ def _ensure_exactly_one_source_or_base(self) -> "ModelConfig":
25
+ if bool(self.source) == bool(self.base):
26
+ raise ValueError("Config must set exactly one of [model].source or [model].base")
27
+ return self
28
+
29
+
30
+ class RolloutConfig(ExtraModel):
31
+ env_name: str
32
+ policy_name: str
33
+ env_config: dict[str, Any] | None = None
34
+ policy_config: dict[str, Any] | None = None
35
+ max_turns: int
36
+ episodes_per_batch: int
37
+ max_concurrent_rollouts: int
38
+ batches_per_step: int | None = None
39
+ ops: list[str] | None = None
40
+
41
+
42
+ class WeightSyncConfig(ExtraModel):
43
+ enable: bool | None = None
44
+ targets: list[str] | None = None
45
+ mode: str | None = None
46
+ direct: bool | None = None
47
+ verify_every_k: int | None = None
48
+
49
+
50
+ class RLTrainingConfig(ExtraModel):
51
+ num_epochs: int
52
+ iterations_per_epoch: int
53
+ gradient_accumulation_steps: int | None = None
54
+ max_accumulated_minibatch: int | None = None
55
+ max_turns: int
56
+ batch_size: int
57
+ group_size: int
58
+ learning_rate: float
59
+ log_interval: int | None = None
60
+ weight_sync_interval: int | None = None
61
+ step_rewards_enabled: bool | None = None
62
+ step_rewards_mode: str | None = None
63
+ step_rewards_indicator_lambda: float | None = None
64
+ step_rewards_beta: float | None = None
65
+ step_rewards_strategy: str | None = None
66
+ event_rewards_kind: str | None = None
67
+ weight_sync: WeightSyncConfig | None = None
68
+
69
+
70
+ class EvaluationConfig(ExtraModel):
71
+ instances: int
72
+ every_n_iters: int
73
+ seeds: list[int]
74
+
75
+
76
+ class JudgeOptionsConfig(ExtraModel):
77
+ event: bool | None = None
78
+ outcome: bool | None = None
79
+ provider: str | None = None
80
+ model: str | None = None
81
+ rubric_id: str | None = None
82
+ rubric_overrides: dict[str, Any] | None = None
83
+ tracks: list[str] | None = None
84
+ weights: dict[str, float] | None = None
85
+ max_concurrency: int | None = None
86
+
87
+
88
+ class JudgeConfig(ExtraModel):
89
+ type: str | None = None
90
+ timeout_s: int | None = None
91
+ options: JudgeOptionsConfig | None = None
92
+
93
+
94
+ class RLConfig(ExtraModel):
95
+ algorithm: AlgorithmConfig
96
+ services: RLServicesConfig
97
+ compute: ComputeConfig | None = None
98
+ topology: dict[str, Any] | None = None
99
+ vllm: dict[str, Any] | None = None
100
+ reference: dict[str, Any] | None = None
101
+ model: ModelConfig
102
+ lora: dict[str, Any] | None = None
103
+ rollout: RolloutConfig | None = None
104
+ evaluation: EvaluationConfig | None = None
105
+ training: RLTrainingConfig | None = None
106
+ rubric: dict[str, Any] | None = None
107
+ judge: JudgeConfig | None = None
108
+ tags: dict[str, Any] | None = None
109
+
110
+ def to_dict(self) -> dict[str, Any]:
111
+ return self.model_dump(mode="python", exclude_none=True)
112
+
113
+ @classmethod
114
+ def from_mapping(cls, data: Mapping[str, Any]) -> "RLConfig":
115
+ return cls.model_validate(dict(data))
116
+
117
+ @classmethod
118
+ def from_path(cls, path: Path) -> "RLConfig":
119
+ content = load_toml(path)
120
+ return cls.from_mapping(content)
121
+
122
+
123
+ __all__ = [
124
+ "EvaluationConfig",
125
+ "JudgeConfig",
126
+ "JudgeOptionsConfig",
127
+ "ModelConfig",
128
+ "RLConfig",
129
+ "RLServicesConfig",
130
+ "RLTrainingConfig",
131
+ "RolloutConfig",
132
+ "WeightSyncConfig",
133
+ ]
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Mapping
5
+
6
+ from pydantic import Field
7
+
8
+ from ..utils import load_toml
9
+ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
10
+
11
+
12
+ class JobConfig(ExtraModel):
13
+ model: str
14
+ data: str | None = None
15
+ data_path: str | None = None
16
+ poll_seconds: int | None = None
17
+
18
+
19
+ class SFTDataConfig(ExtraModel):
20
+ topology: dict[str, Any] | None = None
21
+ validation_path: str | None = None
22
+
23
+
24
+ class TrainingValidationConfig(ExtraModel):
25
+ enabled: bool | None = None
26
+ evaluation_strategy: str | None = None
27
+ eval_steps: int | None = None
28
+ save_best_model_at_end: bool | None = None
29
+ metric_for_best_model: str | None = None
30
+ greater_is_better: bool | None = None
31
+
32
+
33
+ class TrainingConfig(ExtraModel):
34
+ mode: str | None = None
35
+ use_qlora: bool | None = None
36
+ validation: TrainingValidationConfig | None = None
37
+
38
+
39
+ class HyperparametersParallelism(ExtraModel):
40
+ use_deepspeed: bool | None = None
41
+ deepspeed_stage: int | None = None
42
+ fsdp: bool | None = None
43
+ bf16: bool | None = None
44
+ fp16: bool | None = None
45
+ activation_checkpointing: bool | None = None
46
+ tensor_parallel_size: int | None = None
47
+ pipeline_parallel_size: int | None = None
48
+
49
+
50
+ class HyperparametersConfig(ExtraModel):
51
+ n_epochs: int = 1
52
+ batch_size: int | None = None
53
+ global_batch: int | None = None
54
+ per_device_batch: int | None = None
55
+ gradient_accumulation_steps: int | None = None
56
+ sequence_length: int | None = None
57
+ learning_rate: float | None = None
58
+ warmup_ratio: float | None = None
59
+ train_kind: str | None = None
60
+ weight_decay: float | None = None
61
+ parallelism: HyperparametersParallelism | None = None
62
+
63
+
64
+ class SFTConfig(ExtraModel):
65
+ algorithm: AlgorithmConfig | None = None
66
+ job: JobConfig
67
+ compute: ComputeConfig | None = None
68
+ data: SFTDataConfig | None = None
69
+ training: TrainingConfig | None = None
70
+ hyperparameters: HyperparametersConfig = Field(default_factory=HyperparametersConfig)
71
+ tags: dict[str, Any] | None = None
72
+
73
+ def to_dict(self) -> dict[str, Any]:
74
+ return self.model_dump(mode="python", exclude_none=True)
75
+
76
+ @classmethod
77
+ def from_mapping(cls, data: Mapping[str, Any]) -> "SFTConfig":
78
+ return cls.model_validate(dict(data))
79
+
80
+ @classmethod
81
+ def from_path(cls, path: Path) -> "SFTConfig":
82
+ content = load_toml(path)
83
+ return cls.from_mapping(content)
84
+
85
+
86
+ __all__ = [
87
+ "HyperparametersConfig",
88
+ "HyperparametersParallelism",
89
+ "JobConfig",
90
+ "SFTConfig",
91
+ "SFTDataConfig",
92
+ "TrainingConfig",
93
+ "TrainingValidationConfig",
94
+ ]
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class ExtraModel(BaseModel):
7
+ """Base model that tolerates unknown keys so configs keep forward compatibility."""
8
+
9
+ model_config = ConfigDict(extra="allow")
10
+
11
+
12
+ class AlgorithmConfig(ExtraModel):
13
+ type: str
14
+ method: str
15
+ variety: str
16
+
17
+
18
+ class ComputeConfig(ExtraModel):
19
+ gpu_type: str
20
+ gpu_count: int
21
+ nodes: int | None = None
22
+
23
+
24
+ __all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig"]
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import os
4
5
  from collections.abc import Callable, Iterable, MutableMapping
5
6
  from dataclasses import dataclass
@@ -11,6 +12,18 @@ from . import task_app
11
12
  from .utils import REPO_ROOT, mask_value, read_env_file, write_env_value
12
13
 
13
14
 
15
+ def _load_saved_env_path() -> Path | None:
16
+ try:
17
+ module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
18
+ loader = module.load_env_file_path
19
+ saved_path = loader()
20
+ if saved_path:
21
+ return Path(saved_path)
22
+ except Exception:
23
+ return None
24
+ return None
25
+
26
+
14
27
  @dataclass(slots=True)
15
28
  class KeySpec:
16
29
  name: str
@@ -156,25 +169,11 @@ def resolve_env(
156
169
  raise click.ClickException(f"Env file not found: {path}")
157
170
  resolver = EnvResolver(provided)
158
171
  else:
159
- # Check for saved .env path from demo command
160
- try:
161
- from synth_ai.demos.demo_task_apps.core import load_env_file_path
162
-
163
- saved_env_path = load_env_file_path()
164
- if saved_env_path:
165
- saved_path = Path(saved_env_path)
166
- if saved_path.exists():
167
- click.echo(f"Using .env file: {saved_path}")
168
- resolver = EnvResolver([saved_path])
169
- else:
170
- # Saved path no longer exists, fall back to prompt
171
- resolver = EnvResolver(_collect_default_candidates(config_path))
172
- resolver.select_new_env()
173
- else:
174
- resolver = EnvResolver(_collect_default_candidates(config_path))
175
- resolver.select_new_env()
176
- except Exception:
177
- # If import fails or any error, fall back to original behavior
172
+ saved_path = _load_saved_env_path()
173
+ if saved_path and saved_path.exists():
174
+ click.echo(f"Using .env file: {saved_path}")
175
+ resolver = EnvResolver([saved_path])
176
+ else:
178
177
  resolver = EnvResolver(_collect_default_candidates(config_path))
179
178
  resolver.select_new_env()
180
179