synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (169) hide show
  1. examples/baseline/banking77_baseline.py +204 -0
  2. examples/baseline/crafter_baseline.py +407 -0
  3. examples/baseline/pokemon_red_baseline.py +326 -0
  4. examples/baseline/simple_baseline.py +56 -0
  5. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  6. examples/blog_posts/gepa/README.md +355 -0
  7. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  9. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  10. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  13. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  15. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  16. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  18. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  19. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  20. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  21. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  22. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  23. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  24. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  25. examples/blog_posts/gepa/task_apps.py +105 -0
  26. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  27. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  28. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  29. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
  30. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
  31. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  32. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  33. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  34. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  35. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  36. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  37. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  38. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  39. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  40. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  41. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  42. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  43. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
  44. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  45. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
  46. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
  47. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  48. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  49. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  50. examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
  51. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
  52. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
  53. examples/rl/configs/rl_from_base_qwen17.toml +1 -0
  54. examples/swe/task_app/hosted/inference/openai_client.py +0 -34
  55. examples/swe/task_app/hosted/policy_routes.py +17 -0
  56. examples/swe/task_app/hosted/rollout.py +4 -2
  57. examples/task_apps/banking77/__init__.py +6 -0
  58. examples/task_apps/banking77/banking77_task_app.py +841 -0
  59. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  60. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  61. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  62. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  63. examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
  64. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  65. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
  66. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
  67. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
  68. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  69. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  70. examples/task_apps/gepa_benchmarks/common.py +260 -0
  71. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  72. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  73. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  74. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  75. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  76. examples/task_apps/pokemon_red/task_app.py +254 -36
  77. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
  78. examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  84. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
  85. synth_ai/api/train/builders.py +90 -1
  86. synth_ai/api/train/cli.py +396 -21
  87. synth_ai/api/train/config_finder.py +13 -2
  88. synth_ai/api/train/configs/__init__.py +15 -1
  89. synth_ai/api/train/configs/prompt_learning.py +442 -0
  90. synth_ai/api/train/configs/rl.py +29 -0
  91. synth_ai/api/train/task_app.py +1 -1
  92. synth_ai/api/train/validators.py +277 -0
  93. synth_ai/baseline/__init__.py +25 -0
  94. synth_ai/baseline/config.py +209 -0
  95. synth_ai/baseline/discovery.py +214 -0
  96. synth_ai/baseline/execution.py +146 -0
  97. synth_ai/cli/__init__.py +85 -17
  98. synth_ai/cli/__main__.py +0 -0
  99. synth_ai/cli/claude.py +70 -0
  100. synth_ai/cli/codex.py +84 -0
  101. synth_ai/cli/commands/__init__.py +1 -0
  102. synth_ai/cli/commands/baseline/__init__.py +12 -0
  103. synth_ai/cli/commands/baseline/core.py +637 -0
  104. synth_ai/cli/commands/baseline/list.py +93 -0
  105. synth_ai/cli/commands/eval/core.py +13 -10
  106. synth_ai/cli/commands/filter/core.py +53 -17
  107. synth_ai/cli/commands/help/core.py +0 -1
  108. synth_ai/cli/commands/smoke/__init__.py +7 -0
  109. synth_ai/cli/commands/smoke/core.py +1436 -0
  110. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  111. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  112. synth_ai/cli/commands/train/judge_schemas.py +1 -0
  113. synth_ai/cli/commands/train/judge_validation.py +1 -0
  114. synth_ai/cli/commands/train/validation.py +0 -57
  115. synth_ai/cli/demo.py +35 -3
  116. synth_ai/cli/deploy/__init__.py +40 -25
  117. synth_ai/cli/deploy.py +162 -0
  118. synth_ai/cli/legacy_root_backup.py +14 -8
  119. synth_ai/cli/opencode.py +107 -0
  120. synth_ai/cli/root.py +9 -5
  121. synth_ai/cli/task_app_deploy.py +1 -1
  122. synth_ai/cli/task_apps.py +53 -53
  123. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  124. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  125. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  126. synth_ai/judge_schemas.py +1 -0
  127. synth_ai/learning/__init__.py +10 -0
  128. synth_ai/learning/prompt_learning_client.py +276 -0
  129. synth_ai/learning/prompt_learning_types.py +184 -0
  130. synth_ai/pricing/__init__.py +2 -0
  131. synth_ai/pricing/model_pricing.py +57 -0
  132. synth_ai/streaming/handlers.py +53 -4
  133. synth_ai/streaming/streamer.py +19 -0
  134. synth_ai/task/apps/__init__.py +1 -0
  135. synth_ai/task/config.py +2 -0
  136. synth_ai/task/tracing_utils.py +25 -25
  137. synth_ai/task/validators.py +44 -8
  138. synth_ai/task_app_cfgs.py +21 -0
  139. synth_ai/tracing_v3/config.py +162 -19
  140. synth_ai/tracing_v3/constants.py +1 -1
  141. synth_ai/tracing_v3/db_config.py +24 -38
  142. synth_ai/tracing_v3/storage/config.py +47 -13
  143. synth_ai/tracing_v3/storage/factory.py +3 -3
  144. synth_ai/tracing_v3/turso/daemon.py +113 -11
  145. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  146. synth_ai/types.py +8 -0
  147. synth_ai/urls.py +11 -0
  148. synth_ai/utils/__init__.py +30 -1
  149. synth_ai/utils/agents.py +74 -0
  150. synth_ai/utils/bin.py +39 -0
  151. synth_ai/utils/cli.py +149 -5
  152. synth_ai/utils/env.py +17 -17
  153. synth_ai/utils/json.py +72 -0
  154. synth_ai/utils/modal.py +283 -1
  155. synth_ai/utils/paths.py +48 -0
  156. synth_ai/utils/uvicorn.py +113 -0
  157. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
  158. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
  159. synth_ai/cli/commands/deploy/__init__.py +0 -23
  160. synth_ai/cli/commands/deploy/core.py +0 -614
  161. synth_ai/cli/commands/deploy/errors.py +0 -72
  162. synth_ai/cli/commands/deploy/validation.py +0 -11
  163. synth_ai/cli/deploy/core.py +0 -5
  164. synth_ai/cli/deploy/errors.py +0 -23
  165. synth_ai/cli/deploy/validation.py +0 -5
  166. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  167. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  168. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  169. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,277 @@
1
+ """SDK-side validation for training configs - catch errors BEFORE sending to backend."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import click
7
+
8
+
9
+ class ConfigValidationError(Exception):
10
+ """Raised when a training config is invalid."""
11
+ pass
12
+
13
+
14
+ def validate_prompt_learning_config(config_data: dict[str, Any], config_path: Path) -> None:
15
+ """
16
+ Validate prompt learning config BEFORE sending to backend.
17
+
18
+ This catches common errors early with clear messages instead of cryptic backend errors.
19
+
20
+ Args:
21
+ config_data: Parsed TOML/JSON config
22
+ config_path: Path to config file (for error messages)
23
+
24
+ Raises:
25
+ ConfigValidationError: If config is invalid
26
+ click.ClickException: If validation fails (for CLI)
27
+ """
28
+ errors: list[str] = []
29
+
30
+ # Check for prompt_learning section
31
+ pl_section = config_data.get("prompt_learning")
32
+ if not pl_section:
33
+ errors.append(
34
+ "Missing [prompt_learning] section in config. "
35
+ "Expected: [prompt_learning] with algorithm, task_app_url, etc."
36
+ )
37
+ _raise_validation_errors(errors, config_path)
38
+ return
39
+
40
+ if not isinstance(pl_section, dict):
41
+ errors.append(
42
+ f"[prompt_learning] must be a table/dict, got {type(pl_section).__name__}"
43
+ )
44
+ _raise_validation_errors(errors, config_path)
45
+ return
46
+
47
+ # CRITICAL: Validate algorithm field
48
+ algorithm = pl_section.get("algorithm")
49
+ if not algorithm:
50
+ errors.append(
51
+ "Missing required field: prompt_learning.algorithm\n"
52
+ " Must be one of: 'gepa', 'mipro'\n"
53
+ " Example:\n"
54
+ " [prompt_learning]\n"
55
+ " algorithm = \"gepa\""
56
+ )
57
+ elif algorithm not in ("gepa", "mipro"):
58
+ errors.append(
59
+ f"Invalid algorithm: '{algorithm}'\n"
60
+ f" Must be one of: 'gepa', 'mipro' (Note: MIPRO not yet implemented)\n"
61
+ f" Got: '{algorithm}'"
62
+ )
63
+
64
+ # Validate task_app_url
65
+ task_app_url = pl_section.get("task_app_url")
66
+ if not task_app_url:
67
+ errors.append(
68
+ "Missing required field: prompt_learning.task_app_url\n"
69
+ " Example:\n"
70
+ " task_app_url = \"http://127.0.0.1:8102\""
71
+ )
72
+ elif not isinstance(task_app_url, str):
73
+ errors.append(
74
+ f"task_app_url must be a string, got {type(task_app_url).__name__}"
75
+ )
76
+ elif not task_app_url.startswith(("http://", "https://")):
77
+ errors.append(
78
+ f"task_app_url must start with http:// or https://, got: '{task_app_url}'"
79
+ )
80
+
81
+ # Validate initial_prompt if present
82
+ initial_prompt = pl_section.get("initial_prompt")
83
+ if initial_prompt:
84
+ if not isinstance(initial_prompt, dict):
85
+ errors.append(
86
+ f"prompt_learning.initial_prompt must be a table/dict, got {type(initial_prompt).__name__}"
87
+ )
88
+ else:
89
+ # Validate messages array
90
+ messages = initial_prompt.get("messages")
91
+ if messages is not None:
92
+ if not isinstance(messages, list):
93
+ errors.append(
94
+ f"prompt_learning.initial_prompt.messages must be an array, got {type(messages).__name__}"
95
+ )
96
+ elif len(messages) == 0:
97
+ errors.append(
98
+ "prompt_learning.initial_prompt.messages is empty (must have at least one message)"
99
+ )
100
+
101
+ # Validate policy config
102
+ policy = pl_section.get("policy")
103
+ if not policy or not isinstance(policy, dict):
104
+ errors.append("Missing [prompt_learning.policy] section or not a table")
105
+ else:
106
+ # Enforce inference_mode
107
+ mode = str(policy.get("inference_mode", "")).strip().lower()
108
+ if not mode:
109
+ errors.append("Missing required field: prompt_learning.policy.inference_mode (must be 'synth_hosted')")
110
+ elif mode != "synth_hosted":
111
+ errors.append("prompt_learning.policy.inference_mode must be 'synth_hosted' (bring_your_own unsupported)")
112
+ # Required fields for synth_hosted
113
+ provider = (policy.get("provider") or "").strip().lower()
114
+ model = (policy.get("model") or "").strip()
115
+ inference_url = (policy.get("inference_url") or "").strip()
116
+ if not provider:
117
+ errors.append("Missing required field: prompt_learning.policy.provider")
118
+ if not model:
119
+ errors.append("Missing required field: prompt_learning.policy.model")
120
+ if not inference_url:
121
+ errors.append("Missing required field: prompt_learning.policy.inference_url")
122
+ elif not isinstance(inference_url, str) or not inference_url.startswith(("http://", "https://")):
123
+ errors.append(f"policy.inference_url must start with http:// or https://, got: '{inference_url}'")
124
+
125
+ # Validate algorithm-specific config
126
+ if algorithm == "gepa":
127
+ gepa_config = pl_section.get("gepa")
128
+ if not gepa_config or not isinstance(gepa_config, dict):
129
+ errors.append("Missing [prompt_learning.gepa] section for GEPA algorithm")
130
+ else:
131
+ # Numeric sanity checks
132
+ def _pos_int(name: str) -> None:
133
+ val = gepa_config.get(name)
134
+ if val is not None:
135
+ try:
136
+ ival = int(val)
137
+ if ival <= 0:
138
+ errors.append(f"prompt_learning.gepa.{name} must be > 0")
139
+ except Exception:
140
+ errors.append(f"prompt_learning.gepa.{name} must be an integer")
141
+ for fld in ("initial_population_size", "num_generations", "children_per_generation", "max_concurrent_rollouts"):
142
+ _pos_int(fld)
143
+ # Budget cap
144
+ if "max_spend_usd" in gepa_config and gepa_config.get("max_spend_usd") is not None:
145
+ try:
146
+ f = float(gepa_config.get("max_spend_usd"))
147
+ if f <= 0:
148
+ errors.append("prompt_learning.gepa.max_spend_usd must be > 0 when provided")
149
+ except Exception:
150
+ errors.append("prompt_learning.gepa.max_spend_usd must be numeric")
151
+
152
+ elif algorithm == "mipro":
153
+ # MIPRO is not yet implemented in synth-ai
154
+ errors.append(
155
+ "MIPRO algorithm is not yet implemented in synth-ai.\n"
156
+ " Please use 'gepa' algorithm for prompt optimization.\n"
157
+ " MIPRO support is planned for a future release.\n"
158
+ " Example:\n"
159
+ " [prompt_learning]\n"
160
+ " algorithm = \"gepa\"\n"
161
+ " [prompt_learning.gepa]\n"
162
+ " # ... gepa configuration"
163
+ )
164
+
165
+ # Raise all errors at once for better UX
166
+ if errors:
167
+ _raise_validation_errors(errors, config_path)
168
+
169
+
170
+ def _raise_validation_errors(errors: list[str], config_path: Path) -> None:
171
+ """Format and raise validation errors."""
172
+ error_msg = (
173
+ f"\n❌ Invalid prompt learning config: {config_path}\n\n"
174
+ f"Found {len(errors)} error(s):\n\n"
175
+ )
176
+
177
+ for i, error in enumerate(errors, 1):
178
+ # Indent multi-line errors
179
+ indented_error = "\n ".join(error.split("\n"))
180
+ error_msg += f"{i}. {indented_error}\n\n"
181
+
182
+ error_msg += (
183
+ "📖 See example configs:\n"
184
+ " - examples/blog_posts/gepa/configs/banking77_gepa_local.toml\n"
185
+ " - examples/blog_posts/mipro/configs/banking77_mipro_local.toml\n"
186
+ )
187
+
188
+ raise click.ClickException(error_msg)
189
+
190
+
191
+ def validate_rl_config(config_data: dict[str, Any], config_path: Path) -> None:
192
+ """
193
+ Validate RL config BEFORE sending to backend.
194
+
195
+ Args:
196
+ config_data: Parsed TOML/JSON config
197
+ config_path: Path to config file (for error messages)
198
+
199
+ Raises:
200
+ ConfigValidationError: If config is invalid
201
+ click.ClickException: If validation fails (for CLI)
202
+ """
203
+ errors: list[str] = []
204
+
205
+ # Check for rl section
206
+ rl_section = config_data.get("rl") or config_data.get("online_rl")
207
+ if not rl_section:
208
+ errors.append(
209
+ "Missing [rl] or [online_rl] section in config"
210
+ )
211
+ _raise_validation_errors(errors, config_path)
212
+ return
213
+
214
+ # Validate algorithm
215
+ algorithm = rl_section.get("algorithm")
216
+ if not algorithm:
217
+ errors.append(
218
+ "Missing required field: rl.algorithm\n"
219
+ " Must be one of: 'grpo', 'ppo', etc."
220
+ )
221
+
222
+ # Validate task_url
223
+ task_url = rl_section.get("task_url")
224
+ if not task_url:
225
+ errors.append(
226
+ "Missing required field: rl.task_url"
227
+ )
228
+ elif not isinstance(task_url, str):
229
+ errors.append(
230
+ f"task_url must be a string, got {type(task_url).__name__}"
231
+ )
232
+
233
+ if errors:
234
+ _raise_validation_errors(errors, config_path)
235
+
236
+
237
+ def validate_sft_config(config_data: dict[str, Any], config_path: Path) -> None:
238
+ """
239
+ Validate SFT config BEFORE sending to backend.
240
+
241
+ Args:
242
+ config_data: Parsed TOML/JSON config
243
+ config_path: Path to config file (for error messages)
244
+
245
+ Raises:
246
+ ConfigValidationError: If config is invalid
247
+ click.ClickException: If validation fails (for CLI)
248
+ """
249
+ errors: list[str] = []
250
+
251
+ # Check for sft section
252
+ sft_section = config_data.get("sft")
253
+ if not sft_section:
254
+ errors.append(
255
+ "Missing [sft] section in config"
256
+ )
257
+ _raise_validation_errors(errors, config_path)
258
+ return
259
+
260
+ # Validate model
261
+ model = sft_section.get("model")
262
+ if not model:
263
+ errors.append(
264
+ "Missing required field: sft.model"
265
+ )
266
+
267
+ if errors:
268
+ _raise_validation_errors(errors, config_path)
269
+
270
+
271
+ __all__ = [
272
+ "ConfigValidationError",
273
+ "validate_prompt_learning_config",
274
+ "validate_rl_config",
275
+ "validate_sft_config",
276
+ ]
277
+
@@ -0,0 +1,25 @@
1
+ """Baseline file system for self-contained task evaluation.
2
+
3
+ This package provides abstractions for defining and executing baseline evaluations
4
+ without requiring deployed task apps. Supports both class-based and function-based
5
+ task runners with first-class train/val/test split support.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from synth_ai.baseline.config import (
11
+ BaselineConfig,
12
+ BaselineResults,
13
+ BaselineTaskRunner,
14
+ DataSplit,
15
+ TaskResult,
16
+ )
17
+
18
+ __all__ = [
19
+ "BaselineConfig",
20
+ "BaselineTaskRunner",
21
+ "DataSplit",
22
+ "TaskResult",
23
+ "BaselineResults",
24
+ ]
25
+
@@ -0,0 +1,209 @@
1
+ """Core dataclasses for baseline configuration and results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
8
+
9
+
10
+ class BaselineTaskRunner:
11
+ """
12
+ Base class for task runners.
13
+
14
+ Subclasses should implement `run_task` method for class-based approach,
15
+ or you can use standalone async functions for function-based approach.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ policy_config: Dict[str, Any],
21
+ env_config: Dict[str, Any],
22
+ ):
23
+ """
24
+ Initialize task runner with configuration.
25
+
26
+ Args:
27
+ policy_config: Policy configuration (model, temperature, etc.)
28
+ env_config: Environment configuration (max_steps, difficulty, etc.)
29
+ """
30
+ self.policy_config = policy_config
31
+ self.env_config = env_config
32
+
33
+ async def run_task(self, seed: int) -> TaskResult:
34
+ """
35
+ Execute a single task instance.
36
+
37
+ This method is called for each seed in the selected split.
38
+
39
+ Args:
40
+ seed: The seed/index for this task instance
41
+
42
+ Returns:
43
+ TaskResult: Structured result containing success, rewards, metadata, trace
44
+ """
45
+ raise NotImplementedError("Subclasses must implement run_task method")
46
+
47
+
48
+ @dataclass
49
+ class DataSplit:
50
+ """Definition of a data split (train/val/test)."""
51
+
52
+ name: str # "train", "val", "test"
53
+ seeds: List[int] # Seed/index values for this split
54
+ metadata: Dict[str, Any] = field(default_factory=dict) # Optional metadata
55
+
56
+
57
+ @dataclass
58
+ class TaskResult:
59
+ """Result from a single task execution."""
60
+
61
+ # Required: Seed/index that was evaluated
62
+ seed: int
63
+
64
+ # Required: Did the task complete successfully?
65
+ success: bool
66
+
67
+ # Required: Outcome reward for the episode
68
+ outcome_reward: float
69
+
70
+ # Optional: Event rewards (step-level)
71
+ event_rewards: List[Dict[str, Any]] = field(default_factory=list)
72
+
73
+ # Optional: Total steps/turns taken
74
+ total_steps: int = 0
75
+
76
+ # Optional: Metadata (achievements, completion info, etc.)
77
+ metadata: Dict[str, Any] = field(default_factory=dict)
78
+
79
+ # Optional: Error information if success=False
80
+ error: Optional[str] = None
81
+
82
+ # Optional: v3 trace (SessionTrace dict)
83
+ trace: Optional[Dict[str, Any]] = None
84
+
85
+
86
+ # Type alias for task runner (can be class or function)
87
+ TaskRunnerType = (
88
+ type[BaselineTaskRunner]
89
+ | Callable[[int, dict[str, Any], dict[str, Any]], Any] # Function signature
90
+ )
91
+
92
+ # Type alias for result aggregator (can be class or function)
93
+ AggregatorType = (
94
+ type[Any] # Class with aggregate() method
95
+ | Callable[[list[TaskResult]], dict[str, Any]] # Function signature
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class BaselineConfig:
101
+ """Configuration for a baseline file.
102
+
103
+ A baseline file defines how to evaluate a task without requiring
104
+ a deployed task app. It provides self-contained evaluation logic
105
+ with first-class support for train/val/test splits.
106
+
107
+ Supports both class-based and function-based task runners:
108
+ - Class-based: Pass a class that inherits from BaselineTaskRunner
109
+ - Function-based: Pass an async function with signature:
110
+ async def task_runner(seed: int, policy_config: Dict[str, Any],
111
+ env_config: Dict[str, Any]) -> TaskResult
112
+ """
113
+
114
+ # Required: Unique identifier for this baseline config
115
+ baseline_id: str
116
+
117
+ # Required: Human-readable name
118
+ name: str
119
+
120
+ # Required: Task runner (class or function)
121
+ # Class-based: Pass a class inheriting from BaselineTaskRunner
122
+ # The class will be instantiated with policy_config and env_config,
123
+ # and run_task(seed) will be called for each seed.
124
+ # Function-based: Pass an async function with signature:
125
+ # async def task_runner(seed: int, policy_config: Dict[str, Any],
126
+ # env_config: Dict[str, Any]) -> TaskResult
127
+ task_runner: TaskRunnerType
128
+
129
+ # Required: Data splits (train/val/test)
130
+ splits: Dict[str, DataSplit]
131
+
132
+ # Optional: Description for documentation
133
+ description: str = ""
134
+
135
+ # Optional: Default policy configuration
136
+ default_policy_config: Dict[str, Any] = field(default_factory=dict)
137
+
138
+ # Optional: Default environment configuration
139
+ default_env_config: Dict[str, Any] = field(default_factory=dict)
140
+
141
+ # Optional: Metadata for filtering/organization
142
+ metadata: Dict[str, Any] = field(default_factory=dict)
143
+
144
+ # Optional: Tags for filtering and discovery
145
+ tags: List[str] = field(default_factory=list)
146
+
147
+ # Optional: Custom result aggregator (class or function)
148
+ # Class-based: Pass a class with aggregate(results: List[TaskResult]) method
149
+ # The class will be instantiated and aggregate() called.
150
+ # Function-based: Pass a function with signature:
151
+ # def aggregate_results(results: List[TaskResult]) -> Dict[str, Any]
152
+ result_aggregator: Optional[AggregatorType] = None
153
+
154
+ # Optional: Path to this baseline file (set by discovery)
155
+ _source_path: Optional[Path] = None
156
+
157
+ def matches_tag(self, tag: str) -> bool:
158
+ """Check if baseline matches a tag (case-insensitive)."""
159
+ return tag.lower() in [t.lower() for t in self.tags]
160
+
161
+ def matches_metadata(self, key: str, value: Any) -> bool:
162
+ """Check if baseline metadata matches key-value pair."""
163
+ return self.metadata.get(key) == value
164
+
165
+
166
+ @dataclass
167
+ class BaselineResults:
168
+ """Aggregate results from a baseline evaluation."""
169
+
170
+ # Configuration that was used
171
+ config: BaselineConfig
172
+
173
+ # Split that was evaluated
174
+ split_name: str
175
+
176
+ # Per-seed results
177
+ results: List[TaskResult]
178
+
179
+ # Aggregate metrics
180
+ aggregate_metrics: Dict[str, Any]
181
+
182
+ # Execution metadata
183
+ execution_time_seconds: float
184
+ model_name: str
185
+ timestamp: str
186
+
187
+ def to_dict(self) -> Dict[str, Any]:
188
+ """Serialize to dictionary for JSON output."""
189
+ return {
190
+ "baseline_id": self.config.baseline_id,
191
+ "name": self.config.name,
192
+ "split": self.split_name,
193
+ "model": self.model_name,
194
+ "timestamp": self.timestamp,
195
+ "execution_time_seconds": self.execution_time_seconds,
196
+ "aggregate_metrics": self.aggregate_metrics,
197
+ "results": [
198
+ {
199
+ "seed": r.seed,
200
+ "success": r.success,
201
+ "outcome_reward": r.outcome_reward,
202
+ "total_steps": r.total_steps,
203
+ "metadata": r.metadata,
204
+ "error": r.error,
205
+ }
206
+ for r in self.results
207
+ ],
208
+ }
209
+