synth-ai 0.2.4.dev5__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. synth_ai/environments/examples/crafter_classic/engine.py +8 -4
  2. synth_ai/environments/examples/wordle/__init__.py +29 -0
  3. synth_ai/environments/examples/wordle/engine.py +391 -0
  4. synth_ai/environments/examples/wordle/environment.py +154 -0
  5. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  6. synth_ai/environments/examples/wordle/taskset.py +222 -0
  7. synth_ai/environments/service/app.py +8 -0
  8. synth_ai/environments/service/core_routes.py +38 -0
  9. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  10. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  11. synth_ai/learning/prompts/mipro.py +273 -1
  12. synth_ai/learning/prompts/random_search.py +247 -0
  13. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  14. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  15. synth_ai/lm/injection.py +81 -0
  16. synth_ai/lm/overrides.py +204 -0
  17. synth_ai/lm/provider_support/anthropic.py +39 -12
  18. synth_ai/lm/provider_support/openai.py +31 -4
  19. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  20. synth_ai/lm/vendors/openai_standard.py +35 -5
  21. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  22. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +26 -14
  23. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  24. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  25. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  26. {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,222 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import List, Tuple
8
+ from uuid import uuid4, UUID
9
+
10
+ from synth_ai.environments.tasks.core import (
11
+ TaskInstance,
12
+ TaskInstanceMetadata,
13
+ TaskInstanceSet,
14
+ Impetus,
15
+ Intent,
16
+ SplitInfo,
17
+ )
18
+
19
+ from .engine import DEFAULT_SOLUTIONS
20
+
21
+
22
+ @dataclass
23
+ class WordleTaskInstanceMetadata(TaskInstanceMetadata):
24
+ word_length: int
25
+ max_guesses: int
26
+ target_word: str
27
+ enforce_wordlist: bool
28
+ seed: int | None = None
29
+ consume_invalid_attempts: bool = True
30
+
31
+
32
+ @dataclass
33
+ class WordleTaskInstance(TaskInstance):
34
+ async def serialize(self) -> dict:
35
+ return {
36
+ "id": str(self.id),
37
+ "impetus": {"instructions": self.impetus.instructions},
38
+ "intent": {
39
+ "rubric": self.intent.rubric,
40
+ "gold_trajectories": self.intent.gold_trajectories,
41
+ "gold_state_diff": self.intent.gold_state_diff,
42
+ },
43
+ "metadata": {
44
+ "word_length": self.metadata.word_length,
45
+ "max_guesses": self.metadata.max_guesses,
46
+ "target_word": self.metadata.target_word,
47
+ "enforce_wordlist": self.metadata.enforce_wordlist,
48
+ "seed": self.metadata.seed,
49
+ "consume_invalid_attempts": self.metadata.consume_invalid_attempts,
50
+ },
51
+ "is_reproducible": self.is_reproducible,
52
+ "initial_engine_snapshot": self.initial_engine_snapshot,
53
+ }
54
+
55
+ @classmethod
56
+ async def deserialize(cls, data: dict) -> "WordleTaskInstance":
57
+ from uuid import UUID
58
+
59
+ metadata = WordleTaskInstanceMetadata(
60
+ word_length=data["metadata"]["word_length"],
61
+ max_guesses=data["metadata"]["max_guesses"],
62
+ target_word=data["metadata"]["target_word"],
63
+ enforce_wordlist=data["metadata"]["enforce_wordlist"],
64
+ seed=data["metadata"].get("seed"),
65
+ consume_invalid_attempts=data["metadata"].get("consume_invalid_attempts", True),
66
+ )
67
+
68
+ return cls(
69
+ id=UUID(data["id"]),
70
+ impetus=Impetus(instructions=data["impetus"]["instructions"]),
71
+ intent=Intent(
72
+ rubric=data["intent"]["rubric"],
73
+ gold_trajectories=data["intent"]["gold_trajectories"],
74
+ gold_state_diff=data["intent"]["gold_state_diff"],
75
+ ),
76
+ metadata=metadata,
77
+ is_reproducible=data["is_reproducible"],
78
+ initial_engine_snapshot=data["initial_engine_snapshot"],
79
+ )
80
+
81
+
82
+ def _stable_uuid_for_instance(idx: int, target: str) -> UUID:
83
+ import uuid
84
+ return uuid.uuid5(uuid.NAMESPACE_URL, f"wordle-fixed-v1:{idx}:{target}")
85
+
86
+
87
+ def _load_fixed_instances_json() -> tuple[List[dict], dict]:
88
+ """Load fixed instances definition from instances.json (if present).
89
+
90
+ Returns a tuple (instances, defaults) where instances is a list of dicts with at least
91
+ target_word fields, and defaults contains default params.
92
+ """
93
+ import os
94
+ # Allow override via env var
95
+ override = os.getenv("WORDLE_INSTANCES_JSON")
96
+ p = Path(override) if override else Path(__file__).with_name("instances.json")
97
+ if not p.exists():
98
+ return [], {}
99
+ try:
100
+ data = json.loads(p.read_text())
101
+ defaults = data.get("defaults", {}) or {}
102
+ insts = data.get("instances", []) or []
103
+ return insts, defaults
104
+ except Exception:
105
+ return [], {}
106
+
107
+
108
+ # Note: generation helpers removed from runtime. Use the provided script in tools/
109
+ _ = None
110
+
111
+
112
+ async def create_wordle_taskset(
113
+ *,
114
+ word_length: int = 5,
115
+ max_guesses: int = 6,
116
+ enforce_wordlist: bool = False,
117
+ sample_size: int = 30,
118
+ consume_invalid_attempts: bool = True,
119
+ ) -> TaskInstanceSet:
120
+ """Create a Wordle taskset.
121
+
122
+ Priority:
123
+ 1) If instances.json exists, use it to produce a fixed, stable taskset with deterministic IDs.
124
+ 2) Otherwise, fall back to a procedural slice of DEFAULT_SOLUTIONS (stable ordering).
125
+ """
126
+
127
+ json_insts, json_defaults = _load_fixed_instances_json()
128
+
129
+ instances: List[WordleTaskInstance] = []
130
+ # Assemble fixed targets from JSON only (no runtime generation)
131
+ fixed_targets: List[str] = []
132
+ if json_insts:
133
+ fixed_targets.extend([str(r.get("target_word", "")).strip().lower() for r in json_insts if r.get("target_word")])
134
+
135
+ if fixed_targets:
136
+ # Use fixed_targets, honoring defaults and slicing by sample_size
137
+ chosen = fixed_targets[:sample_size]
138
+ for i, tgt in enumerate(chosen):
139
+ md = WordleTaskInstanceMetadata(
140
+ word_length=int(word_length),
141
+ max_guesses=int(max_guesses),
142
+ target_word=tgt,
143
+ enforce_wordlist=bool(enforce_wordlist),
144
+ seed=i,
145
+ consume_invalid_attempts=bool(consume_invalid_attempts),
146
+ )
147
+ impetus = Impetus(
148
+ instructions=(
149
+ "Play Wordle. Submit one word per turn consisting only of letters. "
150
+ f"You have up to {md.max_guesses} guesses to find the {md.word_length}-letter target word. "
151
+ "Feedback per letter: G=correct position, Y=present elsewhere, B=absent."
152
+ )
153
+ )
154
+ intent = Intent(
155
+ rubric={"goal": "Guess the target word in as few moves as possible"},
156
+ gold_trajectories=None,
157
+ gold_state_diff={"target_known": False},
158
+ )
159
+ inst = WordleTaskInstance(
160
+ id=_stable_uuid_for_instance(i, md.target_word),
161
+ impetus=impetus,
162
+ intent=intent,
163
+ metadata=md,
164
+ is_reproducible=True,
165
+ initial_engine_snapshot=None,
166
+ )
167
+ instances.append(inst)
168
+ else:
169
+ # Procedural fallback: stable ordering from DEFAULT_SOLUTIONS
170
+ pool = [w for w in DEFAULT_SOLUTIONS if len(w) == word_length] or [w for w in DEFAULT_SOLUTIONS if len(w) == 5]
171
+ sample = pool[:sample_size]
172
+ for i, target in enumerate(sample):
173
+ seed = i
174
+ md = WordleTaskInstanceMetadata(
175
+ word_length=word_length,
176
+ max_guesses=max_guesses,
177
+ target_word=target,
178
+ enforce_wordlist=enforce_wordlist,
179
+ seed=seed,
180
+ consume_invalid_attempts=consume_invalid_attempts,
181
+ )
182
+ impetus = Impetus(
183
+ instructions=(
184
+ "Play Wordle. Submit one word per turn consisting only of letters. "
185
+ f"You have up to {max_guesses} guesses to find the {word_length}-letter target word. "
186
+ "Feedback per letter: G=correct position, Y=present elsewhere, B=absent."
187
+ )
188
+ )
189
+ intent = Intent(
190
+ rubric={"goal": "Guess the target word in as few moves as possible"},
191
+ gold_trajectories=None,
192
+ gold_state_diff={"target_known": False},
193
+ )
194
+ inst = WordleTaskInstance(
195
+ id=_stable_uuid_for_instance(i, target),
196
+ impetus=impetus,
197
+ intent=intent,
198
+ metadata=md,
199
+ is_reproducible=True,
200
+ initial_engine_snapshot=None,
201
+ )
202
+ instances.append(inst)
203
+
204
+ # Deterministic split based on index positions
205
+ val_ids = {instances[i].id for i in range(0, len(instances), 5)}
206
+ test_ids = {instances[i].id for i in range(0, len(instances), 7)}
207
+ split = SplitInfo(val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True)
208
+
209
+ return TaskInstanceSet(
210
+ name="Wordle Fixed TaskSet" if json_insts else "Wordle Example TaskSet",
211
+ description=(
212
+ "Fixed set from instances.json (stable ordering)."
213
+ if json_insts
214
+ else "Lightweight Wordle tasks with fixed targets and seeds."
215
+ ),
216
+ instances=instances,
217
+ split_info=split,
218
+ )
219
+
220
+
221
+ # Alias
222
+ taskset = create_wordle_taskset
@@ -38,6 +38,14 @@ import synth_ai.environments.examples.crafter_custom.environment as ccustom
38
38
 
39
39
  register_environment("CrafterCustom", ccustom.CrafterCustomEnvironment)
40
40
 
41
+ # Register Wordle example environment
42
+ try:
43
+ import synth_ai.environments.examples.wordle.environment as wordle_mod
44
+ register_environment("Wordle", wordle_mod.WordleEnvironment)
45
+ except Exception as _e:
46
+ # Keep service robust even if example env import fails
47
+ logging.getLogger(__name__).warning(f"Wordle env not registered: {_e}")
48
+
41
49
  app = FastAPI(title="Environment Service")
42
50
 
43
51
 
@@ -7,6 +7,7 @@ import os
7
7
  import json
8
8
  import pickle
9
9
  import base64
10
+ from io import BytesIO
10
11
  import numpy as np
11
12
  import tempfile
12
13
  from dataclasses import dataclass
@@ -727,6 +728,43 @@ async def terminate_env(env_name: str, request: TerminateRequest = Body(...)) ->
727
728
  raise HTTPException(status_code=400, detail=str(e))
728
729
 
729
730
 
731
+ @api_router.get("/env/{env_name}/frame")
732
+ async def get_env_frame(env_name: str, env_id: str) -> Dict[str, Any]:
733
+ """Return the current rendered frame of the environment as base64 PNG.
734
+
735
+ This provides a lightweight way for clients to capture before/after snapshots
736
+ around steps without modifying core step responses.
737
+ """
738
+ env = await storage.get(env_id)
739
+ if not env:
740
+ raise HTTPException(status_code=404, detail=f"Environment instance {env_id} not found")
741
+
742
+ try:
743
+ # For CrafterClassic, underlying engine exposes env.render() -> RGB ndarray
744
+ if hasattr(env, "engine") and hasattr(env.engine, "env") and hasattr(env.engine.env, "render"):
745
+ rgb = env.engine.env.render()
746
+ else:
747
+ raise RuntimeError("Environment does not support render()")
748
+
749
+ if rgb is None:
750
+ raise RuntimeError("render() returned None")
751
+
752
+ # Encode to PNG base64
753
+ try:
754
+ from PIL import Image # type: ignore
755
+ img = Image.fromarray(rgb.astype("uint8"), "RGB")
756
+ buf = BytesIO()
757
+ img.save(buf, format="PNG")
758
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
759
+ except Exception as e:
760
+ raise RuntimeError(f"failed to encode frame: {e}")
761
+
762
+ return {"env_id": env_id, "image_base64": b64}
763
+ except Exception as e:
764
+ logger.error(f"Error rendering frame for {env_id}: {e}")
765
+ raise HTTPException(status_code=500, detail=str(e))
766
+
767
+
730
768
  @api_router.get("/env/{env_name}/metadata")
731
769
  async def get_env_metadata(env_name: str, env_id: str) -> Dict[str, Any]:
732
770
  """Get metadata about an environment instance."""
@@ -0,0 +1,163 @@
1
+ """
2
+ Banking77 in-context injection evals (async, not tests)
3
+
4
+ Samples a handful of Banking77 prompts and evaluates multiple override
5
+ contexts in parallel, printing simple accuracy for each.
6
+
7
+ Usage
8
+ - Keys in .env (GROQ_API_KEY, etc.)
9
+ - Run: uv run -q python -m synth_ai.learning.prompts.banking77_injection_eval
10
+ Optional env:
11
+ - N_SAMPLES=20 (default)
12
+ - MODEL=openai/gpt-oss-20b (default)
13
+ - VENDOR=groq (default)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import os
20
+ import random
21
+ from typing import List, Dict, Any, Tuple
22
+
23
+ from dotenv import load_dotenv
24
+ from datasets import load_dataset
25
+
26
+ from synth_ai.lm.core.main_v3 import LM, build_messages
27
+ from synth_ai.lm.overrides import LMOverridesContext
28
+
29
+
30
+ async def classify_one(lm: LM, text: str, label_names: List[str]) -> str:
31
+ labels_joined = ", ".join(label_names)
32
+ system_message = (
33
+ "You are an intent classifier for the Banking77 dataset. "
34
+ "Given a customer message, respond with exactly one label from the list. "
35
+ "Return only the label text with no extra words.\n\n"
36
+ f"Valid labels: {labels_joined}"
37
+ )
38
+ user_message = f"Message: {text}\nLabel:"
39
+ messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
40
+ resp = await lm.respond_async(messages=messages)
41
+ return (resp.raw_response or "").strip()
42
+
43
+
44
+ def choose_label(pred: str, label_names: List[str]) -> str:
45
+ norm_pred = pred.strip().lower()
46
+ label_lookup = {ln.lower(): ln for ln in label_names}
47
+ mapped = label_lookup.get(norm_pred)
48
+ if mapped is not None:
49
+ return mapped
50
+
51
+ # Fallback: choose the label with the highest naive token overlap
52
+ def score(cand: str) -> int:
53
+ c = cand.lower()
54
+ return sum(1 for w in c.split() if w in norm_pred)
55
+
56
+ return max(label_names, key=score)
57
+
58
+
59
+ async def eval_context(lm: LM, items: List[Tuple[str, str]], label_names: List[str], ctx_name: str, specs: List[Dict[str, Any]]) -> Tuple[str, int, int]:
60
+ correct = 0
61
+ with LMOverridesContext(specs):
62
+ tasks = [classify_one(lm, text, label_names) for text, _ in items]
63
+ results = await asyncio.gather(*tasks, return_exceptions=True)
64
+ for (text, gold), pred in zip(items, results):
65
+ if isinstance(pred, Exception):
66
+ # Treat exceptions as incorrect
67
+ continue
68
+ mapped = choose_label(pred, label_names)
69
+ correct += int(mapped == gold)
70
+ return (ctx_name, correct, len(items))
71
+
72
+
73
+ async def main() -> None:
74
+ load_dotenv()
75
+
76
+ n = int(os.getenv("N_SAMPLES", "20"))
77
+ model = os.getenv("MODEL", "openai/gpt-oss-20b")
78
+ vendor = os.getenv("VENDOR", "groq")
79
+
80
+ lm = LM(model=model, vendor=vendor, temperature=0.0)
81
+
82
+ print("Loading Banking77 dataset (split='test')...")
83
+ ds = load_dataset("banking77", split="test")
84
+ label_names: List[str] = ds.features["label"].names # type: ignore
85
+
86
+ idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
87
+ items = [
88
+ (ds[i]["text"], label_names[int(ds[i]["label"])]) # (text, gold_label)
89
+ for i in idxs
90
+ ]
91
+
92
+ # Define a few override contexts to compare
93
+ contexts: List[Dict[str, Any]] = [
94
+ {
95
+ "name": "baseline (no overrides)",
96
+ "overrides": [],
97
+ },
98
+ {
99
+ "name": "nonsense prompt injection (expected worse)",
100
+ "overrides": [
101
+ {
102
+ "match": {"contains": "", "role": "user"},
103
+ "injection_rules": [
104
+ # Heavily corrupt user text by replacing vowels
105
+ {"find": "a", "replace": "x"},
106
+ {"find": "e", "replace": "x"},
107
+ {"find": "i", "replace": "x"},
108
+ {"find": "o", "replace": "x"},
109
+ {"find": "u", "replace": "x"},
110
+ {"find": "A", "replace": "X"},
111
+ {"find": "E", "replace": "X"},
112
+ {"find": "I", "replace": "X"},
113
+ {"find": "O", "replace": "X"},
114
+ {"find": "U", "replace": "X"},
115
+ ],
116
+ }
117
+ ],
118
+ },
119
+ {
120
+ "name": "injection: atm->ATM, txn->transaction",
121
+ "overrides": [
122
+ {
123
+ "match": {"contains": "atm", "role": "user"},
124
+ "injection_rules": [
125
+ {"find": "atm", "replace": "ATM"},
126
+ {"find": "txn", "replace": "transaction"},
127
+ ],
128
+ }
129
+ ],
130
+ },
131
+ {
132
+ "name": "params: temperature=0.0",
133
+ "overrides": [
134
+ {"match": {"contains": ""}, "params": {"temperature": 0.0}},
135
+ ],
136
+ },
137
+ {
138
+ "name": "model override: 20b->120b",
139
+ "overrides": [
140
+ {"match": {"contains": ""}, "params": {"model": "openai/gpt-oss-120b"}},
141
+ ],
142
+ },
143
+ ]
144
+
145
+ print(f"\nEvaluating {len(contexts)} contexts on {len(items)} Banking77 samples (async)...")
146
+
147
+ # Evaluate each context sequentially but batched (each context classifies in parallel)
148
+ results: List[Tuple[str, int, int]] = []
149
+ for ctx in contexts:
150
+ name = ctx["name"]
151
+ specs = ctx["overrides"]
152
+ print(f"Evaluating: {name} ...")
153
+ res = await eval_context(lm, items, label_names, name, specs)
154
+ results.append(res)
155
+
156
+ print("\nResults:")
157
+ for name, correct, total in results:
158
+ acc = correct / total if total else 0.0
159
+ print(f"- {name}: {correct}/{total} correct ({acc:.2%})")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ asyncio.run(main())
@@ -0,0 +1,201 @@
1
+ """
2
+ Hello World: Banking77 intent classification with in-context injection
3
+
4
+ This script shows a minimal text-classification pipeline over the
5
+ Hugging Face Banking77 dataset using the Synth LM interface. It also
6
+ demonstrates a simple pre-send prompt-injection step as outlined in
7
+ `synth_ai/learning/prompts/injection_plan.txt`.
8
+
9
+ Notes
10
+ - Network access is required to download the dataset and call the model.
11
+ - Defaults to Groq with model `openai/gpt-oss-20b`.
12
+ - Export your key: `export GROQ_API_KEY=...`
13
+ - Override if needed: `export MODEL=openai/gpt-oss-20b VENDOR=groq`
14
+
15
+ Run
16
+ - `python -m synth_ai.learning.prompts.hello_world_in_context_injection_ex`
17
+
18
+ What "in-context injection" means here
19
+ - The script applies ordered substring replacements to the outgoing
20
+ `messages` array before calling the model. This mirrors the algorithm
21
+ described in `injection_plan.txt` without importing any non-existent
22
+ helper yet. You can adapt `INJECTION_RULES` to your needs.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import asyncio
28
+ import os
29
+ import random
30
+ from typing import Any, Dict, List, Optional
31
+
32
+ from datasets import load_dataset
33
+
34
+ # Use the v3 LM class present in this repo
35
+ from synth_ai.lm.core.main_v3 import LM, build_messages
36
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
37
+ from synth_ai.tracing_v3.abstractions import LMCAISEvent
38
+
39
+
40
+ # Use Overrides context to demonstrate matching by content
41
+ from synth_ai.lm.overrides import LMOverridesContext
42
+ INJECTION_RULES = [
43
+ {"find": "accnt", "replace": "account"},
44
+ {"find": "atm", "replace": "ATM"},
45
+ {"find": "txn", "replace": "transaction"},
46
+ ]
47
+
48
+
49
+ async def classify_sample(lm: LM, text: str, label_names: List[str]) -> str:
50
+ """Classify one Banking77 utterance and return the predicted label name."""
51
+ labels_joined = ", ".join(label_names)
52
+ system_message = (
53
+ "You are an intent classifier for the Banking77 dataset. "
54
+ "Given a customer message, respond with exactly one label from the list. "
55
+ "Return only the label text with no extra words.\n\n"
56
+ f"Valid labels: {labels_joined}"
57
+ )
58
+ user_message = f"Message: {text}\nLabel:"
59
+
60
+ # Build canonical messages; injection will be applied inside the vendor via context
61
+ messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
62
+ resp = await lm.respond_async(messages=messages)
63
+ raw = (resp.raw_response or "").strip()
64
+ return raw
65
+
66
+
67
+ async def main() -> None:
68
+ # Configurable model/provider via env, with sensible defaults
69
+ # Default to Groq hosting `openai/gpt-oss-20b`
70
+ model = os.getenv("MODEL", "openai/gpt-oss-20b")
71
+ vendor = os.getenv("VENDOR", "groq")
72
+
73
+ # Construct LM
74
+ lm = LM(model=model, vendor=vendor, temperature=0.0)
75
+
76
+ # Load Banking77 dataset
77
+ # Columns: {"text": str, "label": int}; label names at ds.features["label"].names
78
+ print("Loading Banking77 dataset (split='test')...")
79
+ ds = load_dataset("banking77", split="test")
80
+ label_names: List[str] = ds.features["label"].names # type: ignore
81
+
82
+ # Sample a few items for a quick demo
83
+ n = int(os.getenv("N_SAMPLES", "8"))
84
+ idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
85
+
86
+ correct = 0
87
+ # Apply overrides for all calls in this block (match by content)
88
+ overrides = [
89
+ {"match": {"contains": "atm", "role": "user"}, "injection_rules": INJECTION_RULES},
90
+ {"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
91
+ ]
92
+ with LMOverridesContext(overrides):
93
+ for i, idx in enumerate(idxs, start=1):
94
+ text: str = ds[idx]["text"] # type: ignore
95
+ gold_label_idx: int = int(ds[idx]["label"]) # type: ignore
96
+ gold_label = label_names[gold_label_idx]
97
+
98
+ try:
99
+ pred = await classify_sample(lm, text, label_names)
100
+ except Exception as e:
101
+ print(f"[{i}] Error calling model: {e}")
102
+ break
103
+
104
+ # Normalize and check exact match; if not exact, attempt a loose fallback
105
+ norm_pred = pred.strip().lower()
106
+ label_lookup = {ln.lower(): ln for ln in label_names}
107
+ pred_label = label_lookup.get(norm_pred)
108
+ if pred_label is None:
109
+ # Fallback: pick the label with highest substring overlap (very naive)
110
+ # This avoids extra deps; feel free to replace with a better matcher.
111
+ def score(cand: str) -> int:
112
+ c = cand.lower()
113
+ return sum(1 for w in c.split() if w in norm_pred)
114
+
115
+ pred_label = max(label_names, key=score)
116
+
117
+ is_correct = pred_label == gold_label
118
+ correct += int(is_correct)
119
+ print(f"[{i}] text={text!r}\n gold={gold_label}\n pred={pred} -> mapped={pred_label} {'✅' if is_correct else '❌'}")
120
+
121
+ if idxs:
122
+ acc = correct / len(idxs)
123
+ print(f"\nSamples: {len(idxs)} | Correct: {correct} | Accuracy: {acc:.2%}")
124
+
125
+ # ------------------------------
126
+ # Integration tests (three paths)
127
+ # ------------------------------
128
+ print("\nRunning integration tests with in-context injection...")
129
+ test_text = "I used the atm to withdraw cash."
130
+
131
+ # 1) LM path with v3 tracing: verify substitution in traced messages
132
+ tracer = SessionTracer()
133
+ await tracer.start_session(metadata={"test": "lm_injection"})
134
+ await tracer.start_timestep(step_id="lm_test")
135
+ # Use a tracer-bound LM instance
136
+ lm_traced = LM(model=model, vendor=vendor, temperature=0.0, session_tracer=tracer)
137
+ with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
138
+ _ = await classify_sample(lm_traced, test_text, label_names)
139
+ # inspect trace
140
+ events = [e for e in (tracer.current_session.event_history if tracer.current_session else []) if isinstance(e, LMCAISEvent)]
141
+ assert events, "No LMCAISEvent recorded by SessionTracer"
142
+ cr = events[-1].call_records[0]
143
+ traced_user = ""
144
+ for m in cr.input_messages:
145
+ if m.role == "user":
146
+ for part in m.parts:
147
+ if getattr(part, "type", None) == "text":
148
+ traced_user += (part.text or "")
149
+ assert "ATM" in traced_user, f"Expected substitution in traced prompt; got: {traced_user!r}"
150
+ print("LM path trace verified: substitution present in traced prompt.")
151
+ await tracer.end_timestep()
152
+ await tracer.end_session()
153
+
154
+ # 2) OpenAI wrapper path (AsyncOpenAI to Groq): ensure apply_injection is active
155
+ try:
156
+ import synth_ai.lm.provider_support.openai as _synth_openai_patch # noqa: F401
157
+ from openai import AsyncOpenAI
158
+ from datasets import load_dataset as _ld # ensure datasets present
159
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")
160
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
161
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key)
162
+ messages = [
163
+ {"role": "system", "content": "Echo user label."},
164
+ {"role": "user", "content": f"Please classify: {test_text}"},
165
+ ]
166
+ with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
167
+ resp = await client.chat.completions.create(model=model, messages=messages, temperature=0)
168
+ # Not all models echo input; instead, verify that our injected expectation matches
169
+ expected_user = messages[1]["content"].replace("atm", "ATM")
170
+ if messages[1]["content"] == expected_user:
171
+ print("OpenAI wrapper: input already normalized; skipping assertion.")
172
+ else:
173
+ print("OpenAI wrapper: sent message contains substitution expectation:", expected_user)
174
+ except Exception as e:
175
+ print("OpenAI wrapper test skipped due to error:", e)
176
+
177
+ # 3) Anthropic wrapper path (AsyncClient): ensure apply_injection is active
178
+ try:
179
+ import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
180
+ import anthropic
181
+ a_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022")
182
+ a_key = os.getenv("ANTHROPIC_API_KEY")
183
+ if a_key:
184
+ a_client = anthropic.AsyncClient(api_key=a_key)
185
+ with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
186
+ _ = await a_client.messages.create(
187
+ model=a_model,
188
+ system="Echo user label.",
189
+ max_tokens=64,
190
+ temperature=0,
191
+ messages=[{"role": "user", "content": [{"type": "text", "text": test_text}]}],
192
+ )
193
+ print("Anthropic wrapper call completed (cannot reliably assert echo).")
194
+ else:
195
+ print("Anthropic wrapper test skipped: ANTHROPIC_API_KEY not set.")
196
+ except Exception as e:
197
+ print("Anthropic wrapper test skipped due to error:", e)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ asyncio.run(main())