synth-ai 0.2.4.dev6__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. synth_ai/__init__.py +18 -9
  2. synth_ai/cli/__init__.py +10 -5
  3. synth_ai/cli/balance.py +25 -32
  4. synth_ai/cli/calc.py +2 -3
  5. synth_ai/cli/demo.py +3 -5
  6. synth_ai/cli/legacy_root_backup.py +58 -32
  7. synth_ai/cli/man.py +22 -19
  8. synth_ai/cli/recent.py +9 -8
  9. synth_ai/cli/root.py +58 -13
  10. synth_ai/cli/status.py +13 -6
  11. synth_ai/cli/traces.py +45 -21
  12. synth_ai/cli/watch.py +40 -37
  13. synth_ai/config/base_url.py +47 -2
  14. synth_ai/core/experiment.py +1 -2
  15. synth_ai/environments/__init__.py +2 -6
  16. synth_ai/environments/environment/artifacts/base.py +3 -1
  17. synth_ai/environments/environment/db/sqlite.py +1 -1
  18. synth_ai/environments/environment/registry.py +19 -20
  19. synth_ai/environments/environment/resources/sqlite.py +2 -3
  20. synth_ai/environments/environment/rewards/core.py +3 -2
  21. synth_ai/environments/environment/tools/__init__.py +6 -4
  22. synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
  23. synth_ai/environments/examples/crafter_classic/engine.py +13 -13
  24. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
  25. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
  26. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
  27. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
  28. synth_ai/environments/examples/crafter_classic/environment.py +16 -15
  29. synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
  30. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
  31. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
  32. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
  33. synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
  34. synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
  35. synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
  36. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
  37. synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
  38. synth_ai/environments/examples/crafter_custom/environment.py +13 -13
  39. synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
  40. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
  41. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
  42. synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
  43. synth_ai/environments/examples/enron/engine.py +18 -14
  44. synth_ai/environments/examples/enron/environment.py +12 -11
  45. synth_ai/environments/examples/enron/taskset.py +7 -7
  46. synth_ai/environments/examples/minigrid/__init__.py +6 -6
  47. synth_ai/environments/examples/minigrid/engine.py +6 -6
  48. synth_ai/environments/examples/minigrid/environment.py +6 -6
  49. synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
  50. synth_ai/environments/examples/minigrid/taskset.py +13 -13
  51. synth_ai/environments/examples/nethack/achievements.py +1 -1
  52. synth_ai/environments/examples/nethack/engine.py +8 -7
  53. synth_ai/environments/examples/nethack/environment.py +10 -9
  54. synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
  55. synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
  56. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
  57. synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
  58. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
  59. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
  60. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
  61. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
  62. synth_ai/environments/examples/nethack/taskset.py +5 -5
  63. synth_ai/environments/examples/red/engine.py +9 -8
  64. synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
  69. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
  70. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
  71. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
  72. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
  73. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
  74. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
  75. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
  76. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
  77. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
  78. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
  79. synth_ai/environments/examples/red/environment.py +18 -15
  80. synth_ai/environments/examples/red/taskset.py +5 -3
  81. synth_ai/environments/examples/sokoban/engine.py +16 -13
  82. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
  87. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
  88. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
  89. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
  90. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
  91. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
  92. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
  93. synth_ai/environments/examples/sokoban/environment.py +15 -14
  94. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
  95. synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
  96. synth_ai/environments/examples/sokoban/taskset.py +13 -10
  97. synth_ai/environments/examples/tictactoe/engine.py +6 -6
  98. synth_ai/environments/examples/tictactoe/environment.py +8 -7
  99. synth_ai/environments/examples/tictactoe/taskset.py +6 -5
  100. synth_ai/environments/examples/verilog/engine.py +4 -3
  101. synth_ai/environments/examples/verilog/environment.py +11 -10
  102. synth_ai/environments/examples/verilog/taskset.py +14 -12
  103. synth_ai/environments/examples/wordle/__init__.py +5 -5
  104. synth_ai/environments/examples/wordle/engine.py +32 -25
  105. synth_ai/environments/examples/wordle/environment.py +21 -16
  106. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +6 -6
  107. synth_ai/environments/examples/wordle/taskset.py +20 -12
  108. synth_ai/environments/reproducibility/core.py +1 -1
  109. synth_ai/environments/reproducibility/tree.py +21 -21
  110. synth_ai/environments/service/app.py +3 -2
  111. synth_ai/environments/service/core_routes.py +104 -110
  112. synth_ai/environments/service/external_registry.py +1 -2
  113. synth_ai/environments/service/registry.py +1 -1
  114. synth_ai/environments/stateful/core.py +1 -2
  115. synth_ai/environments/stateful/engine.py +1 -1
  116. synth_ai/environments/tasks/api.py +4 -4
  117. synth_ai/environments/tasks/core.py +14 -12
  118. synth_ai/environments/tasks/filters.py +6 -4
  119. synth_ai/environments/tasks/utils.py +13 -11
  120. synth_ai/evals/base.py +2 -3
  121. synth_ai/experimental/synth_oss.py +4 -4
  122. synth_ai/http.py +102 -0
  123. synth_ai/inference/__init__.py +7 -0
  124. synth_ai/inference/client.py +20 -0
  125. synth_ai/jobs/client.py +246 -0
  126. synth_ai/learning/__init__.py +24 -0
  127. synth_ai/learning/client.py +149 -0
  128. synth_ai/learning/config.py +43 -0
  129. synth_ai/learning/constants.py +29 -0
  130. synth_ai/learning/ft_client.py +59 -0
  131. synth_ai/learning/gateway.py +1 -3
  132. synth_ai/learning/health.py +43 -0
  133. synth_ai/learning/jobs.py +205 -0
  134. synth_ai/learning/prompts/banking77_injection_eval.py +15 -10
  135. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +26 -14
  136. synth_ai/learning/prompts/mipro.py +61 -52
  137. synth_ai/learning/prompts/random_search.py +42 -43
  138. synth_ai/learning/prompts/run_mipro_banking77.py +32 -20
  139. synth_ai/learning/prompts/run_random_search_banking77.py +71 -52
  140. synth_ai/learning/rl_client.py +256 -0
  141. synth_ai/learning/sse.py +58 -0
  142. synth_ai/learning/validators.py +48 -0
  143. synth_ai/lm/__init__.py +5 -5
  144. synth_ai/lm/caching/ephemeral.py +9 -9
  145. synth_ai/lm/caching/handler.py +20 -20
  146. synth_ai/lm/caching/persistent.py +10 -10
  147. synth_ai/lm/config.py +3 -3
  148. synth_ai/lm/constants.py +7 -7
  149. synth_ai/lm/core/all.py +17 -3
  150. synth_ai/lm/core/exceptions.py +0 -2
  151. synth_ai/lm/core/main.py +26 -41
  152. synth_ai/lm/core/main_v3.py +33 -10
  153. synth_ai/lm/core/synth_models.py +48 -0
  154. synth_ai/lm/core/vendor_clients.py +26 -22
  155. synth_ai/lm/injection.py +7 -8
  156. synth_ai/lm/overrides.py +21 -19
  157. synth_ai/lm/provider_support/__init__.py +1 -1
  158. synth_ai/lm/provider_support/anthropic.py +15 -15
  159. synth_ai/lm/provider_support/openai.py +23 -21
  160. synth_ai/lm/structured_outputs/handler.py +34 -32
  161. synth_ai/lm/structured_outputs/inject.py +24 -27
  162. synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
  163. synth_ai/lm/tools/base.py +17 -16
  164. synth_ai/lm/unified_interface.py +17 -18
  165. synth_ai/lm/vendors/base.py +20 -18
  166. synth_ai/lm/vendors/core/anthropic_api.py +36 -27
  167. synth_ai/lm/vendors/core/gemini_api.py +31 -36
  168. synth_ai/lm/vendors/core/mistral_api.py +19 -19
  169. synth_ai/lm/vendors/core/openai_api.py +42 -13
  170. synth_ai/lm/vendors/openai_standard.py +158 -101
  171. synth_ai/lm/vendors/openai_standard_responses.py +74 -61
  172. synth_ai/lm/vendors/retries.py +9 -1
  173. synth_ai/lm/vendors/supported/custom_endpoint.py +38 -28
  174. synth_ai/lm/vendors/supported/deepseek.py +10 -10
  175. synth_ai/lm/vendors/supported/grok.py +8 -8
  176. synth_ai/lm/vendors/supported/ollama.py +2 -1
  177. synth_ai/lm/vendors/supported/openrouter.py +11 -9
  178. synth_ai/lm/vendors/synth_client.py +425 -75
  179. synth_ai/lm/warmup.py +8 -7
  180. synth_ai/rl/__init__.py +30 -0
  181. synth_ai/rl/contracts.py +32 -0
  182. synth_ai/rl/env_keys.py +137 -0
  183. synth_ai/rl/secrets.py +19 -0
  184. synth_ai/scripts/verify_rewards.py +100 -0
  185. synth_ai/task/__init__.py +10 -0
  186. synth_ai/task/contracts.py +120 -0
  187. synth_ai/task/health.py +28 -0
  188. synth_ai/task/validators.py +12 -0
  189. synth_ai/tracing/__init__.py +22 -10
  190. synth_ai/tracing_v1/__init__.py +22 -20
  191. synth_ai/tracing_v3/__init__.py +7 -7
  192. synth_ai/tracing_v3/abstractions.py +56 -52
  193. synth_ai/tracing_v3/config.py +4 -2
  194. synth_ai/tracing_v3/db_config.py +6 -8
  195. synth_ai/tracing_v3/decorators.py +29 -30
  196. synth_ai/tracing_v3/examples/basic_usage.py +12 -12
  197. synth_ai/tracing_v3/hooks.py +24 -22
  198. synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
  199. synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
  200. synth_ai/tracing_v3/migration_helper.py +3 -5
  201. synth_ai/tracing_v3/replica_sync.py +30 -32
  202. synth_ai/tracing_v3/session_tracer.py +158 -31
  203. synth_ai/tracing_v3/storage/__init__.py +1 -1
  204. synth_ai/tracing_v3/storage/base.py +8 -7
  205. synth_ai/tracing_v3/storage/config.py +4 -4
  206. synth_ai/tracing_v3/storage/factory.py +4 -4
  207. synth_ai/tracing_v3/storage/utils.py +9 -9
  208. synth_ai/tracing_v3/turso/__init__.py +3 -3
  209. synth_ai/tracing_v3/turso/daemon.py +9 -9
  210. synth_ai/tracing_v3/turso/manager.py +278 -48
  211. synth_ai/tracing_v3/turso/models.py +77 -19
  212. synth_ai/tracing_v3/utils.py +5 -5
  213. synth_ai/v0/tracing/abstractions.py +28 -28
  214. synth_ai/v0/tracing/base_client.py +9 -9
  215. synth_ai/v0/tracing/client_manager.py +7 -7
  216. synth_ai/v0/tracing/config.py +7 -7
  217. synth_ai/v0/tracing/context.py +6 -6
  218. synth_ai/v0/tracing/decorators.py +6 -5
  219. synth_ai/v0/tracing/events/manage.py +1 -1
  220. synth_ai/v0/tracing/events/store.py +5 -4
  221. synth_ai/v0/tracing/immediate_client.py +4 -5
  222. synth_ai/v0/tracing/local.py +3 -3
  223. synth_ai/v0/tracing/log_client_base.py +4 -5
  224. synth_ai/v0/tracing/retry_queue.py +5 -6
  225. synth_ai/v0/tracing/trackers.py +25 -25
  226. synth_ai/v0/tracing/upload.py +6 -0
  227. synth_ai/v0/tracing_v1/__init__.py +1 -1
  228. synth_ai/v0/tracing_v1/abstractions.py +28 -28
  229. synth_ai/v0/tracing_v1/base_client.py +9 -9
  230. synth_ai/v0/tracing_v1/client_manager.py +7 -7
  231. synth_ai/v0/tracing_v1/config.py +7 -7
  232. synth_ai/v0/tracing_v1/context.py +6 -6
  233. synth_ai/v0/tracing_v1/decorators.py +7 -6
  234. synth_ai/v0/tracing_v1/events/manage.py +1 -1
  235. synth_ai/v0/tracing_v1/events/store.py +5 -4
  236. synth_ai/v0/tracing_v1/immediate_client.py +4 -5
  237. synth_ai/v0/tracing_v1/local.py +3 -3
  238. synth_ai/v0/tracing_v1/log_client_base.py +4 -5
  239. synth_ai/v0/tracing_v1/retry_queue.py +5 -6
  240. synth_ai/v0/tracing_v1/trackers.py +25 -25
  241. synth_ai/v0/tracing_v1/upload.py +25 -24
  242. synth_ai/zyk/__init__.py +1 -0
  243. synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
  244. synth_ai-0.2.4.dev8.dist-info/RECORD +317 -0
  245. synth_ai/tui/__init__.py +0 -1
  246. synth_ai/tui/__main__.py +0 -13
  247. synth_ai/tui/cli/__init__.py +0 -1
  248. synth_ai/tui/cli/query_experiments.py +0 -165
  249. synth_ai/tui/cli/query_experiments_v3.py +0 -165
  250. synth_ai/tui/dashboard.py +0 -329
  251. synth_ai-0.2.4.dev6.dist-info/METADATA +0 -203
  252. synth_ai-0.2.4.dev6.dist-info/RECORD +0 -299
  253. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
  254. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
  255. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
  256. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -12,44 +12,46 @@ Run:
12
12
  from __future__ import annotations
13
13
 
14
14
  import asyncio
15
+ import json
15
16
  import os
16
17
  import random
18
+ import time
19
+ from collections.abc import Sequence
17
20
  from dataclasses import dataclass, replace
21
+ from pathlib import Path
18
22
  from types import SimpleNamespace
19
- from tqdm import tqdm
20
- from typing import Any, Dict, List, Sequence, Tuple
23
+ from typing import Any
21
24
 
22
- from dotenv import load_dotenv
23
25
  from datasets import load_dataset
24
-
25
- from synth_ai.lm.core.main_v3 import LM, build_messages
26
- import json
27
- import time
28
- from pathlib import Path
26
+ from dotenv import load_dotenv
29
27
  from synth_ai.learning.prompts.random_search import random_search_compile
28
+ from synth_ai.lm.core.main_v3 import LM, build_messages
29
+ from tqdm import tqdm
30
30
 
31
31
 
32
- def choose_label(pred: str, label_names: List[str]) -> str:
32
+ def choose_label(pred: str, label_names: list[str]) -> str:
33
33
  norm = (pred or "").strip().lower()
34
34
  d = {ln.lower(): ln for ln in label_names}
35
35
  if norm in d:
36
36
  return d[norm]
37
+
37
38
  def score(cand: str) -> int:
38
39
  c = cand.lower()
39
40
  return sum(1 for w in c.split() if w in norm)
41
+
40
42
  return max(label_names, key=score)
41
43
 
42
44
 
43
- def accuracy(pred: str, gold: str, labels: List[str]) -> float:
45
+ def accuracy(pred: str, gold: str, labels: list[str]) -> float:
44
46
  return 1.0 if choose_label(pred, labels) == gold else 0.0
45
47
 
46
48
 
47
49
  @dataclass
48
50
  class StudentProgram:
49
51
  lm: LM
50
- label_names: List[str]
52
+ label_names: list[str]
51
53
  instruction: str
52
- demos: List[Tuple[str, str]]
54
+ demos: list[tuple[str, str]]
53
55
 
54
56
  def reset_copy(self):
55
57
  return replace(self, instruction=self.instruction, demos=list(self.demos))
@@ -57,7 +59,7 @@ class StudentProgram:
57
59
  def deepcopy(self):
58
60
  return replace(self, instruction=str(self.instruction), demos=list(self.demos))
59
61
 
60
- def with_demos(self, demos: List[Tuple[str, str]]):
62
+ def with_demos(self, demos: list[tuple[str, str]]):
61
63
  return replace(self, demos=list(demos))
62
64
 
63
65
  def run(self, x: str) -> str:
@@ -66,10 +68,12 @@ class StudentProgram:
66
68
  sys = self.instruction or "You are an intent classifier for Banking77."
67
69
  user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
68
70
  messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
71
+
69
72
  # Call LM synchronously via asyncio
70
73
  async def _call():
71
74
  resp = await self.lm.respond_async(messages=messages)
72
75
  return (resp.raw_response or "").strip()
76
+
73
77
  return asyncio.run(_call())
74
78
 
75
79
  async def _apredict(self, x: str):
@@ -91,13 +95,13 @@ def main():
91
95
 
92
96
  print("Loading Banking77 dataset (train/dev split of test for demo)...")
93
97
  ds = load_dataset("banking77")
94
- label_names: List[str] = ds["test"].features["label"].names # type: ignore
98
+ label_names: list[str] = ds["test"].features["label"].names # type: ignore
95
99
 
96
100
  # Create small train/val from the test split for speed
97
101
  all_items = [(r["text"], label_names[int(r["label"])]) for r in ds["test"]]
98
102
  random.shuffle(all_items)
99
- trainset: Sequence[Tuple[str, str]] = all_items[:40]
100
- valset: Sequence[Tuple[str, str]] = all_items[40:60] # 20 examples
103
+ trainset: Sequence[tuple[str, str]] = all_items[:40]
104
+ valset: Sequence[tuple[str, str]] = all_items[40:60] # 20 examples
101
105
 
102
106
  student = StudentProgram(
103
107
  lm=lm,
@@ -110,17 +114,20 @@ def main():
110
114
  return accuracy(yhat, y, label_names)
111
115
 
112
116
  total_candidates = 3 + 3 # zero-shot, labeled few-shot, bootstrapped + 3 random seeds
113
- print(f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)...")
117
+ print(
118
+ f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)..."
119
+ )
114
120
 
115
- def eval_parallel(program: StudentProgram, dataset: Sequence[Tuple[str, str]], metric_fn):
121
+ def eval_parallel(program: StudentProgram, dataset: Sequence[tuple[str, str]], metric_fn):
116
122
  async def _run():
117
123
  xs = [x for x, _ in dataset]
118
124
  ys = [y for _, y in dataset]
119
- preds: List[Optional[str]] = [None] * len(xs)
125
+ preds: list[Optional[str]] = [None] * len(xs)
120
126
  sem = asyncio.Semaphore(int(os.getenv("CONCURRENCY", "5")))
121
127
 
122
128
  async def worker(i: int, x: str, y: str):
123
129
  import time
130
+
124
131
  t_start = time.monotonic()
125
132
  try:
126
133
  async with sem:
@@ -138,16 +145,18 @@ def main():
138
145
  t_end = time.monotonic()
139
146
  return i, y, "", t_start, t_end, {}
140
147
 
141
- tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys))]
148
+ tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys, strict=False))]
142
149
  correct_sum = 0.0
143
150
  processed = 0
144
- import time, statistics
145
- durations: List[float] = []
151
+ import statistics
152
+ import time
153
+
154
+ durations: list[float] = []
146
155
  in_tok_sum = 0
147
156
  out_tok_sum = 0
148
157
  in_tok_count = 0
149
158
  out_tok_count = 0
150
- details: List[Dict[str, Any]] = []
159
+ details: list[dict[str, Any]] = []
151
160
  t_batch_start = time.monotonic()
152
161
  deadline = float(os.getenv("BATCH_DEADLINE_S", "20"))
153
162
  with tqdm(total=len(tasks), desc="Rollouts", leave=False) as pbar:
@@ -172,7 +181,10 @@ def main():
172
181
  break
173
182
  # Wait for at least one completion within remaining time (polling granularity <= 1s)
174
183
  timeout = min(1.0, remaining)
175
- done, pending = await asyncio.wait(pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
184
+ done, pending = await asyncio.wait(
185
+ pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
186
+ )
187
+ import contextlib
176
188
  for task in done:
177
189
  try:
178
190
  i, y_true, pred, t_start, t_end, usage = task.result()
@@ -182,11 +194,9 @@ def main():
182
194
  durations.append(max(0.0, t_end - t_start))
183
195
  preds[i] = pred
184
196
  processed += 1
185
- try:
197
+ with contextlib.suppress(Exception):
186
198
  correct_sum += float(metric_fn(pred, y_true))
187
- except Exception:
188
- pass
189
- try:
199
+ with contextlib.suppress(Exception):
190
200
  pt = usage.get("prompt_tokens") or usage.get("input_tokens")
191
201
  ct = usage.get("completion_tokens") or usage.get("output_tokens")
192
202
  if isinstance(pt, (int, float)):
@@ -195,30 +205,34 @@ def main():
195
205
  if isinstance(ct, (int, float)):
196
206
  out_tok_sum += int(ct)
197
207
  out_tok_count += 1
198
- except Exception:
199
- pass
200
- details.append({
201
- "index": i,
202
- "seconds": max(0.0, t_end - t_start),
203
- "score": float(metric_fn(pred, y_true)),
204
- "usage": {
205
- "prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens"),
206
- "completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens"),
207
- },
208
- })
208
+ details.append(
209
+ {
210
+ "index": i,
211
+ "seconds": max(0.0, t_end - t_start),
212
+ "score": float(metric_fn(pred, y_true)),
213
+ "usage": {
214
+ "prompt_tokens": usage.get("prompt_tokens")
215
+ or usage.get("input_tokens"),
216
+ "completion_tokens": usage.get("completion_tokens")
217
+ or usage.get("output_tokens"),
218
+ },
219
+ }
220
+ )
209
221
  pbar.update(1)
210
222
  med = statistics.median(durations) if durations else 0.0
211
223
  mx = max(durations) if durations else 0.0
212
224
  avg_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
213
225
  avg_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
214
- pbar.set_postfix({
215
- "acc": f"{(correct_sum/processed):.2f}",
216
- "done": f"{processed}/{len(tasks)}",
217
- "med_s": f"{med:.1f}",
218
- "max_s": f"{mx:.1f}",
219
- "tin": f"{avg_in:.1f}",
220
- "tout": f"{avg_out:.1f}",
221
- })
226
+ pbar.set_postfix(
227
+ {
228
+ "acc": f"{(correct_sum / processed):.2f}",
229
+ "done": f"{processed}/{len(tasks)}",
230
+ "med_s": f"{med:.1f}",
231
+ "max_s": f"{mx:.1f}",
232
+ "tin": f"{avg_in:.1f}",
233
+ "tout": f"{avg_out:.1f}",
234
+ }
235
+ )
222
236
  # Compute score only from completed/successful rollouts (drop timeouts/cancelled)
223
237
  subs = [float(d.get("score", 0.0)) for d in details]
224
238
  result = SimpleNamespace(score=(sum(subs) / max(1, len(subs))), subscores=subs)
@@ -226,28 +240,33 @@ def main():
226
240
  result.mean_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
227
241
  result.mean_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
228
242
  return result
243
+
229
244
  return asyncio.run(_run())
245
+
230
246
  pbar = tqdm(total=total_candidates, desc="Candidates")
231
- candidate_eval_details: Dict[int, Any] = {}
247
+ candidate_eval_details: dict[int, Any] = {}
248
+
232
249
  def on_cand(idx: int, score: float, res, intervention):
233
250
  pbar.update(1)
234
251
  pbar.set_postfix({"score": f"{score:.2f}"})
235
252
  # store per-instance details (for apples-to-apples)
236
- try:
253
+ import contextlib
254
+ with contextlib.suppress(Exception):
237
255
  candidate_eval_details[idx] = {
238
256
  "score": score,
239
257
  "mean_in": getattr(res, "mean_in", None),
240
258
  "mean_out": getattr(res, "mean_out", None),
241
259
  "instances": getattr(res, "details", None),
242
260
  }
243
- except Exception:
244
- pass
245
261
  # visible summary line per candidate
246
- kind = intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
262
+ kind = (
263
+ intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
264
+ )
247
265
  label = intervention.get("label") if isinstance(intervention, dict) else None
248
266
  seed = intervention.get("seed") if isinstance(intervention, dict) else None
249
267
  processed = len(getattr(res, "details", []) or [])
250
268
  from tqdm import tqdm as _tqdm
269
+
251
270
  _tqdm.write(
252
271
  f"Candidate {idx}/{total_candidates} [{kind}{'' if label is None else f', label={label}'}{'' if seed is None else f', seed={seed}'}]: "
253
272
  f"score={score:.2f} | mean tin/tout={getattr(res, 'mean_in', 0):.1f}/{getattr(res, 'mean_out', 0):.1f} | N={processed}"
@@ -0,0 +1,256 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Callable
4
+ import os
5
+ import time
6
+
7
+ from ..http import AsyncHttpClient, HTTPError, sleep
8
+
9
+
10
+ def _api_base(b: str) -> str:
11
+ b = (b or "").rstrip("/")
12
+ return b if b.endswith("/api") else f"{b}/api"
13
+
14
+
15
+ class RlClient:
16
+ """Lightweight RL client for provider-agnostic job control.
17
+
18
+ Notes:
19
+ - Uses learning/* for status/events/metrics and rl/* for creation/start.
20
+ - Trainer endpoints are resolved server-side via trainer_id.
21
+ """
22
+
23
+ def __init__(self, base_url: str, api_key: str, *, timeout: float = 600.0) -> None:
24
+ self._base_url = base_url.rstrip("/")
25
+ self._api_key = api_key
26
+ self._timeout = timeout
27
+
28
+ async def resolve_trainer_start_url(self, trainer_id: str) -> str:
29
+ """GET /api/rl/services/{id} → { training_start_url }"""
30
+ path = f"/api/rl/services/{trainer_id}"
31
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
32
+ js = await http.get(path)
33
+ if not isinstance(js, dict):
34
+ raise HTTPError(status=500, url=path, message="invalid_service_response", body_snippet=str(js)[:200])
35
+ start_url = js.get("training_start_url")
36
+ if not isinstance(start_url, str) or not start_url:
37
+ raise HTTPError(status=500, url=path, message="missing_training_start_url", body_snippet=str(js)[:200])
38
+ return start_url
39
+
40
+ async def create_job(
41
+ self,
42
+ *,
43
+ model: str,
44
+ task_app_url: str,
45
+ trainer: Dict[str, Any],
46
+ trainer_id: Optional[str] = None,
47
+ job_config_id: Optional[str] = None,
48
+ inline_config: Optional[Dict[str, Any]] = None,
49
+ ) -> Dict[str, Any]:
50
+ body = {
51
+ "job_type": "rl",
52
+ "data": {
53
+ "model": model,
54
+ "endpoint_base_url": task_app_url,
55
+ **({"job_config_id": job_config_id} if job_config_id else {}),
56
+ **({"config": inline_config} if inline_config else {}),
57
+ "trainer": {
58
+ "batch_size": int(trainer.get("batch_size", 1)),
59
+ "group_size": max(2, int(trainer.get("group_size", 2))),
60
+ },
61
+ },
62
+ }
63
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
64
+ js = await http.post_json(f"{_api_base(self._base_url)}/rl/jobs", json=body)
65
+ if not isinstance(js, dict):
66
+ raise HTTPError(status=500, url="/api/rl/jobs", message="invalid_create_response", body_snippet=str(js)[:200])
67
+ return js
68
+
69
+ async def start_job_if_supported(self, job_id: str) -> Optional[Dict[str, Any]]:
70
+ path = f"{_api_base(self._base_url)}/rl/jobs/{job_id}/start"
71
+ try:
72
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
73
+ return await http.post_json(path, json={})
74
+ except HTTPError as he: # noqa: PERF203
75
+ if he.status == 404:
76
+ return None
77
+ raise
78
+
79
+ async def get_job(self, job_id: str) -> Dict[str, Any]:
80
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
81
+ return await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}")
82
+
83
+ async def get_events(self, job_id: str, *, since_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
84
+ params = {"since_seq": since_seq, "limit": limit}
85
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
86
+ try:
87
+ js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events", params=params)
88
+ except HTTPError as he:
89
+ try:
90
+ print(
91
+ f"[poll] events HTTPError status={he.status} url={he.url} since_seq={since_seq} body={(he.body_snippet or '')[:200]}"
92
+ )
93
+ except Exception:
94
+ pass
95
+ raise
96
+ if isinstance(js, dict):
97
+ evs = js.get("events") or js.get("data")
98
+ if isinstance(evs, list):
99
+ return evs
100
+ return []
101
+
102
+ async def get_metrics(self, job_id: str, *, after_step: int = -1, limit: int = 200) -> List[Dict[str, Any]]:
103
+ params = {"after_step": after_step, "limit": limit}
104
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
105
+ js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/metrics", params=params)
106
+ if isinstance(js, dict) and isinstance(js.get("points"), list):
107
+ return js["points"]
108
+ return []
109
+
110
+ async def poll_until_terminal(
111
+ self,
112
+ job_id: str,
113
+ *,
114
+ interval_seconds: float = 2.0,
115
+ max_seconds: float | None = None,
116
+ empty_polls_threshold: int = 5,
117
+ startup_deadline_s: int = 45,
118
+ on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
119
+ on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
120
+ ) -> Dict[str, Any]:
121
+ last_seq_by_stream: Dict[str, int] = {}
122
+ events_job_id: Optional[str] = None
123
+ last_status: Optional[str] = None
124
+ last_step_by_name: Dict[str, int] = {}
125
+ empty_polls = 0
126
+ saw_any_event = False
127
+ start_t = time.time()
128
+ terminal = {"succeeded", "failed", "cancelled", "canceled", "error", "completed"}
129
+
130
+ while True:
131
+ status_data: Optional[Dict[str, Any]] = None
132
+ try:
133
+ status_data = await self.get_job(job_id)
134
+ except Exception:
135
+ status_data = None
136
+ if status_data is None:
137
+ try:
138
+ print(f"[poll] get_job returned None base={self._base_url} job_id={job_id}")
139
+ except Exception:
140
+ pass
141
+ status = str((status_data or {}).get("status") or "").lower()
142
+ if status_data:
143
+ linked = status_data.get("linked_job_id")
144
+ if isinstance(linked, str) and linked and linked != events_job_id:
145
+ events_job_id = linked
146
+ try:
147
+ print(f"[poll] discovered linked_job_id stream={events_job_id}")
148
+ except Exception:
149
+ pass
150
+ if status and status != last_status:
151
+ last_status = status
152
+ # Status transitions only to avoid log spam
153
+ if on_event:
154
+ try:
155
+ on_event({"type": "rl.status", "message": status})
156
+ except Exception:
157
+ pass
158
+
159
+ # Events
160
+ stream_ids = [job_id]
161
+ if events_job_id and events_job_id not in stream_ids:
162
+ stream_ids.append(events_job_id)
163
+ try:
164
+ print(f"[poll] streams={stream_ids} intervals={interval_seconds}s since_map={last_seq_by_stream} empty_polls={empty_polls}")
165
+ except Exception:
166
+ pass
167
+ total_events_this_cycle = 0
168
+ terminal_event_seen = False
169
+ terminal_event_status: Optional[str] = None
170
+ for ev_id in stream_ids:
171
+ since = last_seq_by_stream.get(ev_id, 0)
172
+ try:
173
+ events = await self.get_events(ev_id, since_seq=since, limit=200)
174
+ except HTTPError as he:
175
+ try:
176
+ print(f"[poll] get_events error status={he.status} url={he.url} since={since} body={(he.body_snippet or '')[:200]}")
177
+ except Exception:
178
+ pass
179
+ events = []
180
+ except Exception as e:
181
+ try:
182
+ print(f"[poll] get_events unexpected error ev_id={ev_id} since={since} err={type(e).__name__}: {e}")
183
+ except Exception:
184
+ pass
185
+ events = []
186
+ total_events_this_cycle += len(events)
187
+ if events:
188
+ saw_any_event = True
189
+ for e in events:
190
+ seq_val = int(e.get("seq") or 0)
191
+ if seq_val <= last_seq_by_stream.get(ev_id, 0):
192
+ continue
193
+ last_seq_by_stream[ev_id] = seq_val
194
+ if on_event:
195
+ try:
196
+ on_event(e)
197
+ except Exception:
198
+ pass
199
+ et = str(e.get("type") or e.get("event_type") or "").lower()
200
+ if et in ("rl.job.completed", "workflow.completed", "rl.train.completed"):
201
+ terminal_event_seen = True
202
+ terminal_event_status = "succeeded"
203
+ elif et in ("rl.job.failed", "workflow.failed"):
204
+ terminal_event_seen = True
205
+ terminal_event_status = "failed"
206
+
207
+ # Metrics
208
+ try:
209
+ after = max(last_step_by_name.values()) if last_step_by_name else -1
210
+ points = await self.get_metrics(job_id, after_step=after, limit=200)
211
+ for p in points:
212
+ name = str(p.get("name") or "")
213
+ step = int(p.get("step") or -1)
214
+ if step <= last_step_by_name.get(name, -1):
215
+ continue
216
+ last_step_by_name[name] = step
217
+ if on_metric:
218
+ try:
219
+ on_metric(p)
220
+ except Exception:
221
+ pass
222
+ except Exception:
223
+ pass
224
+
225
+ if terminal_event_seen:
226
+ return {"status": terminal_event_status or status or "completed", "job_id": job_id}
227
+ if status and status in terminal:
228
+ return {"status": status, "job_id": job_id}
229
+
230
+ if total_events_this_cycle == 0:
231
+ empty_polls += 1
232
+ else:
233
+ empty_polls = 0
234
+ if empty_polls >= max(1, int(empty_polls_threshold)):
235
+ try:
236
+ print(
237
+ f"[poll] threshold hit: empty_polls={empty_polls} >= {empty_polls_threshold} streams={stream_ids} last_seq_map={last_seq_by_stream}"
238
+ )
239
+ except Exception:
240
+ pass
241
+ raise AssertionError(f"No new events detected for {empty_polls_threshold} consecutive polls. Check event ingestion.")
242
+
243
+ if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
244
+ try:
245
+ print(
246
+ f"[poll] startup window exceeded: {startup_deadline_s}s base={self._base_url} job={job_id} streams={stream_ids} last_seq_map={last_seq_by_stream}"
247
+ )
248
+ except Exception:
249
+ pass
250
+ raise AssertionError(f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming.")
251
+
252
+ await sleep(interval_seconds)
253
+ if max_seconds is not None and (time.time() - start_t) >= max_seconds:
254
+ raise TimeoutError(f"Polling timed out after {max_seconds}s for job {job_id}")
255
+
256
+
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, Callable, Optional
6
+
7
+ import aiohttp
8
+
9
+
10
+ def _api_base(b: str) -> str:
11
+ b = (b or "").rstrip("/")
12
+ return b if b.endswith("/api") else f"{b}/api"
13
+
14
+
15
+ async def stream_events(
16
+ base_url: str,
17
+ api_key: str,
18
+ job_id: str,
19
+ *,
20
+ seconds: int = 60,
21
+ on_event: Optional[Callable[[dict], None]] = None,
22
+ ) -> None:
23
+ if seconds <= 0:
24
+ return
25
+ headers = {"Accept": "text/event-stream", "Authorization": f"Bearer {api_key}"}
26
+ candidates = [
27
+ f"{_api_base(base_url)}/rl/jobs/{job_id}/events?since_seq=0",
28
+ f"{_api_base(base_url)}/learning/jobs/{job_id}/events?since_seq=0",
29
+ ]
30
+ for url in candidates:
31
+ try:
32
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
33
+ async with session.get(url, headers=headers) as resp:
34
+ if resp.status != 200:
35
+ continue
36
+ start_t = time.time()
37
+ async for raw in resp.content:
38
+ line = raw.decode(errors="ignore").strip()
39
+ if not line or line.startswith(":"):
40
+ continue
41
+ if not line.startswith("data:"):
42
+ continue
43
+ data = line[5:].strip()
44
+ try:
45
+ obj = json.loads(data)
46
+ except Exception:
47
+ continue
48
+ if on_event:
49
+ try:
50
+ on_event(obj)
51
+ except Exception:
52
+ pass
53
+ if (time.time() - start_t) >= seconds:
54
+ return
55
+ except Exception:
56
+ continue
57
+
58
+
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ import json
5
+ from typing import Any, Dict
6
+ from urllib.parse import urlparse
7
+
8
+
9
+ def validate_training_jsonl(path: str | Path, *, sample_lines: int = 50) -> None:
10
+ p = Path(path)
11
+ if not p.exists():
12
+ raise FileNotFoundError(str(p))
13
+ lines = p.read_text().splitlines()
14
+ if not lines:
15
+ raise ValueError("empty JSONL")
16
+ for i, line in enumerate(lines[: max(1, sample_lines) ], start=1):
17
+ if not line.strip():
18
+ continue
19
+ try:
20
+ obj = json.loads(line)
21
+ except Exception as e:
22
+ raise ValueError(f"invalid json on line {i}: {e}") from e
23
+ msgs = obj.get("messages")
24
+ if not isinstance(msgs, list) or len(msgs) < 2:
25
+ raise ValueError(f"line {i}: missing messages[] with at least 2 turns")
26
+ roles = [m.get("role") for m in msgs if isinstance(m, dict)]
27
+ if not roles or not isinstance(roles[0], str):
28
+ raise ValueError(f"line {i}: missing first role")
29
+ for m in msgs:
30
+ if not isinstance(m, dict):
31
+ raise ValueError(f"line {i}: non-dict message")
32
+ if not isinstance(m.get("role"), str) or not isinstance(m.get("content"), str) or not m["content"].strip():
33
+ raise ValueError(f"line {i}: invalid role/content")
34
+
35
+
36
+ def validate_task_app_url(url: str, *, name: str = "TASK_APP_BASE_URL") -> None:
37
+ from synth_ai.task.validators import validate_task_app_url as _vt
38
+
39
+ _vt(url, name=name)
40
+
41
+
42
+ def validate_trainer_cfg_rl(trainer: Dict[str, Any]) -> None:
43
+ bs = int(trainer.get("batch_size", 1))
44
+ gs = int(trainer.get("group_size", 2))
45
+ if bs < 1:
46
+ raise ValueError("trainer.batch_size must be >= 1")
47
+ if gs < 2:
48
+ raise ValueError("trainer.group_size must be >= 2")
synth_ai/lm/__init__.py CHANGED
@@ -4,24 +4,24 @@ Synth AI Language Model Interface.
4
4
  Provides a unified interface for multiple LLM providers including OpenAI and Synth.
5
5
  """
6
6
 
7
- from .config import SynthConfig, OpenAIConfig
8
- from .warmup import warmup_synth_model, get_warmup_status
7
+ from .config import OpenAIConfig, SynthConfig
8
+ from .core.main_v3 import LM
9
9
  from .unified_interface import (
10
- UnifiedLMProvider,
11
10
  OpenAIProvider,
12
11
  SynthProvider,
13
12
  UnifiedLMClient,
13
+ UnifiedLMProvider,
14
14
  create_provider,
15
15
  )
16
16
  from .vendors.synth_client import (
17
17
  AsyncSynthClient,
18
18
  SyncSynthClient,
19
19
  create_async_client,
20
- create_sync_client,
21
20
  create_chat_completion_async,
22
21
  create_chat_completion_sync,
22
+ create_sync_client,
23
23
  )
24
- from .core.main_v3 import LM
24
+ from .warmup import get_warmup_status, warmup_synth_model
25
25
 
26
26
  __all__ = [
27
27
  # Configuration