synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
1
+ """
2
+ Example: Random Search optimizer on Banking77 using Groq gpt-oss-20b.
3
+
4
+ Requires:
5
+ - .env with GROQ_API_KEY
6
+ - datasets (`uv add datasets` if needed)
7
+
8
+ Run:
9
+ - uv run -q python -m synth_ai.learning.prompts.run_random_search_banking77
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import os
16
+ import random
17
+ from dataclasses import dataclass, replace
18
+ from types import SimpleNamespace
19
+ from tqdm import tqdm
20
+ from typing import Any, Dict, List, Sequence, Tuple
21
+
22
+ from dotenv import load_dotenv
23
+ from datasets import load_dataset
24
+
25
+ from synth_ai.lm.core.main_v3 import LM, build_messages
26
+ import json
27
+ import time
28
+ from pathlib import Path
29
+ from synth_ai.learning.prompts.random_search import random_search_compile
30
+
31
+
32
+ def choose_label(pred: str, label_names: List[str]) -> str:
33
+ norm = (pred or "").strip().lower()
34
+ d = {ln.lower(): ln for ln in label_names}
35
+ if norm in d:
36
+ return d[norm]
37
+ def score(cand: str) -> int:
38
+ c = cand.lower()
39
+ return sum(1 for w in c.split() if w in norm)
40
+ return max(label_names, key=score)
41
+
42
+
43
+ def accuracy(pred: str, gold: str, labels: List[str]) -> float:
44
+ return 1.0 if choose_label(pred, labels) == gold else 0.0
45
+
46
+
47
+ @dataclass
48
+ class StudentProgram:
49
+ lm: LM
50
+ label_names: List[str]
51
+ instruction: str
52
+ demos: List[Tuple[str, str]]
53
+
54
+ def reset_copy(self):
55
+ return replace(self, instruction=self.instruction, demos=list(self.demos))
56
+
57
+ def deepcopy(self):
58
+ return replace(self, instruction=str(self.instruction), demos=list(self.demos))
59
+
60
+ def with_demos(self, demos: List[Tuple[str, str]]):
61
+ return replace(self, demos=list(demos))
62
+
63
+ def run(self, x: str) -> str:
64
+ # Build a prompt with optional demos
65
+ examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
66
+ sys = self.instruction or "You are an intent classifier for Banking77."
67
+ user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
68
+ messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
69
+ # Call LM synchronously via asyncio
70
+ async def _call():
71
+ resp = await self.lm.respond_async(messages=messages)
72
+ return (resp.raw_response or "").strip()
73
+ return asyncio.run(_call())
74
+
75
+ async def _apredict(self, x: str):
76
+ examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
77
+ sys = self.instruction or "You are an intent classifier for Banking77."
78
+ user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
79
+ messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
80
+ resp = await self.lm.respond_async(messages=messages)
81
+ return (resp.raw_response or "").strip(), (resp.usage or {})
82
+
83
+
84
+ def main():
85
+ load_dotenv()
86
+ random.seed(0)
87
+
88
+ model = os.getenv("MODEL", "openai/gpt-oss-20b")
89
+ vendor = os.getenv("VENDOR", "groq")
90
+ lm = LM(model=model, vendor=vendor, temperature=0.0)
91
+
92
+ print("Loading Banking77 dataset (train/dev split of test for demo)...")
93
+ ds = load_dataset("banking77")
94
+ label_names: List[str] = ds["test"].features["label"].names # type: ignore
95
+
96
+ # Create small train/val from the test split for speed
97
+ all_items = [(r["text"], label_names[int(r["label"])]) for r in ds["test"]]
98
+ random.shuffle(all_items)
99
+ trainset: Sequence[Tuple[str, str]] = all_items[:40]
100
+ valset: Sequence[Tuple[str, str]] = all_items[40:60] # 20 examples
101
+
102
+ student = StudentProgram(
103
+ lm=lm,
104
+ label_names=label_names,
105
+ instruction="You are an intent classifier for the Banking77 dataset. Return exactly one label.",
106
+ demos=[],
107
+ )
108
+
109
+ def metric(yhat: str, y: str) -> float:
110
+ return accuracy(yhat, y, label_names)
111
+
112
+ total_candidates = 3 + 3 # zero-shot, labeled few-shot, bootstrapped + 3 random seeds
113
+ print(f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)...")
114
+
115
+ def eval_parallel(program: StudentProgram, dataset: Sequence[Tuple[str, str]], metric_fn):
116
+ async def _run():
117
+ xs = [x for x, _ in dataset]
118
+ ys = [y for _, y in dataset]
119
+ preds: List[Optional[str]] = [None] * len(xs)
120
+ sem = asyncio.Semaphore(int(os.getenv("CONCURRENCY", "5")))
121
+
122
+ async def worker(i: int, x: str, y: str):
123
+ import time
124
+ t_start = time.monotonic()
125
+ try:
126
+ async with sem:
127
+ pred, usage = await asyncio.wait_for(
128
+ program._apredict(x),
129
+ timeout=float(os.getenv("TIMEOUT_S", "45")),
130
+ )
131
+ t_end = time.monotonic()
132
+ return i, y, pred, t_start, t_end, usage or {}
133
+ except asyncio.CancelledError:
134
+ # Respect cancellation but return a placeholder record so scheduler can proceed
135
+ t_end = time.monotonic()
136
+ return i, y, "", t_start, t_end, {}
137
+ except Exception:
138
+ t_end = time.monotonic()
139
+ return i, y, "", t_start, t_end, {}
140
+
141
+ tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys))]
142
+ correct_sum = 0.0
143
+ processed = 0
144
+ import time, statistics
145
+ durations: List[float] = []
146
+ in_tok_sum = 0
147
+ out_tok_sum = 0
148
+ in_tok_count = 0
149
+ out_tok_count = 0
150
+ details: List[Dict[str, Any]] = []
151
+ t_batch_start = time.monotonic()
152
+ deadline = float(os.getenv("BATCH_DEADLINE_S", "20"))
153
+ with tqdm(total=len(tasks), desc="Rollouts", leave=False) as pbar:
154
+ pending = set(tasks)
155
+ # Process completions until all done or deadline reached
156
+ while pending:
157
+ elapsed = time.monotonic() - t_batch_start
158
+ remaining = max(0.0, deadline - elapsed)
159
+ if remaining <= 0.0:
160
+ # Cancel any remaining
161
+ for t in pending:
162
+ t.cancel()
163
+ done, _ = await asyncio.wait(pending, return_when=asyncio.ALL_COMPLETED)
164
+ # Record canceled as zeros
165
+ for task in done:
166
+ try:
167
+ i, y_true, pred, t_start, t_end, usage = task.result()
168
+ except Exception:
169
+ # Unknown index: we can't recover; skip as it's canceled before start
170
+ continue
171
+ # Already processed ones shouldn't be in pending; skip
172
+ break
173
+ # Wait for at least one completion within remaining time (polling granularity <= 1s)
174
+ timeout = min(1.0, remaining)
175
+ done, pending = await asyncio.wait(pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
176
+ for task in done:
177
+ try:
178
+ i, y_true, pred, t_start, t_end, usage = task.result()
179
+ except BaseException:
180
+ # Treat as failure/cancelled
181
+ continue
182
+ durations.append(max(0.0, t_end - t_start))
183
+ preds[i] = pred
184
+ processed += 1
185
+ try:
186
+ correct_sum += float(metric_fn(pred, y_true))
187
+ except Exception:
188
+ pass
189
+ try:
190
+ pt = usage.get("prompt_tokens") or usage.get("input_tokens")
191
+ ct = usage.get("completion_tokens") or usage.get("output_tokens")
192
+ if isinstance(pt, (int, float)):
193
+ in_tok_sum += int(pt)
194
+ in_tok_count += 1
195
+ if isinstance(ct, (int, float)):
196
+ out_tok_sum += int(ct)
197
+ out_tok_count += 1
198
+ except Exception:
199
+ pass
200
+ details.append({
201
+ "index": i,
202
+ "seconds": max(0.0, t_end - t_start),
203
+ "score": float(metric_fn(pred, y_true)),
204
+ "usage": {
205
+ "prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens"),
206
+ "completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens"),
207
+ },
208
+ })
209
+ pbar.update(1)
210
+ med = statistics.median(durations) if durations else 0.0
211
+ mx = max(durations) if durations else 0.0
212
+ avg_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
213
+ avg_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
214
+ pbar.set_postfix({
215
+ "acc": f"{(correct_sum/processed):.2f}",
216
+ "done": f"{processed}/{len(tasks)}",
217
+ "med_s": f"{med:.1f}",
218
+ "max_s": f"{mx:.1f}",
219
+ "tin": f"{avg_in:.1f}",
220
+ "tout": f"{avg_out:.1f}",
221
+ })
222
+ # Compute score only from completed/successful rollouts (drop timeouts/cancelled)
223
+ subs = [float(d.get("score", 0.0)) for d in details]
224
+ result = SimpleNamespace(score=(sum(subs) / max(1, len(subs))), subscores=subs)
225
+ result.details = details
226
+ result.mean_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
227
+ result.mean_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
228
+ return result
229
+ return asyncio.run(_run())
230
+ pbar = tqdm(total=total_candidates, desc="Candidates")
231
+ candidate_eval_details: Dict[int, Any] = {}
232
+ def on_cand(idx: int, score: float, res, intervention):
233
+ pbar.update(1)
234
+ pbar.set_postfix({"score": f"{score:.2f}"})
235
+ # store per-instance details (for apples-to-apples)
236
+ try:
237
+ candidate_eval_details[idx] = {
238
+ "score": score,
239
+ "mean_in": getattr(res, "mean_in", None),
240
+ "mean_out": getattr(res, "mean_out", None),
241
+ "instances": getattr(res, "details", None),
242
+ }
243
+ except Exception:
244
+ pass
245
+ # visible summary line per candidate
246
+ kind = intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
247
+ label = intervention.get("label") if isinstance(intervention, dict) else None
248
+ seed = intervention.get("seed") if isinstance(intervention, dict) else None
249
+ processed = len(getattr(res, "details", []) or [])
250
+ from tqdm import tqdm as _tqdm
251
+ _tqdm.write(
252
+ f"Candidate {idx}/{total_candidates} [{kind}{'' if label is None else f', label={label}'}{'' if seed is None else f', seed={seed}'}]: "
253
+ f"score={score:.2f} | mean tin/tout={getattr(res, 'mean_in', 0):.1f}/{getattr(res, 'mean_out', 0):.1f} | N={processed}"
254
+ )
255
+
256
+ best, records = random_search_compile(
257
+ student=student,
258
+ trainset=trainset,
259
+ valset=valset,
260
+ metric=metric,
261
+ evaluate_fn=eval_parallel,
262
+ max_bootstrapped_demos=0,
263
+ max_labeled_demos=4,
264
+ max_rounds=2,
265
+ num_candidate_programs=3,
266
+ on_candidate_evaluated=on_cand,
267
+ )
268
+ pbar.close()
269
+
270
+ # Evaluate best on holdout (valset) with parallel rollouts
271
+ print("Evaluating best program on val (parallel rollouts)...")
272
+ best_res = eval_parallel(best, valset, metric)
273
+ correct = int(round(best_res.score * max(1, len(best_res.subscores))))
274
+ print(
275
+ "Best program accuracy on val: "
276
+ f"{correct}/{len(valset)} ({best_res.score:.2%}) "
277
+ f"| mean tokens in/out: {getattr(best_res, 'mean_in', 0):.1f}/{getattr(best_res, 'mean_out', 0):.1f}"
278
+ )
279
+
280
+ # Save per-candidate scores and interventions
281
+ out = {
282
+ "context": {
283
+ "model": model,
284
+ "vendor": vendor,
285
+ "train_size": len(trainset),
286
+ "val_size": len(valset),
287
+ },
288
+ "candidates": records,
289
+ "candidate_eval_details": candidate_eval_details,
290
+ "best_eval_details": {
291
+ "score": best_res.score,
292
+ "mean_in": getattr(best_res, "mean_in", None),
293
+ "mean_out": getattr(best_res, "mean_out", None),
294
+ "instances": getattr(best_res, "details", None),
295
+ },
296
+ }
297
+ out_dir = Path(__file__).parent
298
+ fname = str(out_dir / f"random_search_banking77_{int(time.time())}.json")
299
+ with open(fname, "w") as f:
300
+ json.dump(out, f, indent=2)
301
+ print(f"Saved candidate records to {fname}")
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import contextvars
4
+ from contextlib import contextmanager
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ Rule = Dict[str, Any]
8
+
9
+ _rules_ctx: contextvars.ContextVar[Optional[List[Rule]]] = contextvars.ContextVar(
10
+ "injection_rules", default=None
11
+ )
12
+
13
+
14
+ def set_injection_rules(rules: List[Rule]):
15
+ """Set prompt-injection rules for the current context and return a reset token.
16
+
17
+ Each rule must be a dict with at least keys: "find" and "replace" (strings).
18
+ Optional: "roles" as a list of role names to scope the replacement.
19
+ """
20
+ if not isinstance(rules, list) or not all(
21
+ isinstance(r, dict) and "find" in r and "replace" in r for r in rules
22
+ ):
23
+ raise ValueError("Injection rules must be a list of dicts with 'find' and 'replace'")
24
+ return _rules_ctx.set(rules)
25
+
26
+
27
+ def get_injection_rules() -> Optional[List[Rule]]:
28
+ """Get the current context's injection rules, if any."""
29
+ return _rules_ctx.get()
30
+
31
+
32
+ def clear_injection_rules(token) -> None:
33
+ """Reset the injection rules to the previous value using the provided token."""
34
+ _rules_ctx.reset(token)
35
+
36
+
37
+ def apply_injection(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
38
+ """Apply ordered substring replacements to text parts of messages in place.
39
+
40
+ - Only modifies `str` content or list parts where `part["type"] == "text"`.
41
+ - Honors optional `roles` scoping in each rule.
42
+ - Returns the input list for convenience.
43
+ """
44
+ rules = get_injection_rules()
45
+ if not rules:
46
+ return messages
47
+
48
+ for m in messages:
49
+ role = m.get("role")
50
+ content = m.get("content")
51
+ if isinstance(content, str):
52
+ new_content = content
53
+ for r in rules:
54
+ allowed_roles = r.get("roles")
55
+ if allowed_roles is not None and role not in allowed_roles:
56
+ continue
57
+ new_content = new_content.replace(str(r["find"]), str(r["replace"]))
58
+ m["content"] = new_content
59
+ elif isinstance(content, list):
60
+ for part in content:
61
+ if part.get("type") == "text":
62
+ text = part.get("text", "")
63
+ new_text = text
64
+ for r in rules:
65
+ allowed_roles = r.get("roles")
66
+ if allowed_roles is not None and role not in allowed_roles:
67
+ continue
68
+ new_text = new_text.replace(str(r["find"]), str(r["replace"]))
69
+ part["text"] = new_text
70
+ return messages
71
+
72
+
73
+ @contextmanager
74
+ def injection_rules_ctx(rules: List[Rule]):
75
+ """Context manager to temporarily apply injection rules within the block."""
76
+ tok = set_injection_rules(rules)
77
+ try:
78
+ yield
79
+ finally:
80
+ clear_injection_rules(tok)
81
+
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+ import contextvars
6
+
7
+ from synth_ai.lm.injection import (
8
+ set_injection_rules,
9
+ clear_injection_rules,
10
+ apply_injection as _apply_injection,
11
+ )
12
+
13
+ # Context to hold a list of override specs to evaluate per-call
14
+ # Each spec shape (minimal v1):
15
+ # {
16
+ # "match": {"contains": "atm", "role": "user" | "system" | None},
17
+ # "injection_rules": [{"find": str, "replace": str, "roles": Optional[List[str]]}],
18
+ # "params": { ... api params to override ... },
19
+ # "tools": { ... optional tools overrides ... },
20
+ # }
21
+ _override_specs_ctx: contextvars.ContextVar[Optional[List[Dict[str, Any]]]] = contextvars.ContextVar(
22
+ "override_specs", default=None
23
+ )
24
+
25
+ # ContextVars actually applied for the specific call once matched
26
+ _param_overrides_ctx: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
27
+ "param_overrides", default=None
28
+ )
29
+ _tool_overrides_ctx: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
30
+ "tool_overrides", default=None
31
+ )
32
+ _current_override_label_ctx: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
33
+ "override_label", default=None
34
+ )
35
+
36
+
37
+ def set_override_specs(specs: List[Dict[str, Any]]):
38
+ if not isinstance(specs, list):
39
+ raise ValueError("override specs must be a list of dicts")
40
+ return _override_specs_ctx.set(specs)
41
+
42
+
43
+ def get_override_specs() -> Optional[List[Dict[str, Any]]]:
44
+ return _override_specs_ctx.get()
45
+
46
+
47
+ def clear_override_specs(token) -> None:
48
+ _override_specs_ctx.reset(token)
49
+
50
+
51
+ def _matches(spec: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
52
+ match = spec.get("match") or {}
53
+ contains = match.get("contains")
54
+ role = match.get("role") # optional
55
+ if not contains:
56
+ # no match criteria means always apply
57
+ return True
58
+ contains_l = str(contains).lower()
59
+ for m in messages:
60
+ if role and m.get("role") != role:
61
+ continue
62
+ c = m.get("content")
63
+ if isinstance(c, str) and contains_l in c.lower():
64
+ return True
65
+ if isinstance(c, list):
66
+ for part in c:
67
+ if part.get("type") == "text" and contains_l in str(part.get("text", "")).lower():
68
+ return True
69
+ return False
70
+
71
+
72
+ def resolve_override_for_messages(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
73
+ specs = get_override_specs() or []
74
+ for spec in specs:
75
+ try:
76
+ if _matches(spec, messages):
77
+ return spec
78
+ except Exception:
79
+ # On matcher errors, skip spec
80
+ continue
81
+ return None
82
+
83
+
84
+ def apply_injection(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
85
+ # Delegate to injection.apply_injection
86
+ return _apply_injection(messages)
87
+
88
+
89
+ def apply_param_overrides(api_params: Dict[str, Any]) -> Dict[str, Any]:
90
+ ov = _param_overrides_ctx.get()
91
+ if not ov:
92
+ return api_params
93
+ # Shallow merge only known keys users provided
94
+ for k, v in ov.items():
95
+ api_params[k] = v
96
+ return api_params
97
+
98
+
99
+ def apply_tool_overrides(api_params: Dict[str, Any]) -> Dict[str, Any]:
100
+ """Apply tool overrides to OpenAI/Anthropic-like api_params in place.
101
+
102
+ Supports keys under spec["tools"]:
103
+ - set_tools: replace tools entirely
104
+ - add_tools: append tools
105
+ - remove_tools_by_name: remove by function name
106
+ - tool_choice: set tool_choice param
107
+ """
108
+ ov = _tool_overrides_ctx.get()
109
+ if not ov:
110
+ return api_params
111
+ tov = ov.get("tools") if isinstance(ov, dict) else None
112
+ if tov:
113
+ tools = api_params.get("tools")
114
+ if "set_tools" in tov:
115
+ tools = tov["set_tools"]
116
+ if "add_tools" in tov:
117
+ tools = (tools or []) + tov["add_tools"]
118
+ if "remove_tools_by_name" in tov and tools:
119
+ names = set(tov["remove_tools_by_name"]) # function names
120
+ new_tools = []
121
+ for t in tools:
122
+ try:
123
+ # OpenAI dict style
124
+ fn = t.get("function", {}).get("name") if isinstance(t, dict) else None
125
+ except Exception:
126
+ fn = None
127
+ # If BaseTool objects slipped through
128
+ if fn is None:
129
+ fn = getattr(t, "function_name", None)
130
+ if fn is None or fn not in names:
131
+ new_tools.append(t)
132
+ tools = new_tools
133
+ if tools is not None:
134
+ api_params["tools"] = tools
135
+ if "tool_choice" in tov:
136
+ api_params["tool_choice"] = tov["tool_choice"]
137
+ return api_params
138
+
139
+
140
+ @contextmanager
141
+ def use_overrides_for_messages(messages: List[Dict[str, Any]]):
142
+ """Resolve an override spec against messages and apply its contexts within the scope.
143
+
144
+ - Sets injection rules and param overrides if present on the matched spec.
145
+ - Yields, then resets ContextVars to previous values.
146
+ """
147
+ spec = resolve_override_for_messages(messages) or {}
148
+ inj_rules = spec.get("injection_rules")
149
+ params = spec.get("params")
150
+ inj_tok = None
151
+ param_tok = None
152
+ tool_tok = None
153
+ label_tok = None
154
+ try:
155
+ if inj_rules:
156
+ inj_tok = set_injection_rules(inj_rules)
157
+ if params:
158
+ param_tok = _param_overrides_ctx.set(params)
159
+ tools = spec.get("tools")
160
+ if tools:
161
+ tool_tok = _tool_overrides_ctx.set({"tools": tools})
162
+ lbl = spec.get("label")
163
+ if lbl:
164
+ label_tok = _current_override_label_ctx.set(str(lbl))
165
+ yield
166
+ finally:
167
+ if inj_tok is not None:
168
+ clear_injection_rules(inj_tok)
169
+ if param_tok is not None:
170
+ _param_overrides_ctx.reset(param_tok)
171
+ if tool_tok is not None:
172
+ _tool_overrides_ctx.reset(tool_tok)
173
+ if label_tok is not None:
174
+ _current_override_label_ctx.reset(label_tok)
175
+
176
+
177
+ def get_current_override_label() -> Optional[str]:
178
+ return _current_override_label_ctx.get()
179
+
180
+
181
+ class LMOverridesContext:
182
+ """Context manager to register per-call override specs.
183
+
184
+ Usage:
185
+ with LMOverridesContext([
186
+ {"match": {"contains": "atm", "role": "user"}, "injection_rules": [...], "params": {...}},
187
+ {"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
188
+ ]):
189
+ run_pipeline()
190
+ """
191
+
192
+ def __init__(self, override_specs: Optional[List[Dict[str, Any]]] | Dict[str, Any] = None):
193
+ if isinstance(override_specs, dict):
194
+ override_specs = [override_specs]
195
+ self._specs = override_specs or []
196
+ self._tok = None
197
+
198
+ def __enter__(self):
199
+ self._tok = set_override_specs(self._specs)
200
+ return self
201
+
202
+ def __exit__(self, exc_type, exc, tb):
203
+ if self._tok is not None:
204
+ clear_override_specs(self._tok)