synth-ai 0.2.4.dev5__py3-none-any.whl → 0.2.4.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. synth_ai/__init__.py +18 -9
  2. synth_ai/cli/__init__.py +10 -5
  3. synth_ai/cli/balance.py +22 -17
  4. synth_ai/cli/calc.py +2 -3
  5. synth_ai/cli/demo.py +3 -5
  6. synth_ai/cli/legacy_root_backup.py +58 -32
  7. synth_ai/cli/man.py +22 -19
  8. synth_ai/cli/recent.py +9 -8
  9. synth_ai/cli/root.py +58 -13
  10. synth_ai/cli/status.py +13 -6
  11. synth_ai/cli/traces.py +45 -21
  12. synth_ai/cli/watch.py +40 -37
  13. synth_ai/config/base_url.py +1 -3
  14. synth_ai/core/experiment.py +1 -2
  15. synth_ai/environments/__init__.py +2 -6
  16. synth_ai/environments/environment/artifacts/base.py +3 -1
  17. synth_ai/environments/environment/db/sqlite.py +1 -1
  18. synth_ai/environments/environment/registry.py +19 -20
  19. synth_ai/environments/environment/resources/sqlite.py +2 -3
  20. synth_ai/environments/environment/rewards/core.py +3 -2
  21. synth_ai/environments/environment/tools/__init__.py +6 -4
  22. synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
  23. synth_ai/environments/examples/crafter_classic/engine.py +21 -17
  24. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
  25. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
  26. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
  27. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
  28. synth_ai/environments/examples/crafter_classic/environment.py +16 -15
  29. synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
  30. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
  31. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
  32. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
  33. synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
  34. synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
  35. synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
  36. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
  37. synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
  38. synth_ai/environments/examples/crafter_custom/environment.py +13 -13
  39. synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
  40. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
  41. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
  42. synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
  43. synth_ai/environments/examples/enron/engine.py +18 -14
  44. synth_ai/environments/examples/enron/environment.py +12 -11
  45. synth_ai/environments/examples/enron/taskset.py +7 -7
  46. synth_ai/environments/examples/minigrid/__init__.py +6 -6
  47. synth_ai/environments/examples/minigrid/engine.py +6 -6
  48. synth_ai/environments/examples/minigrid/environment.py +6 -6
  49. synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
  50. synth_ai/environments/examples/minigrid/taskset.py +13 -13
  51. synth_ai/environments/examples/nethack/achievements.py +1 -1
  52. synth_ai/environments/examples/nethack/engine.py +8 -7
  53. synth_ai/environments/examples/nethack/environment.py +10 -9
  54. synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
  55. synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
  56. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
  57. synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
  58. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
  59. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
  60. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
  61. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
  62. synth_ai/environments/examples/nethack/taskset.py +5 -5
  63. synth_ai/environments/examples/red/engine.py +9 -8
  64. synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
  69. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
  70. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
  71. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
  72. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
  73. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
  74. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
  75. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
  76. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
  77. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
  78. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
  79. synth_ai/environments/examples/red/environment.py +18 -15
  80. synth_ai/environments/examples/red/taskset.py +5 -3
  81. synth_ai/environments/examples/sokoban/engine.py +16 -13
  82. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
  87. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
  88. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
  89. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
  90. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
  91. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
  92. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
  93. synth_ai/environments/examples/sokoban/environment.py +15 -14
  94. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
  95. synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
  96. synth_ai/environments/examples/sokoban/taskset.py +13 -10
  97. synth_ai/environments/examples/tictactoe/engine.py +6 -6
  98. synth_ai/environments/examples/tictactoe/environment.py +8 -7
  99. synth_ai/environments/examples/tictactoe/taskset.py +6 -5
  100. synth_ai/environments/examples/verilog/engine.py +4 -3
  101. synth_ai/environments/examples/verilog/environment.py +11 -10
  102. synth_ai/environments/examples/verilog/taskset.py +14 -12
  103. synth_ai/environments/examples/wordle/__init__.py +29 -0
  104. synth_ai/environments/examples/wordle/engine.py +398 -0
  105. synth_ai/environments/examples/wordle/environment.py +159 -0
  106. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  107. synth_ai/environments/examples/wordle/taskset.py +230 -0
  108. synth_ai/environments/reproducibility/core.py +1 -1
  109. synth_ai/environments/reproducibility/tree.py +21 -21
  110. synth_ai/environments/service/app.py +11 -2
  111. synth_ai/environments/service/core_routes.py +137 -105
  112. synth_ai/environments/service/external_registry.py +1 -2
  113. synth_ai/environments/service/registry.py +1 -1
  114. synth_ai/environments/stateful/core.py +1 -2
  115. synth_ai/environments/stateful/engine.py +1 -1
  116. synth_ai/environments/tasks/api.py +4 -4
  117. synth_ai/environments/tasks/core.py +14 -12
  118. synth_ai/environments/tasks/filters.py +6 -4
  119. synth_ai/environments/tasks/utils.py +13 -11
  120. synth_ai/evals/base.py +2 -3
  121. synth_ai/experimental/synth_oss.py +4 -4
  122. synth_ai/learning/gateway.py +1 -3
  123. synth_ai/learning/prompts/banking77_injection_eval.py +168 -0
  124. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +213 -0
  125. synth_ai/learning/prompts/mipro.py +282 -1
  126. synth_ai/learning/prompts/random_search.py +246 -0
  127. synth_ai/learning/prompts/run_mipro_banking77.py +172 -0
  128. synth_ai/learning/prompts/run_random_search_banking77.py +324 -0
  129. synth_ai/lm/__init__.py +5 -5
  130. synth_ai/lm/caching/ephemeral.py +9 -9
  131. synth_ai/lm/caching/handler.py +20 -20
  132. synth_ai/lm/caching/persistent.py +10 -10
  133. synth_ai/lm/config.py +3 -3
  134. synth_ai/lm/constants.py +7 -7
  135. synth_ai/lm/core/all.py +17 -3
  136. synth_ai/lm/core/exceptions.py +0 -2
  137. synth_ai/lm/core/main.py +26 -41
  138. synth_ai/lm/core/main_v3.py +20 -10
  139. synth_ai/lm/core/vendor_clients.py +18 -17
  140. synth_ai/lm/injection.py +80 -0
  141. synth_ai/lm/overrides.py +206 -0
  142. synth_ai/lm/provider_support/__init__.py +1 -1
  143. synth_ai/lm/provider_support/anthropic.py +51 -24
  144. synth_ai/lm/provider_support/openai.py +51 -22
  145. synth_ai/lm/structured_outputs/handler.py +34 -32
  146. synth_ai/lm/structured_outputs/inject.py +24 -27
  147. synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
  148. synth_ai/lm/tools/base.py +17 -16
  149. synth_ai/lm/unified_interface.py +17 -18
  150. synth_ai/lm/vendors/base.py +20 -18
  151. synth_ai/lm/vendors/core/anthropic_api.py +50 -25
  152. synth_ai/lm/vendors/core/gemini_api.py +31 -36
  153. synth_ai/lm/vendors/core/mistral_api.py +19 -19
  154. synth_ai/lm/vendors/core/openai_api.py +11 -10
  155. synth_ai/lm/vendors/openai_standard.py +144 -88
  156. synth_ai/lm/vendors/openai_standard_responses.py +74 -61
  157. synth_ai/lm/vendors/retries.py +9 -1
  158. synth_ai/lm/vendors/supported/custom_endpoint.py +26 -26
  159. synth_ai/lm/vendors/supported/deepseek.py +10 -10
  160. synth_ai/lm/vendors/supported/grok.py +8 -8
  161. synth_ai/lm/vendors/supported/ollama.py +2 -1
  162. synth_ai/lm/vendors/supported/openrouter.py +11 -9
  163. synth_ai/lm/vendors/synth_client.py +69 -63
  164. synth_ai/lm/warmup.py +8 -7
  165. synth_ai/tracing/__init__.py +22 -10
  166. synth_ai/tracing_v1/__init__.py +22 -20
  167. synth_ai/tracing_v3/__init__.py +7 -7
  168. synth_ai/tracing_v3/abstractions.py +56 -52
  169. synth_ai/tracing_v3/config.py +4 -2
  170. synth_ai/tracing_v3/db_config.py +6 -8
  171. synth_ai/tracing_v3/decorators.py +29 -30
  172. synth_ai/tracing_v3/examples/basic_usage.py +12 -12
  173. synth_ai/tracing_v3/hooks.py +21 -21
  174. synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
  175. synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
  176. synth_ai/tracing_v3/migration_helper.py +3 -5
  177. synth_ai/tracing_v3/replica_sync.py +30 -32
  178. synth_ai/tracing_v3/session_tracer.py +35 -29
  179. synth_ai/tracing_v3/storage/__init__.py +1 -1
  180. synth_ai/tracing_v3/storage/base.py +8 -7
  181. synth_ai/tracing_v3/storage/config.py +4 -4
  182. synth_ai/tracing_v3/storage/factory.py +4 -4
  183. synth_ai/tracing_v3/storage/utils.py +9 -9
  184. synth_ai/tracing_v3/turso/__init__.py +3 -3
  185. synth_ai/tracing_v3/turso/daemon.py +9 -9
  186. synth_ai/tracing_v3/turso/manager.py +60 -48
  187. synth_ai/tracing_v3/turso/models.py +24 -19
  188. synth_ai/tracing_v3/utils.py +5 -5
  189. synth_ai/tui/__main__.py +1 -1
  190. synth_ai/tui/cli/query_experiments.py +2 -3
  191. synth_ai/tui/cli/query_experiments_v3.py +2 -3
  192. synth_ai/tui/dashboard.py +97 -86
  193. synth_ai/v0/tracing/abstractions.py +28 -28
  194. synth_ai/v0/tracing/base_client.py +9 -9
  195. synth_ai/v0/tracing/client_manager.py +7 -7
  196. synth_ai/v0/tracing/config.py +7 -7
  197. synth_ai/v0/tracing/context.py +6 -6
  198. synth_ai/v0/tracing/decorators.py +6 -5
  199. synth_ai/v0/tracing/events/manage.py +1 -1
  200. synth_ai/v0/tracing/events/store.py +5 -4
  201. synth_ai/v0/tracing/immediate_client.py +4 -5
  202. synth_ai/v0/tracing/local.py +3 -3
  203. synth_ai/v0/tracing/log_client_base.py +4 -5
  204. synth_ai/v0/tracing/retry_queue.py +5 -6
  205. synth_ai/v0/tracing/trackers.py +25 -25
  206. synth_ai/v0/tracing/upload.py +6 -0
  207. synth_ai/v0/tracing_v1/__init__.py +1 -1
  208. synth_ai/v0/tracing_v1/abstractions.py +28 -28
  209. synth_ai/v0/tracing_v1/base_client.py +9 -9
  210. synth_ai/v0/tracing_v1/client_manager.py +7 -7
  211. synth_ai/v0/tracing_v1/config.py +7 -7
  212. synth_ai/v0/tracing_v1/context.py +6 -6
  213. synth_ai/v0/tracing_v1/decorators.py +7 -6
  214. synth_ai/v0/tracing_v1/events/manage.py +1 -1
  215. synth_ai/v0/tracing_v1/events/store.py +5 -4
  216. synth_ai/v0/tracing_v1/immediate_client.py +4 -5
  217. synth_ai/v0/tracing_v1/local.py +3 -3
  218. synth_ai/v0/tracing_v1/log_client_base.py +4 -5
  219. synth_ai/v0/tracing_v1/retry_queue.py +5 -6
  220. synth_ai/v0/tracing_v1/trackers.py +25 -25
  221. synth_ai/v0/tracing_v1/upload.py +25 -24
  222. synth_ai/zyk/__init__.py +1 -0
  223. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/METADATA +2 -11
  224. synth_ai-0.2.4.dev7.dist-info/RECORD +299 -0
  225. synth_ai-0.2.4.dev5.dist-info/RECORD +0 -287
  226. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/WHEEL +0 -0
  227. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/entry_points.txt +0 -0
  228. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/licenses/LICENSE +0 -0
  229. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,324 @@
1
+ """
2
+ Example: Random Search optimizer on Banking77 using Groq gpt-oss-20b.
3
+
4
+ Requires:
5
+ - .env with GROQ_API_KEY
6
+ - datasets (`uv add datasets` if needed)
7
+
8
+ Run:
9
+ - uv run -q python -m synth_ai.learning.prompts.run_random_search_banking77
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import json
16
+ import os
17
+ import random
18
+ import time
19
+ from collections.abc import Sequence
20
+ from dataclasses import dataclass, replace
21
+ from pathlib import Path
22
+ from types import SimpleNamespace
23
+ from typing import Any
24
+
25
+ from datasets import load_dataset
26
+ from dotenv import load_dotenv
27
+ from synth_ai.learning.prompts.random_search import random_search_compile
28
+ from synth_ai.lm.core.main_v3 import LM, build_messages
29
+ from tqdm import tqdm
30
+
31
+
32
+ def choose_label(pred: str, label_names: list[str]) -> str:
33
+ norm = (pred or "").strip().lower()
34
+ d = {ln.lower(): ln for ln in label_names}
35
+ if norm in d:
36
+ return d[norm]
37
+
38
+ def score(cand: str) -> int:
39
+ c = cand.lower()
40
+ return sum(1 for w in c.split() if w in norm)
41
+
42
+ return max(label_names, key=score)
43
+
44
+
45
+ def accuracy(pred: str, gold: str, labels: list[str]) -> float:
46
+ return 1.0 if choose_label(pred, labels) == gold else 0.0
47
+
48
+
49
+ @dataclass
50
+ class StudentProgram:
51
+ lm: LM
52
+ label_names: list[str]
53
+ instruction: str
54
+ demos: list[tuple[str, str]]
55
+
56
+ def reset_copy(self):
57
+ return replace(self, instruction=self.instruction, demos=list(self.demos))
58
+
59
+ def deepcopy(self):
60
+ return replace(self, instruction=str(self.instruction), demos=list(self.demos))
61
+
62
+ def with_demos(self, demos: list[tuple[str, str]]):
63
+ return replace(self, demos=list(demos))
64
+
65
+ def run(self, x: str) -> str:
66
+ # Build a prompt with optional demos
67
+ examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
68
+ sys = self.instruction or "You are an intent classifier for Banking77."
69
+ user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
70
+ messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
71
+
72
+ # Call LM synchronously via asyncio
73
+ async def _call():
74
+ resp = await self.lm.respond_async(messages=messages)
75
+ return (resp.raw_response or "").strip()
76
+
77
+ return asyncio.run(_call())
78
+
79
+ async def _apredict(self, x: str):
80
+ examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
81
+ sys = self.instruction or "You are an intent classifier for Banking77."
82
+ user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
83
+ messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
84
+ resp = await self.lm.respond_async(messages=messages)
85
+ return (resp.raw_response or "").strip(), (resp.usage or {})
86
+
87
+
88
+ def main():
89
+ load_dotenv()
90
+ random.seed(0)
91
+
92
+ model = os.getenv("MODEL", "openai/gpt-oss-20b")
93
+ vendor = os.getenv("VENDOR", "groq")
94
+ lm = LM(model=model, vendor=vendor, temperature=0.0)
95
+
96
+ print("Loading Banking77 dataset (train/dev split of test for demo)...")
97
+ ds = load_dataset("banking77")
98
+ label_names: list[str] = ds["test"].features["label"].names # type: ignore
99
+
100
+ # Create small train/val from the test split for speed
101
+ all_items = [(r["text"], label_names[int(r["label"])]) for r in ds["test"]]
102
+ random.shuffle(all_items)
103
+ trainset: Sequence[tuple[str, str]] = all_items[:40]
104
+ valset: Sequence[tuple[str, str]] = all_items[40:60] # 20 examples
105
+
106
+ student = StudentProgram(
107
+ lm=lm,
108
+ label_names=label_names,
109
+ instruction="You are an intent classifier for the Banking77 dataset. Return exactly one label.",
110
+ demos=[],
111
+ )
112
+
113
+ def metric(yhat: str, y: str) -> float:
114
+ return accuracy(yhat, y, label_names)
115
+
116
+ total_candidates = 3 + 3 # zero-shot, labeled few-shot, bootstrapped + 3 random seeds
117
+ print(
118
+ f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)..."
119
+ )
120
+
121
+ def eval_parallel(program: StudentProgram, dataset: Sequence[tuple[str, str]], metric_fn):
122
+ async def _run():
123
+ xs = [x for x, _ in dataset]
124
+ ys = [y for _, y in dataset]
125
+ preds: list[Optional[str]] = [None] * len(xs)
126
+ sem = asyncio.Semaphore(int(os.getenv("CONCURRENCY", "5")))
127
+
128
+ async def worker(i: int, x: str, y: str):
129
+ import time
130
+
131
+ t_start = time.monotonic()
132
+ try:
133
+ async with sem:
134
+ pred, usage = await asyncio.wait_for(
135
+ program._apredict(x),
136
+ timeout=float(os.getenv("TIMEOUT_S", "45")),
137
+ )
138
+ t_end = time.monotonic()
139
+ return i, y, pred, t_start, t_end, usage or {}
140
+ except asyncio.CancelledError:
141
+ # Respect cancellation but return a placeholder record so scheduler can proceed
142
+ t_end = time.monotonic()
143
+ return i, y, "", t_start, t_end, {}
144
+ except Exception:
145
+ t_end = time.monotonic()
146
+ return i, y, "", t_start, t_end, {}
147
+
148
+ tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys, strict=False))]
149
+ correct_sum = 0.0
150
+ processed = 0
151
+ import statistics
152
+ import time
153
+
154
+ durations: list[float] = []
155
+ in_tok_sum = 0
156
+ out_tok_sum = 0
157
+ in_tok_count = 0
158
+ out_tok_count = 0
159
+ details: list[dict[str, Any]] = []
160
+ t_batch_start = time.monotonic()
161
+ deadline = float(os.getenv("BATCH_DEADLINE_S", "20"))
162
+ with tqdm(total=len(tasks), desc="Rollouts", leave=False) as pbar:
163
+ pending = set(tasks)
164
+ # Process completions until all done or deadline reached
165
+ while pending:
166
+ elapsed = time.monotonic() - t_batch_start
167
+ remaining = max(0.0, deadline - elapsed)
168
+ if remaining <= 0.0:
169
+ # Cancel any remaining
170
+ for t in pending:
171
+ t.cancel()
172
+ done, _ = await asyncio.wait(pending, return_when=asyncio.ALL_COMPLETED)
173
+ # Record canceled as zeros
174
+ for task in done:
175
+ try:
176
+ i, y_true, pred, t_start, t_end, usage = task.result()
177
+ except Exception:
178
+ # Unknown index: we can't recover; skip as it's canceled before start
179
+ continue
180
+ # Already processed ones shouldn't be in pending; skip
181
+ break
182
+ # Wait for at least one completion within remaining time (polling granularity <= 1s)
183
+ timeout = min(1.0, remaining)
184
+ done, pending = await asyncio.wait(
185
+ pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
186
+ )
187
+ import contextlib
188
+ for task in done:
189
+ try:
190
+ i, y_true, pred, t_start, t_end, usage = task.result()
191
+ except BaseException:
192
+ # Treat as failure/cancelled
193
+ continue
194
+ durations.append(max(0.0, t_end - t_start))
195
+ preds[i] = pred
196
+ processed += 1
197
+ with contextlib.suppress(Exception):
198
+ correct_sum += float(metric_fn(pred, y_true))
199
+ with contextlib.suppress(Exception):
200
+ pt = usage.get("prompt_tokens") or usage.get("input_tokens")
201
+ ct = usage.get("completion_tokens") or usage.get("output_tokens")
202
+ if isinstance(pt, (int, float)):
203
+ in_tok_sum += int(pt)
204
+ in_tok_count += 1
205
+ if isinstance(ct, (int, float)):
206
+ out_tok_sum += int(ct)
207
+ out_tok_count += 1
208
+ details.append(
209
+ {
210
+ "index": i,
211
+ "seconds": max(0.0, t_end - t_start),
212
+ "score": float(metric_fn(pred, y_true)),
213
+ "usage": {
214
+ "prompt_tokens": usage.get("prompt_tokens")
215
+ or usage.get("input_tokens"),
216
+ "completion_tokens": usage.get("completion_tokens")
217
+ or usage.get("output_tokens"),
218
+ },
219
+ }
220
+ )
221
+ pbar.update(1)
222
+ med = statistics.median(durations) if durations else 0.0
223
+ mx = max(durations) if durations else 0.0
224
+ avg_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
225
+ avg_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
226
+ pbar.set_postfix(
227
+ {
228
+ "acc": f"{(correct_sum / processed):.2f}",
229
+ "done": f"{processed}/{len(tasks)}",
230
+ "med_s": f"{med:.1f}",
231
+ "max_s": f"{mx:.1f}",
232
+ "tin": f"{avg_in:.1f}",
233
+ "tout": f"{avg_out:.1f}",
234
+ }
235
+ )
236
+ # Compute score only from completed/successful rollouts (drop timeouts/cancelled)
237
+ subs = [float(d.get("score", 0.0)) for d in details]
238
+ result = SimpleNamespace(score=(sum(subs) / max(1, len(subs))), subscores=subs)
239
+ result.details = details
240
+ result.mean_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
241
+ result.mean_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
242
+ return result
243
+
244
+ return asyncio.run(_run())
245
+
246
+ pbar = tqdm(total=total_candidates, desc="Candidates")
247
+ candidate_eval_details: dict[int, Any] = {}
248
+
249
+ def on_cand(idx: int, score: float, res, intervention):
250
+ pbar.update(1)
251
+ pbar.set_postfix({"score": f"{score:.2f}"})
252
+ # store per-instance details (for apples-to-apples)
253
+ import contextlib
254
+ with contextlib.suppress(Exception):
255
+ candidate_eval_details[idx] = {
256
+ "score": score,
257
+ "mean_in": getattr(res, "mean_in", None),
258
+ "mean_out": getattr(res, "mean_out", None),
259
+ "instances": getattr(res, "details", None),
260
+ }
261
+ # visible summary line per candidate
262
+ kind = (
263
+ intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
264
+ )
265
+ label = intervention.get("label") if isinstance(intervention, dict) else None
266
+ seed = intervention.get("seed") if isinstance(intervention, dict) else None
267
+ processed = len(getattr(res, "details", []) or [])
268
+ from tqdm import tqdm as _tqdm
269
+
270
+ _tqdm.write(
271
+ f"Candidate {idx}/{total_candidates} [{kind}{'' if label is None else f', label={label}'}{'' if seed is None else f', seed={seed}'}]: "
272
+ f"score={score:.2f} | mean tin/tout={getattr(res, 'mean_in', 0):.1f}/{getattr(res, 'mean_out', 0):.1f} | N={processed}"
273
+ )
274
+
275
+ best, records = random_search_compile(
276
+ student=student,
277
+ trainset=trainset,
278
+ valset=valset,
279
+ metric=metric,
280
+ evaluate_fn=eval_parallel,
281
+ max_bootstrapped_demos=0,
282
+ max_labeled_demos=4,
283
+ max_rounds=2,
284
+ num_candidate_programs=3,
285
+ on_candidate_evaluated=on_cand,
286
+ )
287
+ pbar.close()
288
+
289
+ # Evaluate best on holdout (valset) with parallel rollouts
290
+ print("Evaluating best program on val (parallel rollouts)...")
291
+ best_res = eval_parallel(best, valset, metric)
292
+ correct = int(round(best_res.score * max(1, len(best_res.subscores))))
293
+ print(
294
+ "Best program accuracy on val: "
295
+ f"{correct}/{len(valset)} ({best_res.score:.2%}) "
296
+ f"| mean tokens in/out: {getattr(best_res, 'mean_in', 0):.1f}/{getattr(best_res, 'mean_out', 0):.1f}"
297
+ )
298
+
299
+ # Save per-candidate scores and interventions
300
+ out = {
301
+ "context": {
302
+ "model": model,
303
+ "vendor": vendor,
304
+ "train_size": len(trainset),
305
+ "val_size": len(valset),
306
+ },
307
+ "candidates": records,
308
+ "candidate_eval_details": candidate_eval_details,
309
+ "best_eval_details": {
310
+ "score": best_res.score,
311
+ "mean_in": getattr(best_res, "mean_in", None),
312
+ "mean_out": getattr(best_res, "mean_out", None),
313
+ "instances": getattr(best_res, "details", None),
314
+ },
315
+ }
316
+ out_dir = Path(__file__).parent
317
+ fname = str(out_dir / f"random_search_banking77_{int(time.time())}.json")
318
+ with open(fname, "w") as f:
319
+ json.dump(out, f, indent=2)
320
+ print(f"Saved candidate records to {fname}")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ main()
synth_ai/lm/__init__.py CHANGED
@@ -4,24 +4,24 @@ Synth AI Language Model Interface.
4
4
  Provides a unified interface for multiple LLM providers including OpenAI and Synth.
5
5
  """
6
6
 
7
- from .config import SynthConfig, OpenAIConfig
8
- from .warmup import warmup_synth_model, get_warmup_status
7
+ from .config import OpenAIConfig, SynthConfig
8
+ from .core.main_v3 import LM
9
9
  from .unified_interface import (
10
- UnifiedLMProvider,
11
10
  OpenAIProvider,
12
11
  SynthProvider,
13
12
  UnifiedLMClient,
13
+ UnifiedLMProvider,
14
14
  create_provider,
15
15
  )
16
16
  from .vendors.synth_client import (
17
17
  AsyncSynthClient,
18
18
  SyncSynthClient,
19
19
  create_async_client,
20
- create_sync_client,
21
20
  create_chat_completion_async,
22
21
  create_chat_completion_sync,
22
+ create_sync_client,
23
23
  )
24
- from .core.main_v3 import LM
24
+ from .warmup import get_warmup_status, warmup_synth_model
25
25
 
26
26
  __all__ = [
27
27
  # Configuration
@@ -7,7 +7,6 @@ of the application run, useful for avoiding redundant API calls within a session
7
7
 
8
8
  import os
9
9
  from dataclasses import dataclass
10
- from typing import Optional, Union
11
10
 
12
11
  from diskcache import Cache
13
12
  from pydantic import BaseModel
@@ -20,24 +19,25 @@ from synth_ai.lm.vendors.base import BaseLMResponse
20
19
  class EphemeralCache:
21
20
  """
22
21
  Ephemeral cache implementation using diskcache.
23
-
22
+
24
23
  This cache stores LM responses temporarily on disk with a size limit.
25
24
  The cache is cleared when the application restarts.
26
25
  """
26
+
27
27
  def __init__(self, fast_cache_dir: str = ".cache/ephemeral_cache"):
28
28
  os.makedirs(fast_cache_dir, exist_ok=True)
29
29
  self.fast_cache = Cache(fast_cache_dir, size_limit=DISKCACHE_SIZE_LIMIT)
30
30
 
31
31
  def hit_cache(
32
- self, key: str, response_model: Optional[BaseModel] = None
33
- ) -> Optional[BaseLMResponse]:
32
+ self, key: str, response_model: BaseModel | None = None
33
+ ) -> BaseLMResponse | None:
34
34
  """
35
35
  Check if a response exists in cache for the given key.
36
-
36
+
37
37
  Args:
38
38
  key: Cache key to look up
39
39
  response_model: Optional Pydantic model to reconstruct structured output
40
-
40
+
41
41
  Returns:
42
42
  BaseLMResponse if found in cache, None otherwise
43
43
  """
@@ -65,14 +65,14 @@ class EphemeralCache:
65
65
  tool_calls=tool_calls,
66
66
  )
67
67
 
68
- def add_to_cache(self, key: str, response: Union[BaseLMResponse, str]) -> None:
68
+ def add_to_cache(self, key: str, response: BaseLMResponse | str) -> None:
69
69
  """
70
70
  Add a response to the cache.
71
-
71
+
72
72
  Args:
73
73
  key: Cache key to store under
74
74
  response: Either a BaseLMResponse object or raw string response
75
-
75
+
76
76
  Raises:
77
77
  ValueError: If response type is not supported
78
78
  """
@@ -1,5 +1,5 @@
1
1
  import hashlib
2
- from typing import Any, Dict, List, Optional, Type
2
+ from typing import Any
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
@@ -17,11 +17,11 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  def map_params_to_key(
20
- messages: List[Dict],
20
+ messages: list[dict],
21
21
  model: str,
22
22
  temperature: float,
23
- response_model: Optional[Type[BaseModel]],
24
- tools: Optional[List[BaseTool]] = None,
23
+ response_model: type[BaseModel] | None,
24
+ tools: list[BaseTool] | None = None,
25
25
  reasoning_effort: str = "low",
26
26
  ) -> str:
27
27
  if any(m is None for m in messages):
@@ -76,37 +76,37 @@ class CacheHandler:
76
76
  self.use_persistent_store = use_persistent_store
77
77
  self.use_ephemeral_store = use_ephemeral_store
78
78
 
79
- def _validate_messages(self, messages: List[Dict[str, Any]]) -> None:
79
+ def _validate_messages(self, messages: list[dict[str, Any]]) -> None:
80
80
  """Validate that messages are in the correct format."""
81
- assert all([type(msg["content"]) == str for msg in messages]), (
81
+ assert all(isinstance(msg["content"], str) for msg in messages), (
82
82
  "All message contents must be strings"
83
83
  )
84
84
 
85
85
  def hit_managed_cache(
86
86
  self,
87
87
  model: str,
88
- messages: List[Dict[str, Any]],
89
- lm_config: Dict[str, Any],
90
- tools: Optional[List[BaseTool]] = None,
91
- ) -> Optional[BaseLMResponse]:
88
+ messages: list[dict[str, Any]],
89
+ lm_config: dict[str, Any],
90
+ tools: list[BaseTool] | None = None,
91
+ ) -> BaseLMResponse | None:
92
92
  """Hit the cache with the given key."""
93
93
  self._validate_messages(messages)
94
- assert type(lm_config) == dict, "lm_config must be a dictionary"
94
+ assert isinstance(lm_config, dict), "lm_config must be a dictionary"
95
95
  key = map_params_to_key(
96
96
  messages,
97
97
  model,
98
98
  lm_config.get("temperature", 0.0),
99
- lm_config.get("response_model", None),
99
+ lm_config.get("response_model"),
100
100
  tools,
101
101
  lm_config.get("reasoning_effort", "low"),
102
102
  )
103
103
  if self.use_persistent_store:
104
104
  return persistent_cache.hit_cache(
105
- key=key, response_model=lm_config.get("response_model", None)
105
+ key=key, response_model=lm_config.get("response_model")
106
106
  )
107
107
  elif self.use_ephemeral_store:
108
108
  return ephemeral_cache.hit_cache(
109
- key=key, response_model=lm_config.get("response_model", None)
109
+ key=key, response_model=lm_config.get("response_model")
110
110
  )
111
111
  else:
112
112
  return None
@@ -114,20 +114,20 @@ class CacheHandler:
114
114
  def add_to_managed_cache(
115
115
  self,
116
116
  model: str,
117
- messages: List[Dict[str, Any]],
118
- lm_config: Dict[str, Any],
117
+ messages: list[dict[str, Any]],
118
+ lm_config: dict[str, Any],
119
119
  output: BaseLMResponse,
120
- tools: Optional[List[BaseTool]] = None,
120
+ tools: list[BaseTool] | None = None,
121
121
  ) -> None:
122
122
  """Add the given output to the cache."""
123
123
  self._validate_messages(messages)
124
- assert type(output) == BaseLMResponse, "output must be a BaseLMResponse"
125
- assert type(lm_config) == dict, "lm_config must be a dictionary"
124
+ assert isinstance(output, BaseLMResponse), "output must be a BaseLMResponse"
125
+ assert isinstance(lm_config, dict), "lm_config must be a dictionary"
126
126
  key = map_params_to_key(
127
127
  messages,
128
128
  model,
129
129
  lm_config.get("temperature", 0.0),
130
- lm_config.get("response_model", None),
130
+ lm_config.get("response_model"),
131
131
  tools,
132
132
  lm_config.get("reasoning_effort", "low"),
133
133
  )
@@ -9,7 +9,6 @@ import json
9
9
  import os
10
10
  import sqlite3
11
11
  from dataclasses import dataclass
12
- from typing import Optional, Type, Union
13
12
 
14
13
  from pydantic import BaseModel
15
14
 
@@ -20,10 +19,11 @@ from synth_ai.lm.vendors.base import BaseLMResponse
20
19
  class PersistentCache:
21
20
  """
22
21
  Persistent cache implementation using SQLite.
23
-
22
+
24
23
  This cache stores LM responses in a SQLite database that persists
25
24
  across application restarts.
26
25
  """
26
+
27
27
  def __init__(self, db_path: str = ".cache/persistent_cache.db"):
28
28
  os.makedirs(os.path.dirname(db_path), exist_ok=True)
29
29
  self.conn = sqlite3.connect(db_path)
@@ -33,15 +33,15 @@ class PersistentCache:
33
33
  self.conn.commit()
34
34
 
35
35
  def hit_cache(
36
- self, key: str, response_model: Optional[Type[BaseModel]] = None
37
- ) -> Optional[BaseLMResponse]:
36
+ self, key: str, response_model: type[BaseModel] | None = None
37
+ ) -> BaseLMResponse | None:
38
38
  """
39
39
  Check if a response exists in cache for the given key.
40
-
40
+
41
41
  Args:
42
42
  key: Cache key to look up
43
43
  response_model: Optional Pydantic model class to reconstruct structured output
44
-
44
+
45
45
  Returns:
46
46
  BaseLMResponse if found in cache, None otherwise
47
47
  """
@@ -72,17 +72,17 @@ class PersistentCache:
72
72
  tool_calls=tool_calls,
73
73
  )
74
74
 
75
- def add_to_cache(self, key: str, response: Union[BaseLMResponse, str]) -> None:
75
+ def add_to_cache(self, key: str, response: BaseLMResponse | str) -> None:
76
76
  """
77
77
  Add a response to the cache.
78
-
78
+
79
79
  Args:
80
80
  key: Cache key to store under
81
81
  response: Either a BaseLMResponse object or raw string response
82
-
82
+
83
83
  Raises:
84
84
  ValueError: If response type is not supported
85
-
85
+
86
86
  Note:
87
87
  Uses INSERT OR REPLACE to update existing cache entries.
88
88
  """
synth_ai/lm/config.py CHANGED
@@ -4,8 +4,8 @@ Loads sensitive configuration from environment variables.
4
4
  """
5
5
 
6
6
  import os
7
- from typing import Optional
8
7
  from dataclasses import dataclass
8
+
9
9
  from dotenv import load_dotenv
10
10
 
11
11
  # Load environment variables from .env file
@@ -15,10 +15,10 @@ load_dotenv()
15
15
  def should_use_cache() -> bool:
16
16
  """
17
17
  Check if caching should be enabled based on environment variable.
18
-
18
+
19
19
  Returns:
20
20
  bool: True if caching is enabled (default), False if explicitly disabled.
21
-
21
+
22
22
  Note:
23
23
  Caching is controlled by the USE_ZYK_CACHE environment variable.
24
24
  Set to "false", "0", or "no" to disable caching.
synth_ai/lm/constants.py CHANGED
@@ -13,20 +13,20 @@ GEMINI_REASONING_MODELS = ["gemini-2.5-flash", "gemini-2.5-pro"]
13
13
  # Gemini models that support thinking
14
14
  GEMINI_REASONING_MODELS = ["gemini-2.5-flash", "gemini-2.5-pro"]
15
15
  GEMINI_THINKING_BUDGETS = {
16
- "high": 10000, # High thinking budget for complex reasoning
17
- "medium": 5000, # Medium thinking budget for standard reasoning
18
- "low": 2500, # Low thinking budget for simple reasoning
16
+ "high": 10000, # High thinking budget for complex reasoning
17
+ "medium": 5000, # Medium thinking budget for standard reasoning
18
+ "low": 2500, # Low thinking budget for simple reasoning
19
19
  }
20
20
 
21
21
  # Anthropic Sonnet 3.7 budgets
22
22
  SONNET_37_BUDGETS = {
23
- "high": 8192, # High budget for complex tasks
24
- "medium": 4096, # Medium budget for standard tasks
25
- "low": 2048, # Low budget for simple tasks
23
+ "high": 8192, # High budget for complex tasks
24
+ "medium": 4096, # Medium budget for standard tasks
25
+ "low": 2048, # Low budget for simple tasks
26
26
  }
27
27
 
28
28
  # Combined list of all reasoning models
29
29
  REASONING_MODELS = OPENAI_REASONING_MODELS + CLAUDE_REASONING_MODELS + GEMINI_REASONING_MODELS
30
30
 
31
31
  # Special base temperatures for reasoning models (all set to 1.0)
32
- SPECIAL_BASE_TEMPS = {model: 1 for model in REASONING_MODELS}
32
+ SPECIAL_BASE_TEMPS = dict.fromkeys(REASONING_MODELS, 1)
synth_ai/lm/core/all.py CHANGED
@@ -4,12 +4,12 @@ from synth_ai.lm.vendors.core.openai_api import (
4
4
  OpenAIPrivate,
5
5
  OpenAIStructuredOutputClient,
6
6
  )
7
+ from synth_ai.lm.vendors.supported.custom_endpoint import CustomEndpointAPI
7
8
  from synth_ai.lm.vendors.supported.deepseek import DeepSeekAPI
8
- from synth_ai.lm.vendors.supported.together import TogetherAPI
9
- from synth_ai.lm.vendors.supported.groq import GroqAPI
10
9
  from synth_ai.lm.vendors.supported.grok import GrokAPI
11
- from synth_ai.lm.vendors.supported.custom_endpoint import CustomEndpointAPI
10
+ from synth_ai.lm.vendors.supported.groq import GroqAPI
12
11
  from synth_ai.lm.vendors.supported.openrouter import OpenRouterAPI
12
+ from synth_ai.lm.vendors.supported.together import TogetherAPI
13
13
 
14
14
 
15
15
  class OpenAIClient(OpenAIPrivate):
@@ -57,3 +57,17 @@ class CustomEndpointClient(CustomEndpointAPI):
57
57
  class OpenRouterClient(OpenRouterAPI):
58
58
  def __init__(self):
59
59
  super().__init__()
60
+
61
+
62
+ __all__ = [
63
+ "OpenAIClient",
64
+ "AnthropicClient",
65
+ "GeminiClient",
66
+ "DeepSeekClient",
67
+ "TogetherClient",
68
+ "GroqClient",
69
+ "GrokClient",
70
+ "CustomEndpointClient",
71
+ "OpenRouterClient",
72
+ "OpenAIStructuredOutputClient",
73
+ ]
@@ -1,5 +1,3 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Any, Callable, Dict, List, Literal, Optional, Union
3
1
 
4
2
 
5
3
  class StructuredOutputCoercionFailureException(Exception):