synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (153) hide show
  1. synth_ai/__init__.py +13 -13
  2. synth_ai/cli/__init__.py +6 -15
  3. synth_ai/cli/commands/eval/__init__.py +6 -15
  4. synth_ai/cli/commands/eval/config.py +338 -0
  5. synth_ai/cli/commands/eval/core.py +236 -1091
  6. synth_ai/cli/commands/eval/runner.py +704 -0
  7. synth_ai/cli/commands/eval/validation.py +44 -117
  8. synth_ai/cli/commands/filter/core.py +7 -7
  9. synth_ai/cli/commands/filter/validation.py +2 -2
  10. synth_ai/cli/commands/smoke/core.py +7 -17
  11. synth_ai/cli/commands/status/__init__.py +1 -64
  12. synth_ai/cli/commands/status/client.py +50 -151
  13. synth_ai/cli/commands/status/config.py +3 -83
  14. synth_ai/cli/commands/status/errors.py +4 -13
  15. synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
  16. synth_ai/cli/commands/status/subcommands/config.py +13 -0
  17. synth_ai/cli/commands/status/subcommands/files.py +18 -63
  18. synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
  19. synth_ai/cli/commands/status/subcommands/models.py +18 -62
  20. synth_ai/cli/commands/status/subcommands/runs.py +16 -63
  21. synth_ai/cli/commands/status/subcommands/session.py +67 -172
  22. synth_ai/cli/commands/status/subcommands/summary.py +24 -32
  23. synth_ai/cli/commands/status/subcommands/utils.py +41 -0
  24. synth_ai/cli/commands/status/utils.py +16 -107
  25. synth_ai/cli/commands/train/__init__.py +18 -20
  26. synth_ai/cli/commands/train/errors.py +3 -3
  27. synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
  28. synth_ai/cli/commands/train/validation.py +7 -7
  29. synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
  30. synth_ai/cli/commands/train/verifier_validation.py +235 -0
  31. synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
  32. synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
  33. synth_ai/cli/demo_apps/math/config.toml +0 -1
  34. synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
  35. synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
  36. synth_ai/cli/lib/apps/task_app.py +12 -13
  37. synth_ai/cli/lib/task_app_discovery.py +6 -6
  38. synth_ai/cli/lib/train_cfgs.py +10 -10
  39. synth_ai/cli/task_apps/__init__.py +11 -0
  40. synth_ai/cli/task_apps/commands.py +7 -15
  41. synth_ai/core/env.py +12 -1
  42. synth_ai/core/errors.py +1 -2
  43. synth_ai/core/integrations/cloudflare.py +209 -33
  44. synth_ai/core/tracing_v3/abstractions.py +46 -0
  45. synth_ai/data/__init__.py +3 -30
  46. synth_ai/data/enums.py +1 -20
  47. synth_ai/data/rewards.py +100 -3
  48. synth_ai/products/graph_evolve/__init__.py +1 -2
  49. synth_ai/products/graph_evolve/config.py +16 -16
  50. synth_ai/products/graph_evolve/converters/__init__.py +3 -3
  51. synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
  52. synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
  53. synth_ai/products/graph_gepa/__init__.py +23 -0
  54. synth_ai/products/graph_gepa/converters/__init__.py +19 -0
  55. synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
  56. synth_ai/sdk/__init__.py +45 -35
  57. synth_ai/sdk/api/eval/__init__.py +33 -0
  58. synth_ai/sdk/api/eval/job.py +732 -0
  59. synth_ai/sdk/api/research_agent/__init__.py +276 -66
  60. synth_ai/sdk/api/train/builders.py +181 -0
  61. synth_ai/sdk/api/train/cli.py +41 -33
  62. synth_ai/sdk/api/train/configs/__init__.py +6 -4
  63. synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
  64. synth_ai/sdk/api/train/configs/rl.py +264 -16
  65. synth_ai/sdk/api/train/configs/sft.py +165 -1
  66. synth_ai/sdk/api/train/graph_validators.py +12 -12
  67. synth_ai/sdk/api/train/graphgen.py +169 -51
  68. synth_ai/sdk/api/train/graphgen_models.py +95 -45
  69. synth_ai/sdk/api/train/local_api.py +10 -0
  70. synth_ai/sdk/api/train/pollers.py +36 -0
  71. synth_ai/sdk/api/train/prompt_learning.py +390 -60
  72. synth_ai/sdk/api/train/rl.py +41 -5
  73. synth_ai/sdk/api/train/sft.py +2 -0
  74. synth_ai/sdk/api/train/task_app.py +20 -0
  75. synth_ai/sdk/api/train/validators.py +17 -17
  76. synth_ai/sdk/graphs/completions.py +239 -33
  77. synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
  78. synth_ai/sdk/learning/__init__.py +35 -5
  79. synth_ai/sdk/learning/context_learning_client.py +531 -0
  80. synth_ai/sdk/learning/context_learning_types.py +294 -0
  81. synth_ai/sdk/learning/prompt_learning_client.py +1 -1
  82. synth_ai/sdk/learning/prompt_learning_types.py +2 -1
  83. synth_ai/sdk/learning/rl/__init__.py +0 -4
  84. synth_ai/sdk/learning/rl/contracts.py +0 -4
  85. synth_ai/sdk/localapi/__init__.py +40 -0
  86. synth_ai/sdk/localapi/apps/__init__.py +28 -0
  87. synth_ai/sdk/localapi/client.py +10 -0
  88. synth_ai/sdk/localapi/contracts.py +10 -0
  89. synth_ai/sdk/localapi/helpers.py +519 -0
  90. synth_ai/sdk/localapi/rollouts.py +93 -0
  91. synth_ai/sdk/localapi/server.py +29 -0
  92. synth_ai/sdk/localapi/template.py +49 -0
  93. synth_ai/sdk/streaming/handlers.py +6 -6
  94. synth_ai/sdk/streaming/streamer.py +10 -6
  95. synth_ai/sdk/task/__init__.py +18 -5
  96. synth_ai/sdk/task/apps/__init__.py +37 -1
  97. synth_ai/sdk/task/client.py +9 -1
  98. synth_ai/sdk/task/config.py +6 -11
  99. synth_ai/sdk/task/contracts.py +137 -95
  100. synth_ai/sdk/task/in_process.py +32 -22
  101. synth_ai/sdk/task/in_process_runner.py +9 -4
  102. synth_ai/sdk/task/rubrics/__init__.py +2 -3
  103. synth_ai/sdk/task/rubrics/loaders.py +4 -4
  104. synth_ai/sdk/task/rubrics/strict.py +3 -4
  105. synth_ai/sdk/task/server.py +76 -16
  106. synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
  107. synth_ai/sdk/task/validators.py +34 -49
  108. synth_ai/sdk/training/__init__.py +7 -16
  109. synth_ai/sdk/tunnels/__init__.py +118 -0
  110. synth_ai/sdk/tunnels/cleanup.py +83 -0
  111. synth_ai/sdk/tunnels/ports.py +120 -0
  112. synth_ai/sdk/tunnels/tunneled_api.py +363 -0
  113. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
  114. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
  115. synth_ai/cli/commands/baseline/__init__.py +0 -12
  116. synth_ai/cli/commands/baseline/core.py +0 -636
  117. synth_ai/cli/commands/baseline/list.py +0 -94
  118. synth_ai/cli/commands/eval/errors.py +0 -81
  119. synth_ai/cli/commands/status/formatters.py +0 -164
  120. synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
  121. synth_ai/cli/commands/status/subcommands/usage.py +0 -203
  122. synth_ai/cli/commands/train/judge_validation.py +0 -305
  123. synth_ai/cli/usage.py +0 -159
  124. synth_ai/data/specs.py +0 -36
  125. synth_ai/sdk/api/research_agent/cli.py +0 -428
  126. synth_ai/sdk/api/research_agent/config.py +0 -357
  127. synth_ai/sdk/api/research_agent/job.py +0 -717
  128. synth_ai/sdk/baseline/__init__.py +0 -25
  129. synth_ai/sdk/baseline/config.py +0 -209
  130. synth_ai/sdk/baseline/discovery.py +0 -216
  131. synth_ai/sdk/baseline/execution.py +0 -154
  132. synth_ai/sdk/judging/__init__.py +0 -15
  133. synth_ai/sdk/judging/base.py +0 -24
  134. synth_ai/sdk/judging/client.py +0 -191
  135. synth_ai/sdk/judging/types.py +0 -42
  136. synth_ai/sdk/research_agent/__init__.py +0 -34
  137. synth_ai/sdk/research_agent/container_builder.py +0 -328
  138. synth_ai/sdk/research_agent/container_spec.py +0 -198
  139. synth_ai/sdk/research_agent/defaults.py +0 -34
  140. synth_ai/sdk/research_agent/results_collector.py +0 -69
  141. synth_ai/sdk/specs/__init__.py +0 -46
  142. synth_ai/sdk/specs/dataclasses.py +0 -149
  143. synth_ai/sdk/specs/loader.py +0 -144
  144. synth_ai/sdk/specs/serializer.py +0 -199
  145. synth_ai/sdk/specs/validation.py +0 -250
  146. synth_ai/sdk/tracing/__init__.py +0 -39
  147. synth_ai/sdk/usage/__init__.py +0 -37
  148. synth_ai/sdk/usage/client.py +0 -171
  149. synth_ai/sdk/usage/models.py +0 -261
  150. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
  151. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
  152. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
  153. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,42 @@
1
+ """RL (Reinforcement Learning) configuration models.
2
+
3
+ This module defines the configuration schema for RL training jobs using GSPO
4
+ (Group Sequence Policy Optimization) or other policy gradient methods.
5
+
6
+ Example TOML configuration:
7
+ ```toml
8
+ [algorithm]
9
+ type = "online"
10
+ method = "policy_gradient"
11
+ variety = "gspo"
12
+
13
+ [services]
14
+ task_url = "https://your-tunnel.trycloudflare.com"
15
+
16
+ [model]
17
+ base = "Qwen/Qwen3-4B"
18
+ trainer_mode = "lora"
19
+ label = "my-rl-model"
20
+
21
+ [rollout]
22
+ env_name = "my-task"
23
+ policy_name = "my-policy"
24
+ max_turns = 10
25
+ episodes_per_batch = 32
26
+ max_concurrent_rollouts = 8
27
+
28
+ [training]
29
+ num_epochs = 1
30
+ iterations_per_epoch = 20
31
+ batch_size = 16
32
+ group_size = 4
33
+ learning_rate = 5e-5
34
+ ```
35
+
36
+ See Also:
37
+ - Training reference: /training/gspo
38
+ - Job events: /sdk/jobs/rl
39
+ """
1
40
  from __future__ import annotations
2
41
 
3
42
  from collections.abc import Mapping
@@ -11,11 +50,32 @@ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel, LoraConfig, Poli
11
50
 
12
51
 
13
52
  class RLServicesConfig(ExtraModel):
53
+ """Service URLs for RL training.
54
+
55
+ Attributes:
56
+ task_url: URL of your task app (typically a Cloudflare tunnel URL).
57
+ Required for rollout execution.
58
+ verifier_url: Optional URL for verifier service. Defaults to Synth's
59
+ hosted verifier at https://synth-backend.onrender.com/api.
60
+ """
14
61
  task_url: str
15
- judge_url: str | None = None
62
+ verifier_url: str | None = None
16
63
 
17
64
 
18
65
  class ModelConfig(ExtraModel):
66
+ """Model configuration for RL training.
67
+
68
+ Specify either `base` (for a new model) or `source` (to continue from
69
+ a checkpoint), but not both.
70
+
71
+ Attributes:
72
+ source: Checkpoint ID to continue training from (e.g., "ft:job_abc123").
73
+ Mutually exclusive with `base`.
74
+ base: Base model to fine-tune (e.g., "Qwen/Qwen3-4B").
75
+ Mutually exclusive with `source`.
76
+ trainer_mode: Training mode - "lora", "qlora", or "full".
77
+ label: Human-readable identifier for this model.
78
+ """
19
79
  source: str | None = None
20
80
  base: str | None = None
21
81
  trainer_mode: str
@@ -29,6 +89,20 @@ class ModelConfig(ExtraModel):
29
89
 
30
90
 
31
91
  class RolloutConfig(ExtraModel):
92
+ """Rollout configuration for episode collection.
93
+
94
+ Controls how episodes are collected from the task app during training.
95
+
96
+ Attributes:
97
+ env_name: Environment/task name registered in your task app.
98
+ policy_name: Policy identifier for the rollout.
99
+ env_config: Optional environment-specific configuration dict.
100
+ policy_config: Optional policy-specific configuration dict.
101
+ max_turns: Maximum steps per episode before truncation.
102
+ episodes_per_batch: Number of episodes to collect per training batch.
103
+ max_concurrent_rollouts: Maximum parallel rollouts to the task app.
104
+ batches_per_step: Batches to collect per training step. Default: 1.
105
+ """
32
106
  env_name: str
33
107
  policy_name: str
34
108
  env_config: dict[str, Any] | None = None
@@ -37,10 +111,20 @@ class RolloutConfig(ExtraModel):
37
111
  episodes_per_batch: int
38
112
  max_concurrent_rollouts: int
39
113
  batches_per_step: int | None = None
40
- ops: list[str] | None = None
41
114
 
42
115
 
43
116
  class WeightSyncConfig(ExtraModel):
117
+ """Weight synchronization configuration.
118
+
119
+ Controls how model weights are synchronized between training and inference.
120
+
121
+ Attributes:
122
+ enable: Whether to enable weight sync. Default: True.
123
+ targets: Sync targets, typically ["policy"].
124
+ mode: Sync mode (advanced).
125
+ direct: Use direct sync method.
126
+ verify_every_k: Verify sync every K iterations.
127
+ """
44
128
  enable: bool | None = None
45
129
  targets: list[str] | None = None
46
130
  mode: str | None = None
@@ -49,7 +133,18 @@ class WeightSyncConfig(ExtraModel):
49
133
 
50
134
 
51
135
  class RewardsConfig(ExtraModel):
52
- """Rewards configuration for RL training."""
136
+ """Rewards configuration for RL training.
137
+
138
+ Controls step-level and event-level reward computation.
139
+
140
+ Attributes:
141
+ step_rewards_enabled: Enable step-level rewards. Default: False.
142
+ step_rewards_mode: Reward mode - "off", "decision_stepwise", or "env_sparse".
143
+ step_rewards_indicator_lambda: Lambda coefficient for indicator rewards.
144
+ step_rewards_beta: Beta coefficient for step rewards.
145
+ step_rewards_strategy: Reward computation strategy.
146
+ event_rewards_kind: Event reward aggregation - "unique" or "absolute".
147
+ """
53
148
  step_rewards_enabled: bool | None = None
54
149
  step_rewards_mode: str | None = None
55
150
  step_rewards_indicator_lambda: float | None = None
@@ -59,6 +154,23 @@ class RewardsConfig(ExtraModel):
59
154
 
60
155
 
61
156
  class RLTrainingConfig(ExtraModel):
157
+ """Training hyperparameters for RL.
158
+
159
+ Attributes:
160
+ num_epochs: Number of training epochs.
161
+ iterations_per_epoch: Training iterations per epoch.
162
+ gradient_accumulation_steps: Steps to accumulate gradients. Default: 1.
163
+ max_accumulated_minibatch: Maximum accumulated minibatch size.
164
+ max_turns: Maximum turns during training rollouts.
165
+ batch_size: Training batch size.
166
+ group_size: GSPO group size for advantage estimation.
167
+ learning_rate: Optimizer learning rate (e.g., 5e-5).
168
+ log_interval: Log metrics every N steps.
169
+ weight_sync_interval: Sync weights every N steps.
170
+ weight_sync: Nested weight sync configuration.
171
+ lora: LoRA configuration (r, alpha, dropout, target_modules).
172
+ rewards: Nested rewards configuration.
173
+ """
62
174
  num_epochs: int
63
175
  iterations_per_epoch: int
64
176
  gradient_accumulation_steps: int | None = None
@@ -83,12 +195,32 @@ class RLTrainingConfig(ExtraModel):
83
195
 
84
196
 
85
197
  class EvaluationConfig(ExtraModel):
198
+ """Evaluation configuration during training.
199
+
200
+ Attributes:
201
+ instances: Number of evaluation instances to run.
202
+ every_n_iters: Run evaluation every N training iterations.
203
+ seeds: List of seeds for reproducible evaluation.
204
+ """
86
205
  instances: int
87
206
  every_n_iters: int
88
207
  seeds: list[int]
89
208
 
90
209
 
91
- class JudgeOptionsConfig(ExtraModel):
210
+ class VerifierOptionsConfig(ExtraModel):
211
+ """Verifier scoring options.
212
+
213
+ Attributes:
214
+ event: Enable event-level verification.
215
+ outcome: Enable outcome-level verification.
216
+ provider: Verifier provider - "synth" for Synth's hosted verifier.
217
+ model: Verifier model identifier.
218
+ rubric_id: Optional rubric identifier.
219
+ rubric_overrides: Override specific rubric parameters.
220
+ tracks: Tracks to evaluate.
221
+ weights: Per-track scoring weights.
222
+ max_concurrency: Maximum concurrent verifier API calls.
223
+ """
92
224
  event: bool | None = None
93
225
  outcome: bool | None = None
94
226
  provider: str | None = None
@@ -101,22 +233,61 @@ class JudgeOptionsConfig(ExtraModel):
101
233
 
102
234
 
103
235
  class RubricConfig(ExtraModel):
104
- """Rubric configuration for reward blending."""
236
+ """Rubric configuration for reward blending.
237
+
238
+ Attributes:
239
+ enabled: Enable rubric-based scoring. Default: False.
240
+ reward_blend: Weights for reward sources - {"env": 1.0, "event": 0.0, "outcome": 0.0}.
241
+ """
105
242
  enabled: bool = False
106
243
  reward_blend: dict[str, float] | None = None # env, event, outcome weights
107
244
 
108
245
 
109
- class JudgeConfig(ExtraModel):
246
+ class VerifierConfig(ExtraModel):
247
+ """Verifier configuration for LLM-based reward scoring.
248
+
249
+ Attributes:
250
+ type: Verifier type - "synth" for Synth's hosted verifier.
251
+ timeout_s: Timeout in seconds for verifier API calls.
252
+ enabled: Master switch to enable/disable verifier scoring.
253
+ reward_blend: Reward source weights - {"env": 1.0, "event": 0.0, "outcome": 0.0}.
254
+ rubric: Deprecated - use reward_blend instead.
255
+ options: Detailed verifier options.
256
+ """
110
257
  type: str | None = None
111
258
  timeout_s: int | None = None
112
- enabled: bool | None = None # Master switch for judge/rubric
259
+ enabled: bool | None = None # Master switch for verifier/rubric
113
260
  reward_blend: dict[str, float] | None = None # NEW: nested reward blending (replaces rubric.weights)
114
261
  rubric: RubricConfig | None = None # DEPRECATED: use flat fields instead
115
- options: JudgeOptionsConfig | None = None
262
+ options: VerifierOptionsConfig | None = None
116
263
 
117
264
 
118
265
  class SmokeConfig(ExtraModel):
119
- """Configuration for local smoke testing (CLI only, ignored by trainer)."""
266
+ """Configuration for local smoke testing (CLI only, ignored by trainer).
267
+
268
+ Use this section to configure quick local tests before submitting
269
+ a full training job.
270
+
271
+ Attributes:
272
+ task_url: Override task app URL for testing.
273
+ env_name: Environment name to test.
274
+ policy_name: Policy name to test.
275
+ max_steps: Maximum steps for smoke test.
276
+ policy: Policy type - "mock", "gpt-5-nano", "openai", "groq".
277
+ model: Model identifier for the policy.
278
+ mock_backend: Mock backend type - "synthetic" or "openai".
279
+ mock_port: Port for mock backend.
280
+ return_trace: Include trace in response.
281
+ use_mock: Use mock policy.
282
+ task_app_name: Task app to auto-serve (e.g., "grpo-crafter").
283
+ task_app_port: Port for auto-served task app. Default: 8765.
284
+ task_app_env_file: Path to .env file for task app.
285
+ task_app_force: Use --force flag when serving.
286
+ sqld_auto_start: Auto-start sqld server.
287
+ sqld_db_path: Database path. Default: ./traces/local.db.
288
+ sqld_hrana_port: Hrana WebSocket port. Default: 8080.
289
+ sqld_http_port: HTTP API port. Default: 8081.
290
+ """
120
291
  # Test parameters
121
292
  task_url: str | None = None
122
293
  env_name: str | None = None
@@ -128,13 +299,13 @@ class SmokeConfig(ExtraModel):
128
299
  mock_port: int | None = None
129
300
  return_trace: bool | None = None
130
301
  use_mock: bool | None = None
131
-
302
+
132
303
  # Task app auto-start configuration
133
304
  task_app_name: str | None = None # Task app to serve (e.g., "grpo-crafter")
134
305
  task_app_port: int | None = None # Port for task app (default: 8765)
135
306
  task_app_env_file: str | None = None # Path to .env file for task app
136
307
  task_app_force: bool | None = None # Use --force flag when serving
137
-
308
+
138
309
  # sqld auto-start configuration
139
310
  sqld_auto_start: bool | None = None # Auto-start sqld server
140
311
  sqld_db_path: str | None = None # Database path (default: ./traces/local.db)
@@ -143,6 +314,67 @@ class SmokeConfig(ExtraModel):
143
314
 
144
315
 
145
316
  class RLConfig(ExtraModel):
317
+ """Root configuration for RL (Reinforcement Learning) training jobs.
318
+
319
+ This is the top-level config loaded from a TOML file. Use `RLConfig.from_path()`
320
+ to load from a file, or `RLConfig.from_mapping()` to load from a dict.
321
+
322
+ Example:
323
+ ```python
324
+ from synth_ai.sdk.api.train.configs.rl import RLConfig
325
+
326
+ # Load from file
327
+ config = RLConfig.from_path("rl_config.toml")
328
+
329
+ # Or from dict
330
+ config = RLConfig.from_mapping({
331
+ "algorithm": {"type": "online", "method": "policy_gradient", "variety": "gspo"},
332
+ "services": {"task_url": "https://my-tunnel.trycloudflare.com"},
333
+ "model": {"base": "Qwen/Qwen3-4B", "trainer_mode": "lora", "label": "my-model"},
334
+ ...
335
+ })
336
+ ```
337
+
338
+ Attributes:
339
+ algorithm: Algorithm configuration (type, method, variety).
340
+ services: Service URLs (task_url, verifier_url).
341
+ compute: GPU and compute configuration.
342
+ topology: Deprecated - use compute.topology.
343
+ vllm: vLLM inference server configuration.
344
+ reference: Deprecated - use compute.topology.reference_placement.
345
+ model: Deprecated - use policy instead.
346
+ policy: Policy/model configuration (preferred).
347
+ lora: Deprecated - use training.lora.
348
+ rollout: Rollout/episode collection configuration.
349
+ evaluation: Evaluation configuration.
350
+ training: Training hyperparameters.
351
+ rubric: Deprecated - use verifier.reward_blend.
352
+ verifier: Verifier/reward configuration.
353
+ tags: Optional metadata tags.
354
+ smoke: CLI-only smoke testing configuration.
355
+
356
+ Returns:
357
+ After training completes, you receive a result dict:
358
+ ```python
359
+ {
360
+ "status": "succeeded",
361
+ "final_reward": 0.85,
362
+ "model_id": "ft:Qwen/Qwen3-0.6B:job_abc123",
363
+ "checkpoints": [
364
+ {"step": 100, "path": "..."},
365
+ {"step": 200, "path": "..."},
366
+ ],
367
+ }
368
+ ```
369
+
370
+ Events:
371
+ During training, you'll receive streaming events:
372
+ - `rl.created` - Job created
373
+ - `rl.running` - Training started
374
+ - `rl.iteration.complete` - Iteration finished with metrics
375
+ - `rl.evaluation.complete` - Evaluation finished with scores
376
+ - `rl.succeeded` / `rl.failed` - Terminal states
377
+ """
146
378
  algorithm: AlgorithmConfig
147
379
  services: RLServicesConfig
148
380
  compute: ComputeConfig | None = None
@@ -155,29 +387,45 @@ class RLConfig(ExtraModel):
155
387
  rollout: RolloutConfig | None = None
156
388
  evaluation: EvaluationConfig | None = None
157
389
  training: RLTrainingConfig | None = None
158
- rubric: dict[str, Any] | None = None # DEPRECATED: use judge.reward_blend and judge.enabled instead
159
- judge: JudgeConfig | None = None
390
+ rubric: dict[str, Any] | None = None # DEPRECATED: use verifier.reward_blend and verifier.enabled instead
391
+ verifier: VerifierConfig | None = None
160
392
  tags: dict[str, Any] | None = None
161
393
  smoke: SmokeConfig | None = None # CLI-only: local smoke testing config (ignored by trainer)
162
394
 
163
395
  def to_dict(self) -> dict[str, Any]:
396
+ """Convert config to a dictionary."""
164
397
  return self.model_dump(mode="python", exclude_none=True)
165
398
 
166
399
  @classmethod
167
400
  def from_mapping(cls, data: Mapping[str, Any]) -> RLConfig:
168
- """Load RL config from dict/TOML mapping."""
401
+ """Load RL config from dict/TOML mapping.
402
+
403
+ Args:
404
+ data: Dictionary or TOML mapping with configuration.
405
+
406
+ Returns:
407
+ Validated RLConfig instance.
408
+ """
169
409
  return cls.model_validate(data)
170
410
 
171
411
  @classmethod
172
412
  def from_path(cls, path: Path) -> RLConfig:
413
+ """Load RL config from a TOML file.
414
+
415
+ Args:
416
+ path: Path to the TOML configuration file.
417
+
418
+ Returns:
419
+ Validated RLConfig instance.
420
+ """
173
421
  content = load_toml(path)
174
422
  return cls.from_mapping(content)
175
423
 
176
424
 
177
425
  __all__ = [
178
426
  "EvaluationConfig",
179
- "JudgeConfig",
180
- "JudgeOptionsConfig",
427
+ "VerifierConfig",
428
+ "VerifierOptionsConfig",
181
429
  "ModelConfig",
182
430
  "RLConfig",
183
431
  "RLServicesConfig",
@@ -1,3 +1,40 @@
1
+ """SFT (Supervised Fine-Tuning) configuration models.
2
+
3
+ This module defines the configuration schema for SFT training jobs.
4
+
5
+ Example TOML configuration:
6
+ ```toml
7
+ [algorithm]
8
+ type = "offline"
9
+ method = "sft"
10
+
11
+ [job]
12
+ model = "Qwen/Qwen3-4B"
13
+ data_path = "training_data.jsonl"
14
+
15
+ [compute]
16
+ gpu_type = "H100"
17
+ gpu_count = 1
18
+
19
+ [training]
20
+ mode = "lora"
21
+
22
+ [training.lora]
23
+ r = 16
24
+ alpha = 32
25
+ dropout = 0.1
26
+
27
+ [hyperparameters]
28
+ n_epochs = 3
29
+ batch_size = 4
30
+ learning_rate = 2e-5
31
+ sequence_length = 2048
32
+ ```
33
+
34
+ See Also:
35
+ - Training reference: /training/sft
36
+ - Quickstart: /quickstart/supervised-fine-tuning
37
+ """
1
38
  from __future__ import annotations
2
39
 
3
40
  from collections.abc import Mapping
@@ -11,6 +48,14 @@ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel, LoraConfig, Poli
11
48
 
12
49
 
13
50
  class JobConfig(ExtraModel):
51
+ """Core job configuration for SFT.
52
+
53
+ Attributes:
54
+ model: Base model to fine-tune (e.g., "Qwen/Qwen3-4B", "meta-llama/Llama-3-8B").
55
+ data: Dataset identifier (if using registered datasets).
56
+ data_path: Path to JSONL training data file.
57
+ poll_seconds: Polling interval for job status. Default: 10.
58
+ """
14
59
  model: str
15
60
  data: str | None = None
16
61
  data_path: str | None = None
@@ -18,11 +63,27 @@ class JobConfig(ExtraModel):
18
63
 
19
64
 
20
65
  class SFTDataConfig(ExtraModel):
66
+ """Data configuration for SFT training.
67
+
68
+ Attributes:
69
+ topology: Data loading topology configuration.
70
+ validation_path: Path to validation JSONL file for eval during training.
71
+ """
21
72
  topology: dict[str, Any] | None = None
22
73
  validation_path: str | None = None
23
74
 
24
75
 
25
76
  class TrainingValidationConfig(ExtraModel):
77
+ """Validation configuration during training.
78
+
79
+ Attributes:
80
+ enabled: Enable validation during training. Default: False.
81
+ evaluation_strategy: When to evaluate - "steps" or "epoch".
82
+ eval_steps: Evaluate every N steps (if strategy is "steps").
83
+ save_best_model_at_end: Save only the best model checkpoint.
84
+ metric_for_best_model: Metric to use for best model selection (e.g., "eval_loss").
85
+ greater_is_better: Whether higher metric is better. Default: False for loss.
86
+ """
26
87
  enabled: bool | None = None
27
88
  evaluation_strategy: str | None = None
28
89
  eval_steps: int | None = None
@@ -32,6 +93,14 @@ class TrainingValidationConfig(ExtraModel):
32
93
 
33
94
 
34
95
  class TrainingConfig(ExtraModel):
96
+ """Training mode configuration.
97
+
98
+ Attributes:
99
+ mode: Training mode - "lora", "qlora", or "full".
100
+ use_qlora: Enable QLoRA (4-bit quantized LoRA). Default: False.
101
+ validation: Validation configuration.
102
+ lora: LoRA hyperparameters (r, alpha, dropout, target_modules).
103
+ """
35
104
  mode: str | None = None
36
105
  use_qlora: bool | None = None
37
106
  validation: TrainingValidationConfig | None = None
@@ -39,6 +108,18 @@ class TrainingConfig(ExtraModel):
39
108
 
40
109
 
41
110
  class HyperparametersParallelism(ExtraModel):
111
+ """Parallelism configuration for distributed training.
112
+
113
+ Attributes:
114
+ use_deepspeed: Enable DeepSpeed. Default: False.
115
+ deepspeed_stage: DeepSpeed ZeRO stage (1, 2, or 3).
116
+ fsdp: Enable PyTorch FSDP. Default: False.
117
+ bf16: Use bfloat16 precision. Default: True on supported hardware.
118
+ fp16: Use float16 precision. Default: False.
119
+ activation_checkpointing: Enable gradient checkpointing. Default: False.
120
+ tensor_parallel_size: Tensor parallelism degree.
121
+ pipeline_parallel_size: Pipeline parallelism degree.
122
+ """
42
123
  use_deepspeed: bool | None = None
43
124
  deepspeed_stage: int | None = None
44
125
  fsdp: bool | None = None
@@ -50,6 +131,21 @@ class HyperparametersParallelism(ExtraModel):
50
131
 
51
132
 
52
133
  class HyperparametersConfig(ExtraModel):
134
+ """Training hyperparameters for SFT.
135
+
136
+ Attributes:
137
+ n_epochs: Number of training epochs. Default: 1.
138
+ batch_size: Training batch size (alias for global_batch).
139
+ global_batch: Global batch size across all GPUs.
140
+ per_device_batch: Per-device batch size.
141
+ gradient_accumulation_steps: Steps to accumulate gradients. Default: 1.
142
+ sequence_length: Maximum sequence length. Default: 2048.
143
+ learning_rate: Optimizer learning rate (e.g., 2e-5).
144
+ warmup_ratio: Fraction of steps for LR warmup. Default: 0.1.
145
+ train_kind: Training variant (advanced).
146
+ weight_decay: Weight decay coefficient. Default: 0.01.
147
+ parallelism: Distributed training configuration.
148
+ """
53
149
  n_epochs: int = 1
54
150
  batch_size: int | None = None
55
151
  global_batch: int | None = None
@@ -64,6 +160,58 @@ class HyperparametersConfig(ExtraModel):
64
160
 
65
161
 
66
162
  class SFTConfig(ExtraModel):
163
+ """Root configuration for SFT (Supervised Fine-Tuning) jobs.
164
+
165
+ This is the top-level config loaded from a TOML file.
166
+
167
+ Example:
168
+ ```python
169
+ from synth_ai.sdk.api.train.configs.sft import SFTConfig
170
+
171
+ # Load from file
172
+ config = SFTConfig.from_path("sft_config.toml")
173
+
174
+ # Or from dict
175
+ config = SFTConfig.from_mapping({
176
+ "job": {"model": "Qwen/Qwen3-4B", "data_path": "data.jsonl"},
177
+ "hyperparameters": {"n_epochs": 3, "learning_rate": 2e-5},
178
+ })
179
+ ```
180
+
181
+ Attributes:
182
+ algorithm: Algorithm configuration (type="offline", method="sft").
183
+ job: Core job configuration (model, data_path).
184
+ policy: Policy configuration (preferred over job.model).
185
+ compute: GPU and compute configuration.
186
+ data: Data loading configuration.
187
+ training: Training mode (lora, full) and LoRA config.
188
+ hyperparameters: Training hyperparameters.
189
+ lora: Deprecated - use training.lora instead.
190
+ tags: Optional metadata tags.
191
+
192
+ Returns:
193
+ After training completes, you receive a result dict:
194
+ ```python
195
+ {
196
+ "status": "succeeded",
197
+ "model_id": "ft:Qwen/Qwen3-4B:sft_abc123",
198
+ "final_loss": 0.42,
199
+ "checkpoints": [
200
+ {"epoch": 1, "loss": 0.65, "path": "..."},
201
+ {"epoch": 2, "loss": 0.52, "path": "..."},
202
+ {"epoch": 3, "loss": 0.42, "path": "..."},
203
+ ],
204
+ }
205
+ ```
206
+
207
+ Events:
208
+ During training, you'll receive streaming events:
209
+ - `sft.created` - Job created
210
+ - `sft.running` - Training started
211
+ - `sft.epoch.complete` - Epoch finished with loss
212
+ - `sft.checkpoint.saved` - Checkpoint saved
213
+ - `sft.succeeded` / `sft.failed` - Terminal states
214
+ """
67
215
  algorithm: AlgorithmConfig | None = None
68
216
  job: JobConfig
69
217
  policy: PolicyConfig | None = None # NEW: unified policy section
@@ -75,15 +223,31 @@ class SFTConfig(ExtraModel):
75
223
  tags: dict[str, Any] | None = None
76
224
 
77
225
  def to_dict(self) -> dict[str, Any]:
226
+ """Convert config to a dictionary."""
78
227
  return self.model_dump(mode="python", exclude_none=True)
79
228
 
80
229
  @classmethod
81
230
  def from_mapping(cls, data: Mapping[str, Any]) -> SFTConfig:
82
- """Load SFT config from dict/TOML mapping."""
231
+ """Load SFT config from dict/TOML mapping.
232
+
233
+ Args:
234
+ data: Dictionary or TOML mapping with configuration.
235
+
236
+ Returns:
237
+ Validated SFTConfig instance.
238
+ """
83
239
  return cls.model_validate(data)
84
240
 
85
241
  @classmethod
86
242
  def from_path(cls, path: Path) -> SFTConfig:
243
+ """Load SFT config from a TOML file.
244
+
245
+ Args:
246
+ path: Path to the TOML configuration file.
247
+
248
+ Returns:
249
+ Validated SFTConfig instance.
250
+ """
87
251
  content = load_toml(path)
88
252
  return cls.from_mapping(content)
89
253
 
@@ -1,7 +1,7 @@
1
- """TOML schema + validation for ADAS/Graphs jobs.
1
+ """TOML schema + validation for Graph Opt (GraphGen) jobs.
2
2
 
3
- Graphs jobs (aka ADAS jobs) are JSON-dataset-first, but for convenience we also
4
- support a small TOML wrapper that points at an GraphGenTaskSet JSON file plus a few
3
+ Graph Opt jobs are JSON-dataset-first, but for convenience we also
4
+ support a small TOML wrapper that points at a GraphGenTaskSet JSON file plus a few
5
5
  optimization knobs.
6
6
 
7
7
  Example `graph.toml`:
@@ -16,7 +16,7 @@ auto_start = true # optional
16
16
 
17
17
  [graph.metadata]
18
18
  session_id = "sess_123"
19
- parent_job_id = "adas_parent"
19
+ parent_job_id = "graph_opt_parent"
20
20
  population_size = 4
21
21
  num_generations = 5
22
22
  ```
@@ -29,7 +29,7 @@ from dataclasses import dataclass
29
29
  from pathlib import Path
30
30
  from typing import Any, Dict, List, Optional, cast, Literal
31
31
 
32
- from .graphgen_models import GraphGenJobConfig, GraphGenTaskSet, load_graphgen_taskset
32
+ from .graphgen_models import GraphGenJobConfig, GraphGenTaskSet, load_graphgen_taskset
33
33
  from .graphgen_validators import GraphGenValidationError, validate_graphgen_job_config
34
34
 
35
35
 
@@ -112,8 +112,8 @@ def validate_graph_job_section(
112
112
  policy_provider=section.get("policy_provider"),
113
113
  rollout_budget=int(rollout_budget) if rollout_budget is not None else 100,
114
114
  proposer_effort=cast(Literal["low", "medium", "high"], str(proposer_effort)) if proposer_effort is not None else "medium",
115
- judge_model=section.get("judge_model"),
116
- judge_provider=section.get("judge_provider"),
115
+ verifier_model=section.get("verifier_model"),
116
+ verifier_provider=section.get("verifier_provider"),
117
117
  population_size=section.get("population_size", 4),
118
118
  num_generations=section.get("num_generations"),
119
119
  )
@@ -151,17 +151,17 @@ def load_graph_job_toml(path: str | Path) -> GraphTomlResult:
151
151
  with open(path, "rb") as f:
152
152
  cfg = tomllib.load(f)
153
153
 
154
- section = cfg.get("graph") or cfg.get("adas") or {}
154
+ section = cfg.get("graph") or {}
155
155
  return validate_graph_job_section(section, base_dir=path.parent)
156
156
 
157
157
 
158
158
  def validate_graph_job_payload(payload: Dict[str, Any]) -> None:
159
- """Validate a graph/ADAS job payload (matching backend create request).
159
+ """Validate a graph job payload (matching backend create request).
160
160
 
161
161
  Expected keys:
162
162
  - dataset: GraphGenTaskSet dict
163
163
  - policy_model, rollout_budget, proposer_effort
164
- - optional judge_model/judge_provider
164
+ - optional verifier_model/verifier_provider
165
165
  - optional metadata (population_size/num_generations)
166
166
  """
167
167
  errors: List[Dict[str, Any]] = []
@@ -188,8 +188,8 @@ def validate_graph_job_payload(payload: Dict[str, Any]) -> None:
188
188
  policy_provider=payload.get("policy_provider"),
189
189
  rollout_budget=int(payload.get("rollout_budget") or 100),
190
190
  proposer_effort=cast(Literal["low", "medium", "high"], str(payload.get("proposer_effort") or "medium")),
191
- judge_model=payload.get("judge_model"),
192
- judge_provider=payload.get("judge_provider"),
191
+ verifier_model=payload.get("verifier_model"),
192
+ verifier_provider=payload.get("verifier_provider"),
193
193
  population_size=metadata.get("population_size", 4),
194
194
  num_generations=metadata.get("num_generations"),
195
195
  )