lm-deluge 0.0.67__py3-none-any.whl → 0.0.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +1 -2
- lm_deluge/api_requests/anthropic.py +117 -22
- lm_deluge/api_requests/base.py +84 -11
- lm_deluge/api_requests/bedrock.py +30 -6
- lm_deluge/api_requests/chat_reasoning.py +4 -0
- lm_deluge/api_requests/gemini.py +166 -20
- lm_deluge/api_requests/openai.py +145 -25
- lm_deluge/batches.py +15 -45
- lm_deluge/client.py +309 -50
- lm_deluge/config.py +15 -3
- lm_deluge/models/__init__.py +14 -1
- lm_deluge/models/anthropic.py +29 -14
- lm_deluge/models/arcee.py +16 -0
- lm_deluge/models/deepseek.py +36 -4
- lm_deluge/models/google.py +42 -0
- lm_deluge/models/grok.py +24 -0
- lm_deluge/models/kimi.py +36 -0
- lm_deluge/models/minimax.py +18 -0
- lm_deluge/models/openai.py +100 -0
- lm_deluge/models/openrouter.py +133 -7
- lm_deluge/models/together.py +11 -0
- lm_deluge/models/zai.py +50 -0
- lm_deluge/pipelines/gepa/__init__.py +95 -0
- lm_deluge/pipelines/gepa/core.py +354 -0
- lm_deluge/pipelines/gepa/docs/samples.py +705 -0
- lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
- lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
- lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
- lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
- lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
- lm_deluge/pipelines/gepa/optimizer.py +435 -0
- lm_deluge/pipelines/gepa/proposer.py +235 -0
- lm_deluge/pipelines/gepa/util.py +165 -0
- lm_deluge/{llm_tools → pipelines}/score.py +2 -2
- lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
- lm_deluge/prompt.py +537 -88
- lm_deluge/request_context.py +7 -2
- lm_deluge/server/__init__.py +24 -0
- lm_deluge/server/__main__.py +144 -0
- lm_deluge/server/adapters.py +369 -0
- lm_deluge/server/app.py +388 -0
- lm_deluge/server/auth.py +71 -0
- lm_deluge/server/model_policy.py +215 -0
- lm_deluge/server/models_anthropic.py +172 -0
- lm_deluge/server/models_openai.py +175 -0
- lm_deluge/tool/__init__.py +1130 -0
- lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
- lm_deluge/tool/builtin/anthropic/bash.py +0 -0
- lm_deluge/tool/builtin/anthropic/computer_use.py +0 -0
- lm_deluge/tool/builtin/gemini.py +59 -0
- lm_deluge/tool/builtin/openai.py +74 -0
- lm_deluge/tool/cua/__init__.py +173 -0
- lm_deluge/tool/cua/actions.py +148 -0
- lm_deluge/tool/cua/base.py +27 -0
- lm_deluge/tool/cua/batch.py +215 -0
- lm_deluge/tool/cua/converters.py +466 -0
- lm_deluge/tool/cua/kernel.py +702 -0
- lm_deluge/tool/cua/trycua.py +989 -0
- lm_deluge/tool/prefab/__init__.py +45 -0
- lm_deluge/tool/prefab/batch_tool.py +156 -0
- lm_deluge/tool/prefab/docs.py +1119 -0
- lm_deluge/tool/prefab/email.py +294 -0
- lm_deluge/tool/prefab/filesystem.py +1711 -0
- lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
- lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
- lm_deluge/tool/prefab/memory.py +458 -0
- lm_deluge/tool/prefab/otc/__init__.py +165 -0
- lm_deluge/tool/prefab/otc/executor.py +281 -0
- lm_deluge/tool/prefab/otc/parse.py +188 -0
- lm_deluge/tool/prefab/random.py +212 -0
- lm_deluge/tool/prefab/rlm/__init__.py +296 -0
- lm_deluge/tool/prefab/rlm/executor.py +349 -0
- lm_deluge/tool/prefab/rlm/parse.py +144 -0
- lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
- lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
- lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
- lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
- lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
- lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
- lm_deluge/tool/prefab/sheets.py +385 -0
- lm_deluge/tool/prefab/skills.py +0 -0
- lm_deluge/tool/prefab/subagents.py +233 -0
- lm_deluge/tool/prefab/todos.py +342 -0
- lm_deluge/tool/prefab/tool_search.py +169 -0
- lm_deluge/tool/prefab/web_search.py +199 -0
- lm_deluge/tracker.py +16 -13
- lm_deluge/util/schema.py +412 -0
- lm_deluge/warnings.py +8 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +23 -9
- lm_deluge-0.0.90.dist-info/RECORD +132 -0
- lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
- lm_deluge/built_in_tools/openai.py +0 -28
- lm_deluge/presets/cerebras.py +0 -17
- lm_deluge/presets/meta.py +0 -13
- lm_deluge/tool.py +0 -849
- lm_deluge-0.0.67.dist-info/RECORD +0 -72
- lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
- /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
- /lm_deluge/{built_in_tools/anthropic/bash.py → skills/anthropic.py} +0 -0
- /lm_deluge/{built_in_tools/anthropic/computer_use.py → skills/compat.py} +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,705 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GEPA-style population optimizer for fts-bench using lm-deluge (no litellm dependency).
|
|
3
|
+
|
|
4
|
+
Features:
|
|
5
|
+
- Maintains a pool of candidates with per-example validation scores.
|
|
6
|
+
- Selects a parent (best-by-val), mutates a single component, and accepts only if
|
|
7
|
+
minibatch reward improves; accepted candidates get a full val eval and join the pool.
|
|
8
|
+
- Components: system_prompt, search_docstring, fetch_docstring.
|
|
9
|
+
- Rollouts are run via verifiers + OpenAI SDK (pointing to lm-deluge proxy server); reflection uses LLMClient.
|
|
10
|
+
|
|
11
|
+
Prerequisites:
|
|
12
|
+
Start the lm-deluge proxy server first:
|
|
13
|
+
python -m lm_deluge.server --port 8000
|
|
14
|
+
|
|
15
|
+
Run:
|
|
16
|
+
uv run python gepa_lm_deluge_full.py --corpus-file ... --queries-file ... --env-file ...
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import argparse
|
|
22
|
+
import asyncio
|
|
23
|
+
import random
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any
|
|
27
|
+
|
|
28
|
+
import verifiers as vf # type: ignore
|
|
29
|
+
from datasets import Dataset # type: ignore
|
|
30
|
+
from dotenv import load_dotenv
|
|
31
|
+
from fts_bench import ( # type: ignore
|
|
32
|
+
DEFAULT_FETCH_DOCSTRING,
|
|
33
|
+
DEFAULT_SEARCH_DOCSTRING,
|
|
34
|
+
DEFAULT_SYSTEM_PROMPT,
|
|
35
|
+
)
|
|
36
|
+
from verifiers.utils.tool_utils import convert_func_to_oai_tool # type: ignore
|
|
37
|
+
|
|
38
|
+
from openai import AsyncOpenAI # type: ignore
|
|
39
|
+
|
|
40
|
+
from lm_deluge.client import LLMClient # type: ignore
|
|
41
|
+
from lm_deluge.util.json import try_load_json # type: ignore
|
|
42
|
+
|
|
43
|
+
# ---------------------- Helpers ---------------------- #
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _clean_state(state: dict[str, Any]) -> dict[str, Any]:
|
|
47
|
+
drop = {"prompt", "completion", "responses"}
|
|
48
|
+
return {k: v for k, v in state.items() if k not in drop}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _extract_assistant_message(messages: list[dict[str, Any]]) -> str:
|
|
52
|
+
for msg in reversed(messages):
|
|
53
|
+
if msg.get("role") == "assistant":
|
|
54
|
+
content = msg.get("content", "")
|
|
55
|
+
return content if isinstance(content, str) else str(content)
|
|
56
|
+
return ""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _count_tool_calls(messages: list[dict[str, Any]]) -> int:
|
|
60
|
+
total = 0
|
|
61
|
+
for msg in messages:
|
|
62
|
+
if msg.get("role") == "assistant" and isinstance(msg.get("tool_calls"), list):
|
|
63
|
+
total += len(msg["tool_calls"])
|
|
64
|
+
return total
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _summarize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
68
|
+
"""Compact view of assistant tool calls (name + truncated args)."""
|
|
69
|
+
calls: list[dict[str, Any]] = []
|
|
70
|
+
for msg in messages:
|
|
71
|
+
if msg.get("role") != "assistant":
|
|
72
|
+
continue
|
|
73
|
+
tool_calls = msg.get("tool_calls") or []
|
|
74
|
+
if not isinstance(tool_calls, list):
|
|
75
|
+
continue
|
|
76
|
+
for tc in tool_calls:
|
|
77
|
+
fn = (tc.get("function") or {}).get("name", "")
|
|
78
|
+
assert fn, "tool call missing name"
|
|
79
|
+
args_raw = (tc.get("function") or {}).get("arguments", "")
|
|
80
|
+
args_str = str(args_raw)
|
|
81
|
+
calls.append({"name": fn, "args": args_str})
|
|
82
|
+
return calls
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _parse_documents_from_completion(messages: list[dict[str, Any]]) -> list[str]:
|
|
86
|
+
assistant_msg = _extract_assistant_message(messages)
|
|
87
|
+
if "{" in assistant_msg:
|
|
88
|
+
assistant_msg = "{" + assistant_msg.split("{", 1)[1]
|
|
89
|
+
parsed = try_load_json(assistant_msg)
|
|
90
|
+
if isinstance(parsed, dict):
|
|
91
|
+
docs = parsed.get("documents", [])
|
|
92
|
+
if isinstance(docs, list):
|
|
93
|
+
return [str(doc) for doc in docs]
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _question_key_from_records(records: list[dict[str, Any]]) -> str:
|
|
98
|
+
if not records:
|
|
99
|
+
return "question"
|
|
100
|
+
keys = records[0].keys()
|
|
101
|
+
if "question" in keys:
|
|
102
|
+
return "question"
|
|
103
|
+
if "query" in keys:
|
|
104
|
+
return "query"
|
|
105
|
+
return "question"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _format_dataset(
|
|
109
|
+
env: vf.Environment,
|
|
110
|
+
records: list[dict[str, Any]],
|
|
111
|
+
system_prompt: str,
|
|
112
|
+
question_key: str,
|
|
113
|
+
) -> Dataset:
|
|
114
|
+
ds = Dataset.from_list(records)
|
|
115
|
+
if "prompt" in ds.column_names:
|
|
116
|
+
ds = ds.remove_columns("prompt")
|
|
117
|
+
return env.format_dataset(
|
|
118
|
+
ds,
|
|
119
|
+
system_prompt=system_prompt,
|
|
120
|
+
few_shot=env.few_shot,
|
|
121
|
+
question_key=question_key,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _prepare_env(env: vf.ToolEnv, candidate: dict[str, str]) -> None:
|
|
126
|
+
# Update text components and rebuild tool schemas.
|
|
127
|
+
if env.tools:
|
|
128
|
+
if len(env.tools) >= 1:
|
|
129
|
+
env.tools[0].__doc__ = candidate["search_docstring"]
|
|
130
|
+
if len(env.tools) >= 2:
|
|
131
|
+
env.tools[1].__doc__ = candidate["fetch_docstring"]
|
|
132
|
+
env.oai_tools = [convert_func_to_oai_tool(tool) for tool in env.tools]
|
|
133
|
+
env.tool_map = {
|
|
134
|
+
getattr(tool, "__name__", tool.__class__.__name__): tool
|
|
135
|
+
for tool in env.tools
|
|
136
|
+
}
|
|
137
|
+
env.system_prompt = candidate["system_prompt"]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _run_generate_sync(
|
|
141
|
+
env: vf.Environment,
|
|
142
|
+
dataset: Dataset,
|
|
143
|
+
client: Any,
|
|
144
|
+
model: str,
|
|
145
|
+
max_concurrency: int,
|
|
146
|
+
rollouts_per_example: int,
|
|
147
|
+
):
|
|
148
|
+
async def _run():
|
|
149
|
+
outputs: vf.GenerateOutputs = await env.generate(
|
|
150
|
+
inputs=dataset,
|
|
151
|
+
client=client, # type: ignore[arg-type]
|
|
152
|
+
model=model,
|
|
153
|
+
rollouts_per_example=rollouts_per_example,
|
|
154
|
+
max_concurrent=max_concurrency,
|
|
155
|
+
use_tqdm=False,
|
|
156
|
+
)
|
|
157
|
+
return outputs
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
return asyncio.run(_run())
|
|
161
|
+
except RuntimeError:
|
|
162
|
+
loop = asyncio.get_event_loop()
|
|
163
|
+
if loop.is_running():
|
|
164
|
+
return asyncio.run_coroutine_threadsafe(_run(), loop).result()
|
|
165
|
+
return loop.run_until_complete(_run())
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclass
|
|
169
|
+
class EvalResult:
|
|
170
|
+
scores: list[float]
|
|
171
|
+
trajectories: list[dict[str, Any]]
|
|
172
|
+
avg_score: float
|
|
173
|
+
example_ids: list[Any]
|
|
174
|
+
subscores: dict[Any, float]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def evaluate_candidate(
|
|
178
|
+
env: vf.ToolEnv,
|
|
179
|
+
candidate: dict[str, str],
|
|
180
|
+
records: list[dict[str, Any]],
|
|
181
|
+
client: Any,
|
|
182
|
+
model: str,
|
|
183
|
+
max_concurrency: int,
|
|
184
|
+
capture_traces: bool,
|
|
185
|
+
rollouts_per_example: int,
|
|
186
|
+
return_subscores: bool = False,
|
|
187
|
+
) -> EvalResult:
|
|
188
|
+
_prepare_env(env, candidate)
|
|
189
|
+
question_key = _question_key_from_records(records)
|
|
190
|
+
formatted = _format_dataset(env, records, candidate["system_prompt"], question_key)
|
|
191
|
+
results = _run_generate_sync(
|
|
192
|
+
env,
|
|
193
|
+
formatted,
|
|
194
|
+
client,
|
|
195
|
+
model,
|
|
196
|
+
max_concurrency=max_concurrency,
|
|
197
|
+
rollouts_per_example=rollouts_per_example,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
trajectories: list[dict[str, Any]] = []
|
|
201
|
+
scores = [float(r) for r in results.reward]
|
|
202
|
+
example_ids: list[Any] = []
|
|
203
|
+
subscores: dict[Any, float] = {}
|
|
204
|
+
for idx in range(len(formatted)):
|
|
205
|
+
completion_messages = results.completion[idx]
|
|
206
|
+
ex_id = results.example_id[idx]
|
|
207
|
+
example_ids.append(ex_id)
|
|
208
|
+
if return_subscores:
|
|
209
|
+
subscores[ex_id] = scores[idx]
|
|
210
|
+
traj = {
|
|
211
|
+
"example_id": ex_id,
|
|
212
|
+
"question": formatted[idx].get(question_key, ""),
|
|
213
|
+
"answer": str(results.answer[idx]),
|
|
214
|
+
"reward": scores[idx],
|
|
215
|
+
"tool_calls": _count_tool_calls(completion_messages), # type: ignore
|
|
216
|
+
"tool_calls_detail": _summarize_tool_calls(completion_messages), # type: ignore
|
|
217
|
+
"assistant_message": _extract_assistant_message(completion_messages), # type: ignore
|
|
218
|
+
"predicted_documents": _parse_documents_from_completion(
|
|
219
|
+
completion_messages # type: ignore
|
|
220
|
+
),
|
|
221
|
+
"prompt_messages": results.prompt[idx],
|
|
222
|
+
"completion_messages": completion_messages,
|
|
223
|
+
"state": _clean_state(results.state[idx]),
|
|
224
|
+
}
|
|
225
|
+
trajectories.append(traj)
|
|
226
|
+
|
|
227
|
+
avg_score = sum(scores) / max(len(scores), 1)
|
|
228
|
+
if not capture_traces:
|
|
229
|
+
trajectories = []
|
|
230
|
+
if not return_subscores:
|
|
231
|
+
subscores = {}
|
|
232
|
+
return EvalResult(
|
|
233
|
+
scores=scores,
|
|
234
|
+
trajectories=trajectories,
|
|
235
|
+
avg_score=avg_score,
|
|
236
|
+
example_ids=example_ids,
|
|
237
|
+
subscores=subscores,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _build_reflection_prompt(
|
|
242
|
+
component: str, current_text: str, trajectories: list[dict[str, Any]], k: int = 4
|
|
243
|
+
) -> str:
|
|
244
|
+
worst = sorted(trajectories, key=lambda t: t.get("reward", 0.0))[:k]
|
|
245
|
+
intro = {
|
|
246
|
+
"system_prompt": "Refine the system prompt for the search agent.",
|
|
247
|
+
"search_docstring": "Refine the SEARCH tool description so the model issues higher-recall queries.",
|
|
248
|
+
"fetch_docstring": "Refine the FETCH tool description so the model inspects the right snippets and returns correct doc IDs.",
|
|
249
|
+
}[component]
|
|
250
|
+
lines = [
|
|
251
|
+
intro,
|
|
252
|
+
f"Return ONLY the improved text for the {component.replace('_', ' ')}.",
|
|
253
|
+
"",
|
|
254
|
+
"Current text:",
|
|
255
|
+
current_text,
|
|
256
|
+
"",
|
|
257
|
+
"Trajectories:",
|
|
258
|
+
]
|
|
259
|
+
for t in worst:
|
|
260
|
+
tool_calls_detail = t.get("tool_calls_detail", [])
|
|
261
|
+
tool_calls_str = (
|
|
262
|
+
"; ".join(
|
|
263
|
+
[f"{c.get('name', '')}({c.get('args', '')})" for c in tool_calls_detail]
|
|
264
|
+
)
|
|
265
|
+
if tool_calls_detail
|
|
266
|
+
else "none"
|
|
267
|
+
)
|
|
268
|
+
lines.append(
|
|
269
|
+
f"- Q: {t.get('question', '')}\n"
|
|
270
|
+
f" Truth doc: {t.get('answer', '')}\n"
|
|
271
|
+
f" Predicted: {t.get('predicted_documents', [])}\n"
|
|
272
|
+
f" Reward: {t.get('reward', 0.0)} | Tool calls: {t.get('tool_calls', 0)} ({tool_calls_str})\n"
|
|
273
|
+
f" Assistant msg: {t.get('assistant_message', '')}\n"
|
|
274
|
+
)
|
|
275
|
+
lines.append(
|
|
276
|
+
"Improve the component to boost recall/precision and ensure the final JSON includes correct doc IDs."
|
|
277
|
+
)
|
|
278
|
+
return "\n".join(lines)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def propose_new_text(
|
|
282
|
+
reflection_client: LLMClient, # type: ignore
|
|
283
|
+
component: str,
|
|
284
|
+
current_text: str,
|
|
285
|
+
trajectories: list[dict[str, Any]],
|
|
286
|
+
) -> str:
|
|
287
|
+
prompt = _build_reflection_prompt(component, current_text, trajectories)
|
|
288
|
+
resp = reflection_client.process_prompts_sync([prompt], show_progress=False)[0]
|
|
289
|
+
text = resp.completion.strip()
|
|
290
|
+
return text if text else current_text
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# ---------------------- Frontier / merge helpers ---------------------- #
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def compute_val_frontier(population: list["CandidateRecord"]) -> dict[Any, set[int]]:
|
|
297
|
+
per_val_best: dict[Any, set[int]] = {}
|
|
298
|
+
max_score: dict[Any, float] = {}
|
|
299
|
+
for idx, cand in enumerate(population):
|
|
300
|
+
for val_id, score in cand.val_subscores.items():
|
|
301
|
+
best = max_score.get(val_id)
|
|
302
|
+
if best is None or score > best:
|
|
303
|
+
max_score[val_id] = score
|
|
304
|
+
per_val_best[val_id] = {idx}
|
|
305
|
+
elif score == best:
|
|
306
|
+
per_val_best[val_id].add(idx)
|
|
307
|
+
return per_val_best
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def frontier_union(frontier: dict[Any, set[int]]) -> set[int]:
|
|
311
|
+
all_ids: set[int] = set()
|
|
312
|
+
for ids in frontier.values():
|
|
313
|
+
all_ids.update(ids)
|
|
314
|
+
return all_ids
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def choose_merge_parents(
|
|
318
|
+
frontier_union_set: set[int],
|
|
319
|
+
population: list["CandidateRecord"],
|
|
320
|
+
rng: random.Random,
|
|
321
|
+
) -> tuple[int, int] | None:
|
|
322
|
+
if len(frontier_union_set) < 2:
|
|
323
|
+
return None
|
|
324
|
+
choices = list(frontier_union_set)
|
|
325
|
+
p1 = rng.choice(choices)
|
|
326
|
+
choices.remove(p1)
|
|
327
|
+
p2 = rng.choice(choices)
|
|
328
|
+
return p1, p2
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def merge_candidates(
|
|
332
|
+
parent_a: dict[str, str],
|
|
333
|
+
parent_b: dict[str, str],
|
|
334
|
+
components: list[str],
|
|
335
|
+
rng: random.Random,
|
|
336
|
+
) -> dict[str, str]:
|
|
337
|
+
child = {}
|
|
338
|
+
for comp in components:
|
|
339
|
+
child[comp] = rng.choice([parent_a[comp], parent_b[comp]])
|
|
340
|
+
return child
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# ---------------------- GEPA loop ---------------------- #
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@dataclass
|
|
347
|
+
class CandidateRecord:
|
|
348
|
+
candidate: dict[str, str]
|
|
349
|
+
val_scores: list[float]
|
|
350
|
+
val_avg: float
|
|
351
|
+
parents: list[int]
|
|
352
|
+
val_subscores: dict[Any, float]
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def parse_args() -> argparse.Namespace:
|
|
356
|
+
parser = argparse.ArgumentParser(
|
|
357
|
+
description="GEPA-style optimizer for fts-bench using lm-deluge."
|
|
358
|
+
)
|
|
359
|
+
parser.add_argument(
|
|
360
|
+
"--corpus-file", default="/Users/benjamin/building_codes_corpus.jsonl"
|
|
361
|
+
)
|
|
362
|
+
parser.add_argument(
|
|
363
|
+
"--queries-file",
|
|
364
|
+
default="/Users/benjamin/building_codes_queries_with_labels.jsonl",
|
|
365
|
+
)
|
|
366
|
+
parser.add_argument("--env-file", default="/Users/benjamin/Desktop/llm_tokens.env")
|
|
367
|
+
parser.add_argument(
|
|
368
|
+
"--model",
|
|
369
|
+
default="claude-5-mini",
|
|
370
|
+
help="Model for rollouts via lm-deluge proxy server.",
|
|
371
|
+
)
|
|
372
|
+
parser.add_argument(
|
|
373
|
+
"--proxy-url",
|
|
374
|
+
default="http://localhost:8000/v1",
|
|
375
|
+
help="URL of the lm-deluge proxy server.",
|
|
376
|
+
)
|
|
377
|
+
parser.add_argument(
|
|
378
|
+
"--reflection-model",
|
|
379
|
+
default="gpt-4.1-mini",
|
|
380
|
+
help="Model for reflection via LLMClient.",
|
|
381
|
+
)
|
|
382
|
+
parser.add_argument("--train-examples", type=int, default=48)
|
|
383
|
+
parser.add_argument("--val-examples", type=int, default=16)
|
|
384
|
+
parser.add_argument("--max-concurrency", type=int, default=4)
|
|
385
|
+
parser.add_argument("--iterations", type=int, default=40)
|
|
386
|
+
parser.add_argument("--minibatch-size", type=int, default=6)
|
|
387
|
+
parser.add_argument(
|
|
388
|
+
"--eval-every", type=int, default=5, help="Val evaluation cadence for logging."
|
|
389
|
+
)
|
|
390
|
+
parser.add_argument(
|
|
391
|
+
"--max-metric-calls",
|
|
392
|
+
type=int,
|
|
393
|
+
default=1000,
|
|
394
|
+
help="Budget in rollout evaluations.",
|
|
395
|
+
)
|
|
396
|
+
parser.add_argument("--rollouts-per-example", type=int, default=1)
|
|
397
|
+
parser.add_argument(
|
|
398
|
+
"--use-merge", action="store_true", help="Enable merge proposals."
|
|
399
|
+
)
|
|
400
|
+
parser.add_argument(
|
|
401
|
+
"--max-merge-invocations", type=int, default=5, help="Max merge attempts."
|
|
402
|
+
)
|
|
403
|
+
parser.add_argument(
|
|
404
|
+
"--merge-period",
|
|
405
|
+
type=int,
|
|
406
|
+
default=3,
|
|
407
|
+
help="Try merge every N iters when merges remain.",
|
|
408
|
+
)
|
|
409
|
+
parser.add_argument("--seed", type=int, default=0)
|
|
410
|
+
return parser.parse_args()
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def main() -> None:
|
|
414
|
+
args = parse_args()
|
|
415
|
+
load_dotenv(args.env_file)
|
|
416
|
+
rng = random.Random(args.seed)
|
|
417
|
+
|
|
418
|
+
# Build base environment once (keeps the index hot).
|
|
419
|
+
base_env = vf.load_environment(
|
|
420
|
+
"fts-bench",
|
|
421
|
+
corpus_file=args.corpus_file,
|
|
422
|
+
queries_file=args.queries_file,
|
|
423
|
+
max_turns=12,
|
|
424
|
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
|
425
|
+
search_docstring=DEFAULT_SEARCH_DOCSTRING,
|
|
426
|
+
fetch_docstring=DEFAULT_FETCH_DOCSTRING,
|
|
427
|
+
)
|
|
428
|
+
if base_env.dataset is None:
|
|
429
|
+
raise ValueError("fts-bench environment did not return a dataset.")
|
|
430
|
+
|
|
431
|
+
# Strip prompts to get raw records for re-formatting per candidate.
|
|
432
|
+
if "prompt" in base_env.dataset.column_names:
|
|
433
|
+
raw_ds = base_env.dataset.remove_columns("prompt")
|
|
434
|
+
else:
|
|
435
|
+
raw_ds = base_env.dataset
|
|
436
|
+
|
|
437
|
+
train_ds = raw_ds.select(range(min(len(raw_ds), args.train_examples)))
|
|
438
|
+
remaining_start = len(train_ds)
|
|
439
|
+
val_end = min(len(raw_ds), remaining_start + args.val_examples)
|
|
440
|
+
val_ds = (
|
|
441
|
+
raw_ds.select(range(remaining_start, val_end))
|
|
442
|
+
if val_end > remaining_start
|
|
443
|
+
else train_ds
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
train_records = [train_ds[i] for i in range(len(train_ds))]
|
|
447
|
+
val_records = [val_ds[i] for i in range(len(val_ds))]
|
|
448
|
+
question_key = _question_key_from_records(train_records or val_records) # noqa
|
|
449
|
+
|
|
450
|
+
# Create OpenAI client pointing to lm-deluge proxy server
|
|
451
|
+
rollout_client = AsyncOpenAI(base_url=args.proxy_url, api_key="not-needed")
|
|
452
|
+
reflection_client = LLMClient(args.reflection_model, progress="tqdm")
|
|
453
|
+
|
|
454
|
+
seed_candidate = {
|
|
455
|
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
|
456
|
+
"search_docstring": DEFAULT_SEARCH_DOCSTRING,
|
|
457
|
+
"fetch_docstring": DEFAULT_FETCH_DOCSTRING,
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
# Evaluate seed on val set.
|
|
461
|
+
seed_eval = evaluate_candidate(
|
|
462
|
+
base_env,
|
|
463
|
+
seed_candidate,
|
|
464
|
+
val_records,
|
|
465
|
+
rollout_client,
|
|
466
|
+
args.model,
|
|
467
|
+
args.max_concurrency,
|
|
468
|
+
capture_traces=False,
|
|
469
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
470
|
+
return_subscores=True,
|
|
471
|
+
)
|
|
472
|
+
population: list[CandidateRecord] = [
|
|
473
|
+
CandidateRecord(
|
|
474
|
+
candidate=seed_candidate,
|
|
475
|
+
val_scores=seed_eval.scores,
|
|
476
|
+
val_avg=seed_eval.avg_score,
|
|
477
|
+
parents=[],
|
|
478
|
+
val_subscores=seed_eval.subscores,
|
|
479
|
+
)
|
|
480
|
+
]
|
|
481
|
+
best_idx = 0
|
|
482
|
+
metric_calls = len(val_records) * args.rollouts_per_example
|
|
483
|
+
print(
|
|
484
|
+
f"Seed val avg reward: {seed_eval.avg_score:.3f} over {len(val_records)} examples"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
components = ["system_prompt", "search_docstring", "fetch_docstring"]
|
|
488
|
+
merges_due = 0
|
|
489
|
+
merges_tested = 0
|
|
490
|
+
frontier = compute_val_frontier(population)
|
|
491
|
+
|
|
492
|
+
def print_rollout_usage(rollout_client: AsyncOpenAI):
|
|
493
|
+
# Usage tracking not available via proxy - would need server-side tracking
|
|
494
|
+
print("Rollout client: using lm-deluge proxy server")
|
|
495
|
+
|
|
496
|
+
for it in range(1, args.iterations + 1):
|
|
497
|
+
print(f"=== Starting iteration {it} ===")
|
|
498
|
+
print_rollout_usage(rollout_client)
|
|
499
|
+
# print(rollout_client._clients)
|
|
500
|
+
if metric_calls >= args.max_metric_calls:
|
|
501
|
+
print(f"Stopping: reached metric budget {metric_calls}")
|
|
502
|
+
break
|
|
503
|
+
|
|
504
|
+
# Attempt merge first if scheduled
|
|
505
|
+
if (
|
|
506
|
+
args.use_merge
|
|
507
|
+
and merges_due > 0
|
|
508
|
+
and merges_tested < args.max_merge_invocations
|
|
509
|
+
and frontier_union(frontier)
|
|
510
|
+
):
|
|
511
|
+
parent_pair = choose_merge_parents(
|
|
512
|
+
frontier_union(frontier), population, rng
|
|
513
|
+
)
|
|
514
|
+
if parent_pair is not None:
|
|
515
|
+
p1_idx, p2_idx = parent_pair
|
|
516
|
+
parent_a = population[p1_idx].candidate
|
|
517
|
+
parent_b = population[p2_idx].candidate
|
|
518
|
+
|
|
519
|
+
minibatch = rng.sample(
|
|
520
|
+
train_records, k=min(args.minibatch_size, len(train_records))
|
|
521
|
+
)
|
|
522
|
+
eval_p1 = evaluate_candidate(
|
|
523
|
+
base_env,
|
|
524
|
+
parent_a,
|
|
525
|
+
minibatch,
|
|
526
|
+
rollout_client,
|
|
527
|
+
args.model,
|
|
528
|
+
args.max_concurrency,
|
|
529
|
+
capture_traces=False,
|
|
530
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
531
|
+
)
|
|
532
|
+
eval_p2 = evaluate_candidate(
|
|
533
|
+
base_env,
|
|
534
|
+
parent_b,
|
|
535
|
+
minibatch,
|
|
536
|
+
rollout_client,
|
|
537
|
+
args.model,
|
|
538
|
+
args.max_concurrency,
|
|
539
|
+
capture_traces=False,
|
|
540
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
541
|
+
)
|
|
542
|
+
metric_calls += 2 * len(minibatch) * args.rollouts_per_example
|
|
543
|
+
|
|
544
|
+
child_candidate = merge_candidates(parent_a, parent_b, components, rng)
|
|
545
|
+
eval_child = evaluate_candidate(
|
|
546
|
+
base_env,
|
|
547
|
+
child_candidate,
|
|
548
|
+
minibatch,
|
|
549
|
+
rollout_client,
|
|
550
|
+
args.model,
|
|
551
|
+
args.max_concurrency,
|
|
552
|
+
capture_traces=False,
|
|
553
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
554
|
+
)
|
|
555
|
+
metric_calls += len(minibatch) * args.rollouts_per_example
|
|
556
|
+
|
|
557
|
+
parent_max = max(sum(eval_p1.scores), sum(eval_p2.scores))
|
|
558
|
+
child_sum = sum(eval_child.scores)
|
|
559
|
+
improved = child_sum > parent_max
|
|
560
|
+
print(
|
|
561
|
+
f"[Iter {it}][MERGE] parents {p1_idx},{p2_idx} child_sum={child_sum:.2f} "
|
|
562
|
+
f"parent_max={parent_max:.2f} -> {'ACCEPT' if improved else 'REJECT'} "
|
|
563
|
+
f"| metric_calls={metric_calls}"
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
if improved:
|
|
567
|
+
val_eval = evaluate_candidate(
|
|
568
|
+
base_env,
|
|
569
|
+
child_candidate,
|
|
570
|
+
val_records,
|
|
571
|
+
rollout_client,
|
|
572
|
+
args.model,
|
|
573
|
+
args.max_concurrency,
|
|
574
|
+
capture_traces=False,
|
|
575
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
576
|
+
return_subscores=True,
|
|
577
|
+
)
|
|
578
|
+
metric_calls += len(val_records) * args.rollouts_per_example
|
|
579
|
+
population.append(
|
|
580
|
+
CandidateRecord(
|
|
581
|
+
candidate=child_candidate,
|
|
582
|
+
val_scores=val_eval.scores,
|
|
583
|
+
val_avg=val_eval.avg_score,
|
|
584
|
+
parents=[p1_idx, p2_idx],
|
|
585
|
+
val_subscores=val_eval.subscores,
|
|
586
|
+
)
|
|
587
|
+
)
|
|
588
|
+
merges_due = max(0, merges_due - 1)
|
|
589
|
+
merges_tested += 1
|
|
590
|
+
frontier = compute_val_frontier(population)
|
|
591
|
+
if val_eval.avg_score >= population[best_idx].val_avg:
|
|
592
|
+
best_idx = len(population) - 1
|
|
593
|
+
else:
|
|
594
|
+
# rejected merge; leave merges_due unchanged so it can be retried later
|
|
595
|
+
pass
|
|
596
|
+
|
|
597
|
+
# Parent selection: best by val avg.
|
|
598
|
+
parent_idx = max(range(len(population)), key=lambda i: population[i].val_avg)
|
|
599
|
+
parent = population[parent_idx].candidate
|
|
600
|
+
component = components[(it - 1) % len(components)]
|
|
601
|
+
|
|
602
|
+
# Minibatch for reflection.
|
|
603
|
+
minibatch = rng.sample(
|
|
604
|
+
train_records, k=min(args.minibatch_size, len(train_records))
|
|
605
|
+
)
|
|
606
|
+
eval_curr = evaluate_candidate(
|
|
607
|
+
base_env,
|
|
608
|
+
parent,
|
|
609
|
+
minibatch,
|
|
610
|
+
rollout_client,
|
|
611
|
+
args.model,
|
|
612
|
+
args.max_concurrency,
|
|
613
|
+
capture_traces=True,
|
|
614
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
615
|
+
)
|
|
616
|
+
metric_calls += len(minibatch) * args.rollouts_per_example
|
|
617
|
+
|
|
618
|
+
new_text = propose_new_text(
|
|
619
|
+
reflection_client, component, parent[component], eval_curr.trajectories
|
|
620
|
+
)
|
|
621
|
+
candidate_new = dict(parent)
|
|
622
|
+
candidate_new[component] = new_text
|
|
623
|
+
|
|
624
|
+
eval_new = evaluate_candidate(
|
|
625
|
+
base_env,
|
|
626
|
+
candidate_new,
|
|
627
|
+
minibatch,
|
|
628
|
+
rollout_client,
|
|
629
|
+
args.model,
|
|
630
|
+
args.max_concurrency,
|
|
631
|
+
capture_traces=False,
|
|
632
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
633
|
+
)
|
|
634
|
+
metric_calls += len(minibatch) * args.rollouts_per_example
|
|
635
|
+
|
|
636
|
+
old_sum = sum(eval_curr.scores)
|
|
637
|
+
new_sum = sum(eval_new.scores)
|
|
638
|
+
improved = new_sum > old_sum
|
|
639
|
+
print(
|
|
640
|
+
f"[Iter {it}] parent {parent_idx} comp={component} old_sum={old_sum:.2f} new_sum={new_sum:.2f} -> "
|
|
641
|
+
f"{'ACCEPT' if improved else 'REJECT'} | metric_calls={metric_calls}"
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if not improved:
|
|
645
|
+
continue
|
|
646
|
+
|
|
647
|
+
# Full val eval for accepted candidate.
|
|
648
|
+
val_eval = evaluate_candidate(
|
|
649
|
+
base_env,
|
|
650
|
+
candidate_new,
|
|
651
|
+
val_records,
|
|
652
|
+
rollout_client,
|
|
653
|
+
args.model,
|
|
654
|
+
args.max_concurrency,
|
|
655
|
+
capture_traces=False,
|
|
656
|
+
rollouts_per_example=args.rollouts_per_example,
|
|
657
|
+
return_subscores=True,
|
|
658
|
+
)
|
|
659
|
+
metric_calls += len(val_records) * args.rollouts_per_example
|
|
660
|
+
population.append(
|
|
661
|
+
CandidateRecord(
|
|
662
|
+
candidate=candidate_new,
|
|
663
|
+
val_scores=val_eval.scores,
|
|
664
|
+
val_avg=val_eval.avg_score,
|
|
665
|
+
parents=[parent_idx],
|
|
666
|
+
val_subscores=val_eval.subscores,
|
|
667
|
+
)
|
|
668
|
+
)
|
|
669
|
+
if val_eval.avg_score >= population[best_idx].val_avg:
|
|
670
|
+
best_idx = len(population) - 1
|
|
671
|
+
frontier = compute_val_frontier(population)
|
|
672
|
+
if args.use_merge and merges_tested < args.max_merge_invocations:
|
|
673
|
+
merges_due = min(merges_due + 1, args.max_merge_invocations - merges_tested)
|
|
674
|
+
|
|
675
|
+
if it % args.eval_every == 0:
|
|
676
|
+
print(
|
|
677
|
+
f" Val avg {val_eval.avg_score:.3f} (best {population[best_idx].val_avg:.3f}, pool {len(population)})"
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
best = population[best_idx]
|
|
681
|
+
out_dir = Path("debug_runs/gepa_lm_deluge_full")
|
|
682
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
683
|
+
(out_dir / "best_system_prompt.txt").write_text(
|
|
684
|
+
best.candidate["system_prompt"], encoding="utf-8"
|
|
685
|
+
)
|
|
686
|
+
(out_dir / "best_search_docstring.txt").write_text(
|
|
687
|
+
best.candidate["search_docstring"], encoding="utf-8"
|
|
688
|
+
)
|
|
689
|
+
(out_dir / "best_fetch_docstring.txt").write_text(
|
|
690
|
+
best.candidate["fetch_docstring"], encoding="utf-8"
|
|
691
|
+
)
|
|
692
|
+
print(
|
|
693
|
+
f"Done. Best val {best.val_avg:.3f} (pool {len(population)}, metric calls {metric_calls}). "
|
|
694
|
+
f"Artifacts in {out_dir.resolve()}"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
if __name__ == "__main__":
|
|
699
|
+
main()
|
|
700
|
+
|
|
701
|
+
# uv run python gepa_lm_deluge_full.py \
|
|
702
|
+
# --use-merge --max-merge-invocations 5 --merge-period 3 \
|
|
703
|
+
# --corpus-file /Users/benjamin/ccr_corpus.jsonl \
|
|
704
|
+
# --queries-file /Users/benjamin/ccr_queries_with_labels.jsonl \
|
|
705
|
+
# --env-file .env --model gpt-5-mini --reflection-model gpt-5-mini
|