synth-ai 0.2.4.dev5__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/crafter_classic/engine.py +8 -4
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +26 -14
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Example: Random Search optimizer on Banking77 using Groq gpt-oss-20b.
|
3
|
+
|
4
|
+
Requires:
|
5
|
+
- .env with GROQ_API_KEY
|
6
|
+
- datasets (`uv add datasets` if needed)
|
7
|
+
|
8
|
+
Run:
|
9
|
+
- uv run -q python -m synth_ai.learning.prompts.run_random_search_banking77
|
10
|
+
"""
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
import os
|
16
|
+
import random
|
17
|
+
from dataclasses import dataclass, replace
|
18
|
+
from types import SimpleNamespace
|
19
|
+
from tqdm import tqdm
|
20
|
+
from typing import Any, Dict, List, Sequence, Tuple
|
21
|
+
|
22
|
+
from dotenv import load_dotenv
|
23
|
+
from datasets import load_dataset
|
24
|
+
|
25
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
26
|
+
import json
|
27
|
+
import time
|
28
|
+
from pathlib import Path
|
29
|
+
from synth_ai.learning.prompts.random_search import random_search_compile
|
30
|
+
|
31
|
+
|
32
|
+
def choose_label(pred: str, label_names: List[str]) -> str:
|
33
|
+
norm = (pred or "").strip().lower()
|
34
|
+
d = {ln.lower(): ln for ln in label_names}
|
35
|
+
if norm in d:
|
36
|
+
return d[norm]
|
37
|
+
def score(cand: str) -> int:
|
38
|
+
c = cand.lower()
|
39
|
+
return sum(1 for w in c.split() if w in norm)
|
40
|
+
return max(label_names, key=score)
|
41
|
+
|
42
|
+
|
43
|
+
def accuracy(pred: str, gold: str, labels: List[str]) -> float:
|
44
|
+
return 1.0 if choose_label(pred, labels) == gold else 0.0
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class StudentProgram:
|
49
|
+
lm: LM
|
50
|
+
label_names: List[str]
|
51
|
+
instruction: str
|
52
|
+
demos: List[Tuple[str, str]]
|
53
|
+
|
54
|
+
def reset_copy(self):
|
55
|
+
return replace(self, instruction=self.instruction, demos=list(self.demos))
|
56
|
+
|
57
|
+
def deepcopy(self):
|
58
|
+
return replace(self, instruction=str(self.instruction), demos=list(self.demos))
|
59
|
+
|
60
|
+
def with_demos(self, demos: List[Tuple[str, str]]):
|
61
|
+
return replace(self, demos=list(demos))
|
62
|
+
|
63
|
+
def run(self, x: str) -> str:
|
64
|
+
# Build a prompt with optional demos
|
65
|
+
examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
|
66
|
+
sys = self.instruction or "You are an intent classifier for Banking77."
|
67
|
+
user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
|
68
|
+
messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
|
69
|
+
# Call LM synchronously via asyncio
|
70
|
+
async def _call():
|
71
|
+
resp = await self.lm.respond_async(messages=messages)
|
72
|
+
return (resp.raw_response or "").strip()
|
73
|
+
return asyncio.run(_call())
|
74
|
+
|
75
|
+
async def _apredict(self, x: str):
|
76
|
+
examples = "\n".join(f"Input: {a}\nLabel: {b}" for a, b in self.demos)
|
77
|
+
sys = self.instruction or "You are an intent classifier for Banking77."
|
78
|
+
user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
|
79
|
+
messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
|
80
|
+
resp = await self.lm.respond_async(messages=messages)
|
81
|
+
return (resp.raw_response or "").strip(), (resp.usage or {})
|
82
|
+
|
83
|
+
|
84
|
+
def main():
|
85
|
+
load_dotenv()
|
86
|
+
random.seed(0)
|
87
|
+
|
88
|
+
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
89
|
+
vendor = os.getenv("VENDOR", "groq")
|
90
|
+
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
91
|
+
|
92
|
+
print("Loading Banking77 dataset (train/dev split of test for demo)...")
|
93
|
+
ds = load_dataset("banking77")
|
94
|
+
label_names: List[str] = ds["test"].features["label"].names # type: ignore
|
95
|
+
|
96
|
+
# Create small train/val from the test split for speed
|
97
|
+
all_items = [(r["text"], label_names[int(r["label"])]) for r in ds["test"]]
|
98
|
+
random.shuffle(all_items)
|
99
|
+
trainset: Sequence[Tuple[str, str]] = all_items[:40]
|
100
|
+
valset: Sequence[Tuple[str, str]] = all_items[40:60] # 20 examples
|
101
|
+
|
102
|
+
student = StudentProgram(
|
103
|
+
lm=lm,
|
104
|
+
label_names=label_names,
|
105
|
+
instruction="You are an intent classifier for the Banking77 dataset. Return exactly one label.",
|
106
|
+
demos=[],
|
107
|
+
)
|
108
|
+
|
109
|
+
def metric(yhat: str, y: str) -> float:
|
110
|
+
return accuracy(yhat, y, label_names)
|
111
|
+
|
112
|
+
total_candidates = 3 + 3 # zero-shot, labeled few-shot, bootstrapped + 3 random seeds
|
113
|
+
print(f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)...")
|
114
|
+
|
115
|
+
def eval_parallel(program: StudentProgram, dataset: Sequence[Tuple[str, str]], metric_fn):
|
116
|
+
async def _run():
|
117
|
+
xs = [x for x, _ in dataset]
|
118
|
+
ys = [y for _, y in dataset]
|
119
|
+
preds: List[Optional[str]] = [None] * len(xs)
|
120
|
+
sem = asyncio.Semaphore(int(os.getenv("CONCURRENCY", "5")))
|
121
|
+
|
122
|
+
async def worker(i: int, x: str, y: str):
|
123
|
+
import time
|
124
|
+
t_start = time.monotonic()
|
125
|
+
try:
|
126
|
+
async with sem:
|
127
|
+
pred, usage = await asyncio.wait_for(
|
128
|
+
program._apredict(x),
|
129
|
+
timeout=float(os.getenv("TIMEOUT_S", "45")),
|
130
|
+
)
|
131
|
+
t_end = time.monotonic()
|
132
|
+
return i, y, pred, t_start, t_end, usage or {}
|
133
|
+
except asyncio.CancelledError:
|
134
|
+
# Respect cancellation but return a placeholder record so scheduler can proceed
|
135
|
+
t_end = time.monotonic()
|
136
|
+
return i, y, "", t_start, t_end, {}
|
137
|
+
except Exception:
|
138
|
+
t_end = time.monotonic()
|
139
|
+
return i, y, "", t_start, t_end, {}
|
140
|
+
|
141
|
+
tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys))]
|
142
|
+
correct_sum = 0.0
|
143
|
+
processed = 0
|
144
|
+
import time, statistics
|
145
|
+
durations: List[float] = []
|
146
|
+
in_tok_sum = 0
|
147
|
+
out_tok_sum = 0
|
148
|
+
in_tok_count = 0
|
149
|
+
out_tok_count = 0
|
150
|
+
details: List[Dict[str, Any]] = []
|
151
|
+
t_batch_start = time.monotonic()
|
152
|
+
deadline = float(os.getenv("BATCH_DEADLINE_S", "20"))
|
153
|
+
with tqdm(total=len(tasks), desc="Rollouts", leave=False) as pbar:
|
154
|
+
pending = set(tasks)
|
155
|
+
# Process completions until all done or deadline reached
|
156
|
+
while pending:
|
157
|
+
elapsed = time.monotonic() - t_batch_start
|
158
|
+
remaining = max(0.0, deadline - elapsed)
|
159
|
+
if remaining <= 0.0:
|
160
|
+
# Cancel any remaining
|
161
|
+
for t in pending:
|
162
|
+
t.cancel()
|
163
|
+
done, _ = await asyncio.wait(pending, return_when=asyncio.ALL_COMPLETED)
|
164
|
+
# Record canceled as zeros
|
165
|
+
for task in done:
|
166
|
+
try:
|
167
|
+
i, y_true, pred, t_start, t_end, usage = task.result()
|
168
|
+
except Exception:
|
169
|
+
# Unknown index: we can't recover; skip as it's canceled before start
|
170
|
+
continue
|
171
|
+
# Already processed ones shouldn't be in pending; skip
|
172
|
+
break
|
173
|
+
# Wait for at least one completion within remaining time (polling granularity <= 1s)
|
174
|
+
timeout = min(1.0, remaining)
|
175
|
+
done, pending = await asyncio.wait(pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
|
176
|
+
for task in done:
|
177
|
+
try:
|
178
|
+
i, y_true, pred, t_start, t_end, usage = task.result()
|
179
|
+
except BaseException:
|
180
|
+
# Treat as failure/cancelled
|
181
|
+
continue
|
182
|
+
durations.append(max(0.0, t_end - t_start))
|
183
|
+
preds[i] = pred
|
184
|
+
processed += 1
|
185
|
+
try:
|
186
|
+
correct_sum += float(metric_fn(pred, y_true))
|
187
|
+
except Exception:
|
188
|
+
pass
|
189
|
+
try:
|
190
|
+
pt = usage.get("prompt_tokens") or usage.get("input_tokens")
|
191
|
+
ct = usage.get("completion_tokens") or usage.get("output_tokens")
|
192
|
+
if isinstance(pt, (int, float)):
|
193
|
+
in_tok_sum += int(pt)
|
194
|
+
in_tok_count += 1
|
195
|
+
if isinstance(ct, (int, float)):
|
196
|
+
out_tok_sum += int(ct)
|
197
|
+
out_tok_count += 1
|
198
|
+
except Exception:
|
199
|
+
pass
|
200
|
+
details.append({
|
201
|
+
"index": i,
|
202
|
+
"seconds": max(0.0, t_end - t_start),
|
203
|
+
"score": float(metric_fn(pred, y_true)),
|
204
|
+
"usage": {
|
205
|
+
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens"),
|
206
|
+
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens"),
|
207
|
+
},
|
208
|
+
})
|
209
|
+
pbar.update(1)
|
210
|
+
med = statistics.median(durations) if durations else 0.0
|
211
|
+
mx = max(durations) if durations else 0.0
|
212
|
+
avg_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
|
213
|
+
avg_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
|
214
|
+
pbar.set_postfix({
|
215
|
+
"acc": f"{(correct_sum/processed):.2f}",
|
216
|
+
"done": f"{processed}/{len(tasks)}",
|
217
|
+
"med_s": f"{med:.1f}",
|
218
|
+
"max_s": f"{mx:.1f}",
|
219
|
+
"tin": f"{avg_in:.1f}",
|
220
|
+
"tout": f"{avg_out:.1f}",
|
221
|
+
})
|
222
|
+
# Compute score only from completed/successful rollouts (drop timeouts/cancelled)
|
223
|
+
subs = [float(d.get("score", 0.0)) for d in details]
|
224
|
+
result = SimpleNamespace(score=(sum(subs) / max(1, len(subs))), subscores=subs)
|
225
|
+
result.details = details
|
226
|
+
result.mean_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
|
227
|
+
result.mean_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
|
228
|
+
return result
|
229
|
+
return asyncio.run(_run())
|
230
|
+
pbar = tqdm(total=total_candidates, desc="Candidates")
|
231
|
+
candidate_eval_details: Dict[int, Any] = {}
|
232
|
+
def on_cand(idx: int, score: float, res, intervention):
|
233
|
+
pbar.update(1)
|
234
|
+
pbar.set_postfix({"score": f"{score:.2f}"})
|
235
|
+
# store per-instance details (for apples-to-apples)
|
236
|
+
try:
|
237
|
+
candidate_eval_details[idx] = {
|
238
|
+
"score": score,
|
239
|
+
"mean_in": getattr(res, "mean_in", None),
|
240
|
+
"mean_out": getattr(res, "mean_out", None),
|
241
|
+
"instances": getattr(res, "details", None),
|
242
|
+
}
|
243
|
+
except Exception:
|
244
|
+
pass
|
245
|
+
# visible summary line per candidate
|
246
|
+
kind = intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
|
247
|
+
label = intervention.get("label") if isinstance(intervention, dict) else None
|
248
|
+
seed = intervention.get("seed") if isinstance(intervention, dict) else None
|
249
|
+
processed = len(getattr(res, "details", []) or [])
|
250
|
+
from tqdm import tqdm as _tqdm
|
251
|
+
_tqdm.write(
|
252
|
+
f"Candidate {idx}/{total_candidates} [{kind}{'' if label is None else f', label={label}'}{'' if seed is None else f', seed={seed}'}]: "
|
253
|
+
f"score={score:.2f} | mean tin/tout={getattr(res, 'mean_in', 0):.1f}/{getattr(res, 'mean_out', 0):.1f} | N={processed}"
|
254
|
+
)
|
255
|
+
|
256
|
+
best, records = random_search_compile(
|
257
|
+
student=student,
|
258
|
+
trainset=trainset,
|
259
|
+
valset=valset,
|
260
|
+
metric=metric,
|
261
|
+
evaluate_fn=eval_parallel,
|
262
|
+
max_bootstrapped_demos=0,
|
263
|
+
max_labeled_demos=4,
|
264
|
+
max_rounds=2,
|
265
|
+
num_candidate_programs=3,
|
266
|
+
on_candidate_evaluated=on_cand,
|
267
|
+
)
|
268
|
+
pbar.close()
|
269
|
+
|
270
|
+
# Evaluate best on holdout (valset) with parallel rollouts
|
271
|
+
print("Evaluating best program on val (parallel rollouts)...")
|
272
|
+
best_res = eval_parallel(best, valset, metric)
|
273
|
+
correct = int(round(best_res.score * max(1, len(best_res.subscores))))
|
274
|
+
print(
|
275
|
+
"Best program accuracy on val: "
|
276
|
+
f"{correct}/{len(valset)} ({best_res.score:.2%}) "
|
277
|
+
f"| mean tokens in/out: {getattr(best_res, 'mean_in', 0):.1f}/{getattr(best_res, 'mean_out', 0):.1f}"
|
278
|
+
)
|
279
|
+
|
280
|
+
# Save per-candidate scores and interventions
|
281
|
+
out = {
|
282
|
+
"context": {
|
283
|
+
"model": model,
|
284
|
+
"vendor": vendor,
|
285
|
+
"train_size": len(trainset),
|
286
|
+
"val_size": len(valset),
|
287
|
+
},
|
288
|
+
"candidates": records,
|
289
|
+
"candidate_eval_details": candidate_eval_details,
|
290
|
+
"best_eval_details": {
|
291
|
+
"score": best_res.score,
|
292
|
+
"mean_in": getattr(best_res, "mean_in", None),
|
293
|
+
"mean_out": getattr(best_res, "mean_out", None),
|
294
|
+
"instances": getattr(best_res, "details", None),
|
295
|
+
},
|
296
|
+
}
|
297
|
+
out_dir = Path(__file__).parent
|
298
|
+
fname = str(out_dir / f"random_search_banking77_{int(time.time())}.json")
|
299
|
+
with open(fname, "w") as f:
|
300
|
+
json.dump(out, f, indent=2)
|
301
|
+
print(f"Saved candidate records to {fname}")
|
302
|
+
|
303
|
+
|
304
|
+
if __name__ == "__main__":
|
305
|
+
main()
|
synth_ai/lm/injection.py
ADDED
@@ -0,0 +1,81 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextvars
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any, Dict, List, Optional
|
6
|
+
|
7
|
+
Rule = Dict[str, Any]
|
8
|
+
|
9
|
+
_rules_ctx: contextvars.ContextVar[Optional[List[Rule]]] = contextvars.ContextVar(
|
10
|
+
"injection_rules", default=None
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
def set_injection_rules(rules: List[Rule]):
|
15
|
+
"""Set prompt-injection rules for the current context and return a reset token.
|
16
|
+
|
17
|
+
Each rule must be a dict with at least keys: "find" and "replace" (strings).
|
18
|
+
Optional: "roles" as a list of role names to scope the replacement.
|
19
|
+
"""
|
20
|
+
if not isinstance(rules, list) or not all(
|
21
|
+
isinstance(r, dict) and "find" in r and "replace" in r for r in rules
|
22
|
+
):
|
23
|
+
raise ValueError("Injection rules must be a list of dicts with 'find' and 'replace'")
|
24
|
+
return _rules_ctx.set(rules)
|
25
|
+
|
26
|
+
|
27
|
+
def get_injection_rules() -> Optional[List[Rule]]:
|
28
|
+
"""Get the current context's injection rules, if any."""
|
29
|
+
return _rules_ctx.get()
|
30
|
+
|
31
|
+
|
32
|
+
def clear_injection_rules(token) -> None:
|
33
|
+
"""Reset the injection rules to the previous value using the provided token."""
|
34
|
+
_rules_ctx.reset(token)
|
35
|
+
|
36
|
+
|
37
|
+
def apply_injection(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
38
|
+
"""Apply ordered substring replacements to text parts of messages in place.
|
39
|
+
|
40
|
+
- Only modifies `str` content or list parts where `part["type"] == "text"`.
|
41
|
+
- Honors optional `roles` scoping in each rule.
|
42
|
+
- Returns the input list for convenience.
|
43
|
+
"""
|
44
|
+
rules = get_injection_rules()
|
45
|
+
if not rules:
|
46
|
+
return messages
|
47
|
+
|
48
|
+
for m in messages:
|
49
|
+
role = m.get("role")
|
50
|
+
content = m.get("content")
|
51
|
+
if isinstance(content, str):
|
52
|
+
new_content = content
|
53
|
+
for r in rules:
|
54
|
+
allowed_roles = r.get("roles")
|
55
|
+
if allowed_roles is not None and role not in allowed_roles:
|
56
|
+
continue
|
57
|
+
new_content = new_content.replace(str(r["find"]), str(r["replace"]))
|
58
|
+
m["content"] = new_content
|
59
|
+
elif isinstance(content, list):
|
60
|
+
for part in content:
|
61
|
+
if part.get("type") == "text":
|
62
|
+
text = part.get("text", "")
|
63
|
+
new_text = text
|
64
|
+
for r in rules:
|
65
|
+
allowed_roles = r.get("roles")
|
66
|
+
if allowed_roles is not None and role not in allowed_roles:
|
67
|
+
continue
|
68
|
+
new_text = new_text.replace(str(r["find"]), str(r["replace"]))
|
69
|
+
part["text"] = new_text
|
70
|
+
return messages
|
71
|
+
|
72
|
+
|
73
|
+
@contextmanager
|
74
|
+
def injection_rules_ctx(rules: List[Rule]):
|
75
|
+
"""Context manager to temporarily apply injection rules within the block."""
|
76
|
+
tok = set_injection_rules(rules)
|
77
|
+
try:
|
78
|
+
yield
|
79
|
+
finally:
|
80
|
+
clear_injection_rules(tok)
|
81
|
+
|
synth_ai/lm/overrides.py
ADDED
@@ -0,0 +1,204 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
|
+
import contextvars
|
6
|
+
|
7
|
+
from synth_ai.lm.injection import (
|
8
|
+
set_injection_rules,
|
9
|
+
clear_injection_rules,
|
10
|
+
apply_injection as _apply_injection,
|
11
|
+
)
|
12
|
+
|
13
|
+
# Context to hold a list of override specs to evaluate per-call
|
14
|
+
# Each spec shape (minimal v1):
|
15
|
+
# {
|
16
|
+
# "match": {"contains": "atm", "role": "user" | "system" | None},
|
17
|
+
# "injection_rules": [{"find": str, "replace": str, "roles": Optional[List[str]]}],
|
18
|
+
# "params": { ... api params to override ... },
|
19
|
+
# "tools": { ... optional tools overrides ... },
|
20
|
+
# }
|
21
|
+
_override_specs_ctx: contextvars.ContextVar[Optional[List[Dict[str, Any]]]] = contextvars.ContextVar(
|
22
|
+
"override_specs", default=None
|
23
|
+
)
|
24
|
+
|
25
|
+
# ContextVars actually applied for the specific call once matched
|
26
|
+
_param_overrides_ctx: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
|
27
|
+
"param_overrides", default=None
|
28
|
+
)
|
29
|
+
_tool_overrides_ctx: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
|
30
|
+
"tool_overrides", default=None
|
31
|
+
)
|
32
|
+
_current_override_label_ctx: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
33
|
+
"override_label", default=None
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def set_override_specs(specs: List[Dict[str, Any]]):
|
38
|
+
if not isinstance(specs, list):
|
39
|
+
raise ValueError("override specs must be a list of dicts")
|
40
|
+
return _override_specs_ctx.set(specs)
|
41
|
+
|
42
|
+
|
43
|
+
def get_override_specs() -> Optional[List[Dict[str, Any]]]:
|
44
|
+
return _override_specs_ctx.get()
|
45
|
+
|
46
|
+
|
47
|
+
def clear_override_specs(token) -> None:
|
48
|
+
_override_specs_ctx.reset(token)
|
49
|
+
|
50
|
+
|
51
|
+
def _matches(spec: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
|
52
|
+
match = spec.get("match") or {}
|
53
|
+
contains = match.get("contains")
|
54
|
+
role = match.get("role") # optional
|
55
|
+
if not contains:
|
56
|
+
# no match criteria means always apply
|
57
|
+
return True
|
58
|
+
contains_l = str(contains).lower()
|
59
|
+
for m in messages:
|
60
|
+
if role and m.get("role") != role:
|
61
|
+
continue
|
62
|
+
c = m.get("content")
|
63
|
+
if isinstance(c, str) and contains_l in c.lower():
|
64
|
+
return True
|
65
|
+
if isinstance(c, list):
|
66
|
+
for part in c:
|
67
|
+
if part.get("type") == "text" and contains_l in str(part.get("text", "")).lower():
|
68
|
+
return True
|
69
|
+
return False
|
70
|
+
|
71
|
+
|
72
|
+
def resolve_override_for_messages(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
73
|
+
specs = get_override_specs() or []
|
74
|
+
for spec in specs:
|
75
|
+
try:
|
76
|
+
if _matches(spec, messages):
|
77
|
+
return spec
|
78
|
+
except Exception:
|
79
|
+
# On matcher errors, skip spec
|
80
|
+
continue
|
81
|
+
return None
|
82
|
+
|
83
|
+
|
84
|
+
def apply_injection(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
85
|
+
# Delegate to injection.apply_injection
|
86
|
+
return _apply_injection(messages)
|
87
|
+
|
88
|
+
|
89
|
+
def apply_param_overrides(api_params: Dict[str, Any]) -> Dict[str, Any]:
|
90
|
+
ov = _param_overrides_ctx.get()
|
91
|
+
if not ov:
|
92
|
+
return api_params
|
93
|
+
# Shallow merge only known keys users provided
|
94
|
+
for k, v in ov.items():
|
95
|
+
api_params[k] = v
|
96
|
+
return api_params
|
97
|
+
|
98
|
+
|
99
|
+
def apply_tool_overrides(api_params: Dict[str, Any]) -> Dict[str, Any]:
|
100
|
+
"""Apply tool overrides to OpenAI/Anthropic-like api_params in place.
|
101
|
+
|
102
|
+
Supports keys under spec["tools"]:
|
103
|
+
- set_tools: replace tools entirely
|
104
|
+
- add_tools: append tools
|
105
|
+
- remove_tools_by_name: remove by function name
|
106
|
+
- tool_choice: set tool_choice param
|
107
|
+
"""
|
108
|
+
ov = _tool_overrides_ctx.get()
|
109
|
+
if not ov:
|
110
|
+
return api_params
|
111
|
+
tov = ov.get("tools") if isinstance(ov, dict) else None
|
112
|
+
if tov:
|
113
|
+
tools = api_params.get("tools")
|
114
|
+
if "set_tools" in tov:
|
115
|
+
tools = tov["set_tools"]
|
116
|
+
if "add_tools" in tov:
|
117
|
+
tools = (tools or []) + tov["add_tools"]
|
118
|
+
if "remove_tools_by_name" in tov and tools:
|
119
|
+
names = set(tov["remove_tools_by_name"]) # function names
|
120
|
+
new_tools = []
|
121
|
+
for t in tools:
|
122
|
+
try:
|
123
|
+
# OpenAI dict style
|
124
|
+
fn = t.get("function", {}).get("name") if isinstance(t, dict) else None
|
125
|
+
except Exception:
|
126
|
+
fn = None
|
127
|
+
# If BaseTool objects slipped through
|
128
|
+
if fn is None:
|
129
|
+
fn = getattr(t, "function_name", None)
|
130
|
+
if fn is None or fn not in names:
|
131
|
+
new_tools.append(t)
|
132
|
+
tools = new_tools
|
133
|
+
if tools is not None:
|
134
|
+
api_params["tools"] = tools
|
135
|
+
if "tool_choice" in tov:
|
136
|
+
api_params["tool_choice"] = tov["tool_choice"]
|
137
|
+
return api_params
|
138
|
+
|
139
|
+
|
140
|
+
@contextmanager
|
141
|
+
def use_overrides_for_messages(messages: List[Dict[str, Any]]):
|
142
|
+
"""Resolve an override spec against messages and apply its contexts within the scope.
|
143
|
+
|
144
|
+
- Sets injection rules and param overrides if present on the matched spec.
|
145
|
+
- Yields, then resets ContextVars to previous values.
|
146
|
+
"""
|
147
|
+
spec = resolve_override_for_messages(messages) or {}
|
148
|
+
inj_rules = spec.get("injection_rules")
|
149
|
+
params = spec.get("params")
|
150
|
+
inj_tok = None
|
151
|
+
param_tok = None
|
152
|
+
tool_tok = None
|
153
|
+
label_tok = None
|
154
|
+
try:
|
155
|
+
if inj_rules:
|
156
|
+
inj_tok = set_injection_rules(inj_rules)
|
157
|
+
if params:
|
158
|
+
param_tok = _param_overrides_ctx.set(params)
|
159
|
+
tools = spec.get("tools")
|
160
|
+
if tools:
|
161
|
+
tool_tok = _tool_overrides_ctx.set({"tools": tools})
|
162
|
+
lbl = spec.get("label")
|
163
|
+
if lbl:
|
164
|
+
label_tok = _current_override_label_ctx.set(str(lbl))
|
165
|
+
yield
|
166
|
+
finally:
|
167
|
+
if inj_tok is not None:
|
168
|
+
clear_injection_rules(inj_tok)
|
169
|
+
if param_tok is not None:
|
170
|
+
_param_overrides_ctx.reset(param_tok)
|
171
|
+
if tool_tok is not None:
|
172
|
+
_tool_overrides_ctx.reset(tool_tok)
|
173
|
+
if label_tok is not None:
|
174
|
+
_current_override_label_ctx.reset(label_tok)
|
175
|
+
|
176
|
+
|
177
|
+
def get_current_override_label() -> Optional[str]:
|
178
|
+
return _current_override_label_ctx.get()
|
179
|
+
|
180
|
+
|
181
|
+
class LMOverridesContext:
|
182
|
+
"""Context manager to register per-call override specs.
|
183
|
+
|
184
|
+
Usage:
|
185
|
+
with LMOverridesContext([
|
186
|
+
{"match": {"contains": "atm", "role": "user"}, "injection_rules": [...], "params": {...}},
|
187
|
+
{"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
|
188
|
+
]):
|
189
|
+
run_pipeline()
|
190
|
+
"""
|
191
|
+
|
192
|
+
def __init__(self, override_specs: Optional[List[Dict[str, Any]]] | Dict[str, Any] = None):
|
193
|
+
if isinstance(override_specs, dict):
|
194
|
+
override_specs = [override_specs]
|
195
|
+
self._specs = override_specs or []
|
196
|
+
self._tok = None
|
197
|
+
|
198
|
+
def __enter__(self):
|
199
|
+
self._tok = set_override_specs(self._specs)
|
200
|
+
return self
|
201
|
+
|
202
|
+
def __exit__(self, exc_type, exc, tb):
|
203
|
+
if self._tok is not None:
|
204
|
+
clear_override_specs(self._tok)
|