synth-ai 0.2.4.dev5__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/crafter_classic/engine.py +8 -4
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +26 -14
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,222 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import List, Tuple
|
8
|
+
from uuid import uuid4, UUID
|
9
|
+
|
10
|
+
from synth_ai.environments.tasks.core import (
|
11
|
+
TaskInstance,
|
12
|
+
TaskInstanceMetadata,
|
13
|
+
TaskInstanceSet,
|
14
|
+
Impetus,
|
15
|
+
Intent,
|
16
|
+
SplitInfo,
|
17
|
+
)
|
18
|
+
|
19
|
+
from .engine import DEFAULT_SOLUTIONS
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class WordleTaskInstanceMetadata(TaskInstanceMetadata):
|
24
|
+
word_length: int
|
25
|
+
max_guesses: int
|
26
|
+
target_word: str
|
27
|
+
enforce_wordlist: bool
|
28
|
+
seed: int | None = None
|
29
|
+
consume_invalid_attempts: bool = True
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class WordleTaskInstance(TaskInstance):
|
34
|
+
async def serialize(self) -> dict:
|
35
|
+
return {
|
36
|
+
"id": str(self.id),
|
37
|
+
"impetus": {"instructions": self.impetus.instructions},
|
38
|
+
"intent": {
|
39
|
+
"rubric": self.intent.rubric,
|
40
|
+
"gold_trajectories": self.intent.gold_trajectories,
|
41
|
+
"gold_state_diff": self.intent.gold_state_diff,
|
42
|
+
},
|
43
|
+
"metadata": {
|
44
|
+
"word_length": self.metadata.word_length,
|
45
|
+
"max_guesses": self.metadata.max_guesses,
|
46
|
+
"target_word": self.metadata.target_word,
|
47
|
+
"enforce_wordlist": self.metadata.enforce_wordlist,
|
48
|
+
"seed": self.metadata.seed,
|
49
|
+
"consume_invalid_attempts": self.metadata.consume_invalid_attempts,
|
50
|
+
},
|
51
|
+
"is_reproducible": self.is_reproducible,
|
52
|
+
"initial_engine_snapshot": self.initial_engine_snapshot,
|
53
|
+
}
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
async def deserialize(cls, data: dict) -> "WordleTaskInstance":
|
57
|
+
from uuid import UUID
|
58
|
+
|
59
|
+
metadata = WordleTaskInstanceMetadata(
|
60
|
+
word_length=data["metadata"]["word_length"],
|
61
|
+
max_guesses=data["metadata"]["max_guesses"],
|
62
|
+
target_word=data["metadata"]["target_word"],
|
63
|
+
enforce_wordlist=data["metadata"]["enforce_wordlist"],
|
64
|
+
seed=data["metadata"].get("seed"),
|
65
|
+
consume_invalid_attempts=data["metadata"].get("consume_invalid_attempts", True),
|
66
|
+
)
|
67
|
+
|
68
|
+
return cls(
|
69
|
+
id=UUID(data["id"]),
|
70
|
+
impetus=Impetus(instructions=data["impetus"]["instructions"]),
|
71
|
+
intent=Intent(
|
72
|
+
rubric=data["intent"]["rubric"],
|
73
|
+
gold_trajectories=data["intent"]["gold_trajectories"],
|
74
|
+
gold_state_diff=data["intent"]["gold_state_diff"],
|
75
|
+
),
|
76
|
+
metadata=metadata,
|
77
|
+
is_reproducible=data["is_reproducible"],
|
78
|
+
initial_engine_snapshot=data["initial_engine_snapshot"],
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def _stable_uuid_for_instance(idx: int, target: str) -> UUID:
|
83
|
+
import uuid
|
84
|
+
return uuid.uuid5(uuid.NAMESPACE_URL, f"wordle-fixed-v1:{idx}:{target}")
|
85
|
+
|
86
|
+
|
87
|
+
def _load_fixed_instances_json() -> tuple[List[dict], dict]:
|
88
|
+
"""Load fixed instances definition from instances.json (if present).
|
89
|
+
|
90
|
+
Returns a tuple (instances, defaults) where instances is a list of dicts with at least
|
91
|
+
target_word fields, and defaults contains default params.
|
92
|
+
"""
|
93
|
+
import os
|
94
|
+
# Allow override via env var
|
95
|
+
override = os.getenv("WORDLE_INSTANCES_JSON")
|
96
|
+
p = Path(override) if override else Path(__file__).with_name("instances.json")
|
97
|
+
if not p.exists():
|
98
|
+
return [], {}
|
99
|
+
try:
|
100
|
+
data = json.loads(p.read_text())
|
101
|
+
defaults = data.get("defaults", {}) or {}
|
102
|
+
insts = data.get("instances", []) or []
|
103
|
+
return insts, defaults
|
104
|
+
except Exception:
|
105
|
+
return [], {}
|
106
|
+
|
107
|
+
|
108
|
+
# Note: generation helpers removed from runtime. Use the provided script in tools/
|
109
|
+
_ = None
|
110
|
+
|
111
|
+
|
112
|
+
async def create_wordle_taskset(
|
113
|
+
*,
|
114
|
+
word_length: int = 5,
|
115
|
+
max_guesses: int = 6,
|
116
|
+
enforce_wordlist: bool = False,
|
117
|
+
sample_size: int = 30,
|
118
|
+
consume_invalid_attempts: bool = True,
|
119
|
+
) -> TaskInstanceSet:
|
120
|
+
"""Create a Wordle taskset.
|
121
|
+
|
122
|
+
Priority:
|
123
|
+
1) If instances.json exists, use it to produce a fixed, stable taskset with deterministic IDs.
|
124
|
+
2) Otherwise, fall back to a procedural slice of DEFAULT_SOLUTIONS (stable ordering).
|
125
|
+
"""
|
126
|
+
|
127
|
+
json_insts, json_defaults = _load_fixed_instances_json()
|
128
|
+
|
129
|
+
instances: List[WordleTaskInstance] = []
|
130
|
+
# Assemble fixed targets from JSON only (no runtime generation)
|
131
|
+
fixed_targets: List[str] = []
|
132
|
+
if json_insts:
|
133
|
+
fixed_targets.extend([str(r.get("target_word", "")).strip().lower() for r in json_insts if r.get("target_word")])
|
134
|
+
|
135
|
+
if fixed_targets:
|
136
|
+
# Use fixed_targets, honoring defaults and slicing by sample_size
|
137
|
+
chosen = fixed_targets[:sample_size]
|
138
|
+
for i, tgt in enumerate(chosen):
|
139
|
+
md = WordleTaskInstanceMetadata(
|
140
|
+
word_length=int(word_length),
|
141
|
+
max_guesses=int(max_guesses),
|
142
|
+
target_word=tgt,
|
143
|
+
enforce_wordlist=bool(enforce_wordlist),
|
144
|
+
seed=i,
|
145
|
+
consume_invalid_attempts=bool(consume_invalid_attempts),
|
146
|
+
)
|
147
|
+
impetus = Impetus(
|
148
|
+
instructions=(
|
149
|
+
"Play Wordle. Submit one word per turn consisting only of letters. "
|
150
|
+
f"You have up to {md.max_guesses} guesses to find the {md.word_length}-letter target word. "
|
151
|
+
"Feedback per letter: G=correct position, Y=present elsewhere, B=absent."
|
152
|
+
)
|
153
|
+
)
|
154
|
+
intent = Intent(
|
155
|
+
rubric={"goal": "Guess the target word in as few moves as possible"},
|
156
|
+
gold_trajectories=None,
|
157
|
+
gold_state_diff={"target_known": False},
|
158
|
+
)
|
159
|
+
inst = WordleTaskInstance(
|
160
|
+
id=_stable_uuid_for_instance(i, md.target_word),
|
161
|
+
impetus=impetus,
|
162
|
+
intent=intent,
|
163
|
+
metadata=md,
|
164
|
+
is_reproducible=True,
|
165
|
+
initial_engine_snapshot=None,
|
166
|
+
)
|
167
|
+
instances.append(inst)
|
168
|
+
else:
|
169
|
+
# Procedural fallback: stable ordering from DEFAULT_SOLUTIONS
|
170
|
+
pool = [w for w in DEFAULT_SOLUTIONS if len(w) == word_length] or [w for w in DEFAULT_SOLUTIONS if len(w) == 5]
|
171
|
+
sample = pool[:sample_size]
|
172
|
+
for i, target in enumerate(sample):
|
173
|
+
seed = i
|
174
|
+
md = WordleTaskInstanceMetadata(
|
175
|
+
word_length=word_length,
|
176
|
+
max_guesses=max_guesses,
|
177
|
+
target_word=target,
|
178
|
+
enforce_wordlist=enforce_wordlist,
|
179
|
+
seed=seed,
|
180
|
+
consume_invalid_attempts=consume_invalid_attempts,
|
181
|
+
)
|
182
|
+
impetus = Impetus(
|
183
|
+
instructions=(
|
184
|
+
"Play Wordle. Submit one word per turn consisting only of letters. "
|
185
|
+
f"You have up to {max_guesses} guesses to find the {word_length}-letter target word. "
|
186
|
+
"Feedback per letter: G=correct position, Y=present elsewhere, B=absent."
|
187
|
+
)
|
188
|
+
)
|
189
|
+
intent = Intent(
|
190
|
+
rubric={"goal": "Guess the target word in as few moves as possible"},
|
191
|
+
gold_trajectories=None,
|
192
|
+
gold_state_diff={"target_known": False},
|
193
|
+
)
|
194
|
+
inst = WordleTaskInstance(
|
195
|
+
id=_stable_uuid_for_instance(i, target),
|
196
|
+
impetus=impetus,
|
197
|
+
intent=intent,
|
198
|
+
metadata=md,
|
199
|
+
is_reproducible=True,
|
200
|
+
initial_engine_snapshot=None,
|
201
|
+
)
|
202
|
+
instances.append(inst)
|
203
|
+
|
204
|
+
# Deterministic split based on index positions
|
205
|
+
val_ids = {instances[i].id for i in range(0, len(instances), 5)}
|
206
|
+
test_ids = {instances[i].id for i in range(0, len(instances), 7)}
|
207
|
+
split = SplitInfo(val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True)
|
208
|
+
|
209
|
+
return TaskInstanceSet(
|
210
|
+
name="Wordle Fixed TaskSet" if json_insts else "Wordle Example TaskSet",
|
211
|
+
description=(
|
212
|
+
"Fixed set from instances.json (stable ordering)."
|
213
|
+
if json_insts
|
214
|
+
else "Lightweight Wordle tasks with fixed targets and seeds."
|
215
|
+
),
|
216
|
+
instances=instances,
|
217
|
+
split_info=split,
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
# Alias
|
222
|
+
taskset = create_wordle_taskset
|
@@ -38,6 +38,14 @@ import synth_ai.environments.examples.crafter_custom.environment as ccustom
|
|
38
38
|
|
39
39
|
register_environment("CrafterCustom", ccustom.CrafterCustomEnvironment)
|
40
40
|
|
41
|
+
# Register Wordle example environment
|
42
|
+
try:
|
43
|
+
import synth_ai.environments.examples.wordle.environment as wordle_mod
|
44
|
+
register_environment("Wordle", wordle_mod.WordleEnvironment)
|
45
|
+
except Exception as _e:
|
46
|
+
# Keep service robust even if example env import fails
|
47
|
+
logging.getLogger(__name__).warning(f"Wordle env not registered: {_e}")
|
48
|
+
|
41
49
|
app = FastAPI(title="Environment Service")
|
42
50
|
|
43
51
|
|
@@ -7,6 +7,7 @@ import os
|
|
7
7
|
import json
|
8
8
|
import pickle
|
9
9
|
import base64
|
10
|
+
from io import BytesIO
|
10
11
|
import numpy as np
|
11
12
|
import tempfile
|
12
13
|
from dataclasses import dataclass
|
@@ -727,6 +728,43 @@ async def terminate_env(env_name: str, request: TerminateRequest = Body(...)) ->
|
|
727
728
|
raise HTTPException(status_code=400, detail=str(e))
|
728
729
|
|
729
730
|
|
731
|
+
@api_router.get("/env/{env_name}/frame")
|
732
|
+
async def get_env_frame(env_name: str, env_id: str) -> Dict[str, Any]:
|
733
|
+
"""Return the current rendered frame of the environment as base64 PNG.
|
734
|
+
|
735
|
+
This provides a lightweight way for clients to capture before/after snapshots
|
736
|
+
around steps without modifying core step responses.
|
737
|
+
"""
|
738
|
+
env = await storage.get(env_id)
|
739
|
+
if not env:
|
740
|
+
raise HTTPException(status_code=404, detail=f"Environment instance {env_id} not found")
|
741
|
+
|
742
|
+
try:
|
743
|
+
# For CrafterClassic, underlying engine exposes env.render() -> RGB ndarray
|
744
|
+
if hasattr(env, "engine") and hasattr(env.engine, "env") and hasattr(env.engine.env, "render"):
|
745
|
+
rgb = env.engine.env.render()
|
746
|
+
else:
|
747
|
+
raise RuntimeError("Environment does not support render()")
|
748
|
+
|
749
|
+
if rgb is None:
|
750
|
+
raise RuntimeError("render() returned None")
|
751
|
+
|
752
|
+
# Encode to PNG base64
|
753
|
+
try:
|
754
|
+
from PIL import Image # type: ignore
|
755
|
+
img = Image.fromarray(rgb.astype("uint8"), "RGB")
|
756
|
+
buf = BytesIO()
|
757
|
+
img.save(buf, format="PNG")
|
758
|
+
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
759
|
+
except Exception as e:
|
760
|
+
raise RuntimeError(f"failed to encode frame: {e}")
|
761
|
+
|
762
|
+
return {"env_id": env_id, "image_base64": b64}
|
763
|
+
except Exception as e:
|
764
|
+
logger.error(f"Error rendering frame for {env_id}: {e}")
|
765
|
+
raise HTTPException(status_code=500, detail=str(e))
|
766
|
+
|
767
|
+
|
730
768
|
@api_router.get("/env/{env_name}/metadata")
|
731
769
|
async def get_env_metadata(env_name: str, env_id: str) -> Dict[str, Any]:
|
732
770
|
"""Get metadata about an environment instance."""
|
@@ -0,0 +1,163 @@
|
|
1
|
+
"""
|
2
|
+
Banking77 in-context injection evals (async, not tests)
|
3
|
+
|
4
|
+
Samples a handful of Banking77 prompts and evaluates multiple override
|
5
|
+
contexts in parallel, printing simple accuracy for each.
|
6
|
+
|
7
|
+
Usage
|
8
|
+
- Keys in .env (GROQ_API_KEY, etc.)
|
9
|
+
- Run: uv run -q python -m synth_ai.learning.prompts.banking77_injection_eval
|
10
|
+
Optional env:
|
11
|
+
- N_SAMPLES=20 (default)
|
12
|
+
- MODEL=openai/gpt-oss-20b (default)
|
13
|
+
- VENDOR=groq (default)
|
14
|
+
"""
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import asyncio
|
19
|
+
import os
|
20
|
+
import random
|
21
|
+
from typing import List, Dict, Any, Tuple
|
22
|
+
|
23
|
+
from dotenv import load_dotenv
|
24
|
+
from datasets import load_dataset
|
25
|
+
|
26
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
27
|
+
from synth_ai.lm.overrides import LMOverridesContext
|
28
|
+
|
29
|
+
|
30
|
+
async def classify_one(lm: LM, text: str, label_names: List[str]) -> str:
|
31
|
+
labels_joined = ", ".join(label_names)
|
32
|
+
system_message = (
|
33
|
+
"You are an intent classifier for the Banking77 dataset. "
|
34
|
+
"Given a customer message, respond with exactly one label from the list. "
|
35
|
+
"Return only the label text with no extra words.\n\n"
|
36
|
+
f"Valid labels: {labels_joined}"
|
37
|
+
)
|
38
|
+
user_message = f"Message: {text}\nLabel:"
|
39
|
+
messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
|
40
|
+
resp = await lm.respond_async(messages=messages)
|
41
|
+
return (resp.raw_response or "").strip()
|
42
|
+
|
43
|
+
|
44
|
+
def choose_label(pred: str, label_names: List[str]) -> str:
|
45
|
+
norm_pred = pred.strip().lower()
|
46
|
+
label_lookup = {ln.lower(): ln for ln in label_names}
|
47
|
+
mapped = label_lookup.get(norm_pred)
|
48
|
+
if mapped is not None:
|
49
|
+
return mapped
|
50
|
+
|
51
|
+
# Fallback: choose the label with the highest naive token overlap
|
52
|
+
def score(cand: str) -> int:
|
53
|
+
c = cand.lower()
|
54
|
+
return sum(1 for w in c.split() if w in norm_pred)
|
55
|
+
|
56
|
+
return max(label_names, key=score)
|
57
|
+
|
58
|
+
|
59
|
+
async def eval_context(lm: LM, items: List[Tuple[str, str]], label_names: List[str], ctx_name: str, specs: List[Dict[str, Any]]) -> Tuple[str, int, int]:
|
60
|
+
correct = 0
|
61
|
+
with LMOverridesContext(specs):
|
62
|
+
tasks = [classify_one(lm, text, label_names) for text, _ in items]
|
63
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
64
|
+
for (text, gold), pred in zip(items, results):
|
65
|
+
if isinstance(pred, Exception):
|
66
|
+
# Treat exceptions as incorrect
|
67
|
+
continue
|
68
|
+
mapped = choose_label(pred, label_names)
|
69
|
+
correct += int(mapped == gold)
|
70
|
+
return (ctx_name, correct, len(items))
|
71
|
+
|
72
|
+
|
73
|
+
async def main() -> None:
|
74
|
+
load_dotenv()
|
75
|
+
|
76
|
+
n = int(os.getenv("N_SAMPLES", "20"))
|
77
|
+
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
78
|
+
vendor = os.getenv("VENDOR", "groq")
|
79
|
+
|
80
|
+
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
81
|
+
|
82
|
+
print("Loading Banking77 dataset (split='test')...")
|
83
|
+
ds = load_dataset("banking77", split="test")
|
84
|
+
label_names: List[str] = ds.features["label"].names # type: ignore
|
85
|
+
|
86
|
+
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
87
|
+
items = [
|
88
|
+
(ds[i]["text"], label_names[int(ds[i]["label"])]) # (text, gold_label)
|
89
|
+
for i in idxs
|
90
|
+
]
|
91
|
+
|
92
|
+
# Define a few override contexts to compare
|
93
|
+
contexts: List[Dict[str, Any]] = [
|
94
|
+
{
|
95
|
+
"name": "baseline (no overrides)",
|
96
|
+
"overrides": [],
|
97
|
+
},
|
98
|
+
{
|
99
|
+
"name": "nonsense prompt injection (expected worse)",
|
100
|
+
"overrides": [
|
101
|
+
{
|
102
|
+
"match": {"contains": "", "role": "user"},
|
103
|
+
"injection_rules": [
|
104
|
+
# Heavily corrupt user text by replacing vowels
|
105
|
+
{"find": "a", "replace": "x"},
|
106
|
+
{"find": "e", "replace": "x"},
|
107
|
+
{"find": "i", "replace": "x"},
|
108
|
+
{"find": "o", "replace": "x"},
|
109
|
+
{"find": "u", "replace": "x"},
|
110
|
+
{"find": "A", "replace": "X"},
|
111
|
+
{"find": "E", "replace": "X"},
|
112
|
+
{"find": "I", "replace": "X"},
|
113
|
+
{"find": "O", "replace": "X"},
|
114
|
+
{"find": "U", "replace": "X"},
|
115
|
+
],
|
116
|
+
}
|
117
|
+
],
|
118
|
+
},
|
119
|
+
{
|
120
|
+
"name": "injection: atm->ATM, txn->transaction",
|
121
|
+
"overrides": [
|
122
|
+
{
|
123
|
+
"match": {"contains": "atm", "role": "user"},
|
124
|
+
"injection_rules": [
|
125
|
+
{"find": "atm", "replace": "ATM"},
|
126
|
+
{"find": "txn", "replace": "transaction"},
|
127
|
+
],
|
128
|
+
}
|
129
|
+
],
|
130
|
+
},
|
131
|
+
{
|
132
|
+
"name": "params: temperature=0.0",
|
133
|
+
"overrides": [
|
134
|
+
{"match": {"contains": ""}, "params": {"temperature": 0.0}},
|
135
|
+
],
|
136
|
+
},
|
137
|
+
{
|
138
|
+
"name": "model override: 20b->120b",
|
139
|
+
"overrides": [
|
140
|
+
{"match": {"contains": ""}, "params": {"model": "openai/gpt-oss-120b"}},
|
141
|
+
],
|
142
|
+
},
|
143
|
+
]
|
144
|
+
|
145
|
+
print(f"\nEvaluating {len(contexts)} contexts on {len(items)} Banking77 samples (async)...")
|
146
|
+
|
147
|
+
# Evaluate each context sequentially but batched (each context classifies in parallel)
|
148
|
+
results: List[Tuple[str, int, int]] = []
|
149
|
+
for ctx in contexts:
|
150
|
+
name = ctx["name"]
|
151
|
+
specs = ctx["overrides"]
|
152
|
+
print(f"Evaluating: {name} ...")
|
153
|
+
res = await eval_context(lm, items, label_names, name, specs)
|
154
|
+
results.append(res)
|
155
|
+
|
156
|
+
print("\nResults:")
|
157
|
+
for name, correct, total in results:
|
158
|
+
acc = correct / total if total else 0.0
|
159
|
+
print(f"- {name}: {correct}/{total} correct ({acc:.2%})")
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == "__main__":
|
163
|
+
asyncio.run(main())
|
@@ -0,0 +1,201 @@
|
|
1
|
+
"""
|
2
|
+
Hello World: Banking77 intent classification with in-context injection
|
3
|
+
|
4
|
+
This script shows a minimal text-classification pipeline over the
|
5
|
+
Hugging Face Banking77 dataset using the Synth LM interface. It also
|
6
|
+
demonstrates a simple pre-send prompt-injection step as outlined in
|
7
|
+
`synth_ai/learning/prompts/injection_plan.txt`.
|
8
|
+
|
9
|
+
Notes
|
10
|
+
- Network access is required to download the dataset and call the model.
|
11
|
+
- Defaults to Groq with model `openai/gpt-oss-20b`.
|
12
|
+
- Export your key: `export GROQ_API_KEY=...`
|
13
|
+
- Override if needed: `export MODEL=openai/gpt-oss-20b VENDOR=groq`
|
14
|
+
|
15
|
+
Run
|
16
|
+
- `python -m synth_ai.learning.prompts.hello_world_in_context_injection_ex`
|
17
|
+
|
18
|
+
What "in-context injection" means here
|
19
|
+
- The script applies ordered substring replacements to the outgoing
|
20
|
+
`messages` array before calling the model. This mirrors the algorithm
|
21
|
+
described in `injection_plan.txt` without importing any non-existent
|
22
|
+
helper yet. You can adapt `INJECTION_RULES` to your needs.
|
23
|
+
"""
|
24
|
+
|
25
|
+
from __future__ import annotations
|
26
|
+
|
27
|
+
import asyncio
|
28
|
+
import os
|
29
|
+
import random
|
30
|
+
from typing import Any, Dict, List, Optional
|
31
|
+
|
32
|
+
from datasets import load_dataset
|
33
|
+
|
34
|
+
# Use the v3 LM class present in this repo
|
35
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
36
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
37
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
38
|
+
|
39
|
+
|
40
|
+
# Use Overrides context to demonstrate matching by content
|
41
|
+
from synth_ai.lm.overrides import LMOverridesContext
|
42
|
+
INJECTION_RULES = [
|
43
|
+
{"find": "accnt", "replace": "account"},
|
44
|
+
{"find": "atm", "replace": "ATM"},
|
45
|
+
{"find": "txn", "replace": "transaction"},
|
46
|
+
]
|
47
|
+
|
48
|
+
|
49
|
+
async def classify_sample(lm: LM, text: str, label_names: List[str]) -> str:
|
50
|
+
"""Classify one Banking77 utterance and return the predicted label name."""
|
51
|
+
labels_joined = ", ".join(label_names)
|
52
|
+
system_message = (
|
53
|
+
"You are an intent classifier for the Banking77 dataset. "
|
54
|
+
"Given a customer message, respond with exactly one label from the list. "
|
55
|
+
"Return only the label text with no extra words.\n\n"
|
56
|
+
f"Valid labels: {labels_joined}"
|
57
|
+
)
|
58
|
+
user_message = f"Message: {text}\nLabel:"
|
59
|
+
|
60
|
+
# Build canonical messages; injection will be applied inside the vendor via context
|
61
|
+
messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
|
62
|
+
resp = await lm.respond_async(messages=messages)
|
63
|
+
raw = (resp.raw_response or "").strip()
|
64
|
+
return raw
|
65
|
+
|
66
|
+
|
67
|
+
async def main() -> None:
|
68
|
+
# Configurable model/provider via env, with sensible defaults
|
69
|
+
# Default to Groq hosting `openai/gpt-oss-20b`
|
70
|
+
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
71
|
+
vendor = os.getenv("VENDOR", "groq")
|
72
|
+
|
73
|
+
# Construct LM
|
74
|
+
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
75
|
+
|
76
|
+
# Load Banking77 dataset
|
77
|
+
# Columns: {"text": str, "label": int}; label names at ds.features["label"].names
|
78
|
+
print("Loading Banking77 dataset (split='test')...")
|
79
|
+
ds = load_dataset("banking77", split="test")
|
80
|
+
label_names: List[str] = ds.features["label"].names # type: ignore
|
81
|
+
|
82
|
+
# Sample a few items for a quick demo
|
83
|
+
n = int(os.getenv("N_SAMPLES", "8"))
|
84
|
+
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
85
|
+
|
86
|
+
correct = 0
|
87
|
+
# Apply overrides for all calls in this block (match by content)
|
88
|
+
overrides = [
|
89
|
+
{"match": {"contains": "atm", "role": "user"}, "injection_rules": INJECTION_RULES},
|
90
|
+
{"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
|
91
|
+
]
|
92
|
+
with LMOverridesContext(overrides):
|
93
|
+
for i, idx in enumerate(idxs, start=1):
|
94
|
+
text: str = ds[idx]["text"] # type: ignore
|
95
|
+
gold_label_idx: int = int(ds[idx]["label"]) # type: ignore
|
96
|
+
gold_label = label_names[gold_label_idx]
|
97
|
+
|
98
|
+
try:
|
99
|
+
pred = await classify_sample(lm, text, label_names)
|
100
|
+
except Exception as e:
|
101
|
+
print(f"[{i}] Error calling model: {e}")
|
102
|
+
break
|
103
|
+
|
104
|
+
# Normalize and check exact match; if not exact, attempt a loose fallback
|
105
|
+
norm_pred = pred.strip().lower()
|
106
|
+
label_lookup = {ln.lower(): ln for ln in label_names}
|
107
|
+
pred_label = label_lookup.get(norm_pred)
|
108
|
+
if pred_label is None:
|
109
|
+
# Fallback: pick the label with highest substring overlap (very naive)
|
110
|
+
# This avoids extra deps; feel free to replace with a better matcher.
|
111
|
+
def score(cand: str) -> int:
|
112
|
+
c = cand.lower()
|
113
|
+
return sum(1 for w in c.split() if w in norm_pred)
|
114
|
+
|
115
|
+
pred_label = max(label_names, key=score)
|
116
|
+
|
117
|
+
is_correct = pred_label == gold_label
|
118
|
+
correct += int(is_correct)
|
119
|
+
print(f"[{i}] text={text!r}\n gold={gold_label}\n pred={pred} -> mapped={pred_label} {'✅' if is_correct else '❌'}")
|
120
|
+
|
121
|
+
if idxs:
|
122
|
+
acc = correct / len(idxs)
|
123
|
+
print(f"\nSamples: {len(idxs)} | Correct: {correct} | Accuracy: {acc:.2%}")
|
124
|
+
|
125
|
+
# ------------------------------
|
126
|
+
# Integration tests (three paths)
|
127
|
+
# ------------------------------
|
128
|
+
print("\nRunning integration tests with in-context injection...")
|
129
|
+
test_text = "I used the atm to withdraw cash."
|
130
|
+
|
131
|
+
# 1) LM path with v3 tracing: verify substitution in traced messages
|
132
|
+
tracer = SessionTracer()
|
133
|
+
await tracer.start_session(metadata={"test": "lm_injection"})
|
134
|
+
await tracer.start_timestep(step_id="lm_test")
|
135
|
+
# Use a tracer-bound LM instance
|
136
|
+
lm_traced = LM(model=model, vendor=vendor, temperature=0.0, session_tracer=tracer)
|
137
|
+
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
138
|
+
_ = await classify_sample(lm_traced, test_text, label_names)
|
139
|
+
# inspect trace
|
140
|
+
events = [e for e in (tracer.current_session.event_history if tracer.current_session else []) if isinstance(e, LMCAISEvent)]
|
141
|
+
assert events, "No LMCAISEvent recorded by SessionTracer"
|
142
|
+
cr = events[-1].call_records[0]
|
143
|
+
traced_user = ""
|
144
|
+
for m in cr.input_messages:
|
145
|
+
if m.role == "user":
|
146
|
+
for part in m.parts:
|
147
|
+
if getattr(part, "type", None) == "text":
|
148
|
+
traced_user += (part.text or "")
|
149
|
+
assert "ATM" in traced_user, f"Expected substitution in traced prompt; got: {traced_user!r}"
|
150
|
+
print("LM path trace verified: substitution present in traced prompt.")
|
151
|
+
await tracer.end_timestep()
|
152
|
+
await tracer.end_session()
|
153
|
+
|
154
|
+
# 2) OpenAI wrapper path (AsyncOpenAI to Groq): ensure apply_injection is active
|
155
|
+
try:
|
156
|
+
import synth_ai.lm.provider_support.openai as _synth_openai_patch # noqa: F401
|
157
|
+
from openai import AsyncOpenAI
|
158
|
+
from datasets import load_dataset as _ld # ensure datasets present
|
159
|
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")
|
160
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
161
|
+
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
162
|
+
messages = [
|
163
|
+
{"role": "system", "content": "Echo user label."},
|
164
|
+
{"role": "user", "content": f"Please classify: {test_text}"},
|
165
|
+
]
|
166
|
+
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
167
|
+
resp = await client.chat.completions.create(model=model, messages=messages, temperature=0)
|
168
|
+
# Not all models echo input; instead, verify that our injected expectation matches
|
169
|
+
expected_user = messages[1]["content"].replace("atm", "ATM")
|
170
|
+
if messages[1]["content"] == expected_user:
|
171
|
+
print("OpenAI wrapper: input already normalized; skipping assertion.")
|
172
|
+
else:
|
173
|
+
print("OpenAI wrapper: sent message contains substitution expectation:", expected_user)
|
174
|
+
except Exception as e:
|
175
|
+
print("OpenAI wrapper test skipped due to error:", e)
|
176
|
+
|
177
|
+
# 3) Anthropic wrapper path (AsyncClient): ensure apply_injection is active
|
178
|
+
try:
|
179
|
+
import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
|
180
|
+
import anthropic
|
181
|
+
a_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022")
|
182
|
+
a_key = os.getenv("ANTHROPIC_API_KEY")
|
183
|
+
if a_key:
|
184
|
+
a_client = anthropic.AsyncClient(api_key=a_key)
|
185
|
+
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
186
|
+
_ = await a_client.messages.create(
|
187
|
+
model=a_model,
|
188
|
+
system="Echo user label.",
|
189
|
+
max_tokens=64,
|
190
|
+
temperature=0,
|
191
|
+
messages=[{"role": "user", "content": [{"type": "text", "text": test_text}]}],
|
192
|
+
)
|
193
|
+
print("Anthropic wrapper call completed (cannot reliably assert echo).")
|
194
|
+
else:
|
195
|
+
print("Anthropic wrapper test skipped: ANTHROPIC_API_KEY not set.")
|
196
|
+
except Exception as e:
|
197
|
+
print("Anthropic wrapper test skipped due to error:", e)
|
198
|
+
|
199
|
+
|
200
|
+
if __name__ == "__main__":
|
201
|
+
asyncio.run(main())
|