lm-deluge 0.0.67__py3-none-any.whl → 0.0.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (108) hide show
  1. lm_deluge/__init__.py +1 -2
  2. lm_deluge/api_requests/anthropic.py +117 -22
  3. lm_deluge/api_requests/base.py +84 -11
  4. lm_deluge/api_requests/bedrock.py +30 -6
  5. lm_deluge/api_requests/chat_reasoning.py +4 -0
  6. lm_deluge/api_requests/gemini.py +166 -20
  7. lm_deluge/api_requests/openai.py +145 -25
  8. lm_deluge/batches.py +15 -45
  9. lm_deluge/client.py +309 -50
  10. lm_deluge/config.py +15 -3
  11. lm_deluge/models/__init__.py +14 -1
  12. lm_deluge/models/anthropic.py +29 -14
  13. lm_deluge/models/arcee.py +16 -0
  14. lm_deluge/models/deepseek.py +36 -4
  15. lm_deluge/models/google.py +42 -0
  16. lm_deluge/models/grok.py +24 -0
  17. lm_deluge/models/kimi.py +36 -0
  18. lm_deluge/models/minimax.py +18 -0
  19. lm_deluge/models/openai.py +100 -0
  20. lm_deluge/models/openrouter.py +133 -7
  21. lm_deluge/models/together.py +11 -0
  22. lm_deluge/models/zai.py +50 -0
  23. lm_deluge/pipelines/gepa/__init__.py +95 -0
  24. lm_deluge/pipelines/gepa/core.py +354 -0
  25. lm_deluge/pipelines/gepa/docs/samples.py +705 -0
  26. lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
  27. lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
  28. lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
  29. lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
  30. lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
  31. lm_deluge/pipelines/gepa/optimizer.py +435 -0
  32. lm_deluge/pipelines/gepa/proposer.py +235 -0
  33. lm_deluge/pipelines/gepa/util.py +165 -0
  34. lm_deluge/{llm_tools → pipelines}/score.py +2 -2
  35. lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
  36. lm_deluge/prompt.py +537 -88
  37. lm_deluge/request_context.py +7 -2
  38. lm_deluge/server/__init__.py +24 -0
  39. lm_deluge/server/__main__.py +144 -0
  40. lm_deluge/server/adapters.py +369 -0
  41. lm_deluge/server/app.py +388 -0
  42. lm_deluge/server/auth.py +71 -0
  43. lm_deluge/server/model_policy.py +215 -0
  44. lm_deluge/server/models_anthropic.py +172 -0
  45. lm_deluge/server/models_openai.py +175 -0
  46. lm_deluge/tool/__init__.py +1130 -0
  47. lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
  48. lm_deluge/tool/builtin/anthropic/bash.py +0 -0
  49. lm_deluge/tool/builtin/anthropic/computer_use.py +0 -0
  50. lm_deluge/tool/builtin/gemini.py +59 -0
  51. lm_deluge/tool/builtin/openai.py +74 -0
  52. lm_deluge/tool/cua/__init__.py +173 -0
  53. lm_deluge/tool/cua/actions.py +148 -0
  54. lm_deluge/tool/cua/base.py +27 -0
  55. lm_deluge/tool/cua/batch.py +215 -0
  56. lm_deluge/tool/cua/converters.py +466 -0
  57. lm_deluge/tool/cua/kernel.py +702 -0
  58. lm_deluge/tool/cua/trycua.py +989 -0
  59. lm_deluge/tool/prefab/__init__.py +45 -0
  60. lm_deluge/tool/prefab/batch_tool.py +156 -0
  61. lm_deluge/tool/prefab/docs.py +1119 -0
  62. lm_deluge/tool/prefab/email.py +294 -0
  63. lm_deluge/tool/prefab/filesystem.py +1711 -0
  64. lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
  65. lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
  66. lm_deluge/tool/prefab/memory.py +458 -0
  67. lm_deluge/tool/prefab/otc/__init__.py +165 -0
  68. lm_deluge/tool/prefab/otc/executor.py +281 -0
  69. lm_deluge/tool/prefab/otc/parse.py +188 -0
  70. lm_deluge/tool/prefab/random.py +212 -0
  71. lm_deluge/tool/prefab/rlm/__init__.py +296 -0
  72. lm_deluge/tool/prefab/rlm/executor.py +349 -0
  73. lm_deluge/tool/prefab/rlm/parse.py +144 -0
  74. lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
  75. lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
  76. lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
  77. lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
  78. lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
  79. lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
  80. lm_deluge/tool/prefab/sheets.py +385 -0
  81. lm_deluge/tool/prefab/skills.py +0 -0
  82. lm_deluge/tool/prefab/subagents.py +233 -0
  83. lm_deluge/tool/prefab/todos.py +342 -0
  84. lm_deluge/tool/prefab/tool_search.py +169 -0
  85. lm_deluge/tool/prefab/web_search.py +199 -0
  86. lm_deluge/tracker.py +16 -13
  87. lm_deluge/util/schema.py +412 -0
  88. lm_deluge/warnings.py +8 -0
  89. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +23 -9
  90. lm_deluge-0.0.90.dist-info/RECORD +132 -0
  91. lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
  92. lm_deluge/built_in_tools/openai.py +0 -28
  93. lm_deluge/presets/cerebras.py +0 -17
  94. lm_deluge/presets/meta.py +0 -13
  95. lm_deluge/tool.py +0 -849
  96. lm_deluge-0.0.67.dist-info/RECORD +0 -72
  97. lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
  98. /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
  99. /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
  100. /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
  101. /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
  102. /lm_deluge/{built_in_tools/anthropic/bash.py → skills/anthropic.py} +0 -0
  103. /lm_deluge/{built_in_tools/anthropic/computer_use.py → skills/compat.py} +0 -0
  104. /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
  105. /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
  106. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
  107. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
  108. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,705 @@
1
+ """
2
+ GEPA-style population optimizer for fts-bench using lm-deluge (no litellm dependency).
3
+
4
+ Features:
5
+ - Maintains a pool of candidates with per-example validation scores.
6
+ - Selects a parent (best-by-val), mutates a single component, and accepts only if
7
+ minibatch reward improves; accepted candidates get a full val eval and join the pool.
8
+ - Components: system_prompt, search_docstring, fetch_docstring.
9
+ - Rollouts are run via verifiers + OpenAI SDK (pointing to lm-deluge proxy server); reflection uses LLMClient.
10
+
11
+ Prerequisites:
12
+ Start the lm-deluge proxy server first:
13
+ python -m lm_deluge.server --port 8000
14
+
15
+ Run:
16
+ uv run python gepa_lm_deluge_full.py --corpus-file ... --queries-file ... --env-file ...
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import asyncio
23
+ import random
24
+ from dataclasses import dataclass
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ import verifiers as vf # type: ignore
29
+ from datasets import Dataset # type: ignore
30
+ from dotenv import load_dotenv
31
+ from fts_bench import ( # type: ignore
32
+ DEFAULT_FETCH_DOCSTRING,
33
+ DEFAULT_SEARCH_DOCSTRING,
34
+ DEFAULT_SYSTEM_PROMPT,
35
+ )
36
+ from verifiers.utils.tool_utils import convert_func_to_oai_tool # type: ignore
37
+
38
+ from openai import AsyncOpenAI # type: ignore
39
+
40
+ from lm_deluge.client import LLMClient # type: ignore
41
+ from lm_deluge.util.json import try_load_json # type: ignore
42
+
43
+ # ---------------------- Helpers ---------------------- #
44
+
45
+
46
+ def _clean_state(state: dict[str, Any]) -> dict[str, Any]:
47
+ drop = {"prompt", "completion", "responses"}
48
+ return {k: v for k, v in state.items() if k not in drop}
49
+
50
+
51
+ def _extract_assistant_message(messages: list[dict[str, Any]]) -> str:
52
+ for msg in reversed(messages):
53
+ if msg.get("role") == "assistant":
54
+ content = msg.get("content", "")
55
+ return content if isinstance(content, str) else str(content)
56
+ return ""
57
+
58
+
59
+ def _count_tool_calls(messages: list[dict[str, Any]]) -> int:
60
+ total = 0
61
+ for msg in messages:
62
+ if msg.get("role") == "assistant" and isinstance(msg.get("tool_calls"), list):
63
+ total += len(msg["tool_calls"])
64
+ return total
65
+
66
+
67
+ def _summarize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
68
+ """Compact view of assistant tool calls (name + truncated args)."""
69
+ calls: list[dict[str, Any]] = []
70
+ for msg in messages:
71
+ if msg.get("role") != "assistant":
72
+ continue
73
+ tool_calls = msg.get("tool_calls") or []
74
+ if not isinstance(tool_calls, list):
75
+ continue
76
+ for tc in tool_calls:
77
+ fn = (tc.get("function") or {}).get("name", "")
78
+ assert fn, "tool call missing name"
79
+ args_raw = (tc.get("function") or {}).get("arguments", "")
80
+ args_str = str(args_raw)
81
+ calls.append({"name": fn, "args": args_str})
82
+ return calls
83
+
84
+
85
+ def _parse_documents_from_completion(messages: list[dict[str, Any]]) -> list[str]:
86
+ assistant_msg = _extract_assistant_message(messages)
87
+ if "{" in assistant_msg:
88
+ assistant_msg = "{" + assistant_msg.split("{", 1)[1]
89
+ parsed = try_load_json(assistant_msg)
90
+ if isinstance(parsed, dict):
91
+ docs = parsed.get("documents", [])
92
+ if isinstance(docs, list):
93
+ return [str(doc) for doc in docs]
94
+ return []
95
+
96
+
97
+ def _question_key_from_records(records: list[dict[str, Any]]) -> str:
98
+ if not records:
99
+ return "question"
100
+ keys = records[0].keys()
101
+ if "question" in keys:
102
+ return "question"
103
+ if "query" in keys:
104
+ return "query"
105
+ return "question"
106
+
107
+
108
+ def _format_dataset(
109
+ env: vf.Environment,
110
+ records: list[dict[str, Any]],
111
+ system_prompt: str,
112
+ question_key: str,
113
+ ) -> Dataset:
114
+ ds = Dataset.from_list(records)
115
+ if "prompt" in ds.column_names:
116
+ ds = ds.remove_columns("prompt")
117
+ return env.format_dataset(
118
+ ds,
119
+ system_prompt=system_prompt,
120
+ few_shot=env.few_shot,
121
+ question_key=question_key,
122
+ )
123
+
124
+
125
+ def _prepare_env(env: vf.ToolEnv, candidate: dict[str, str]) -> None:
126
+ # Update text components and rebuild tool schemas.
127
+ if env.tools:
128
+ if len(env.tools) >= 1:
129
+ env.tools[0].__doc__ = candidate["search_docstring"]
130
+ if len(env.tools) >= 2:
131
+ env.tools[1].__doc__ = candidate["fetch_docstring"]
132
+ env.oai_tools = [convert_func_to_oai_tool(tool) for tool in env.tools]
133
+ env.tool_map = {
134
+ getattr(tool, "__name__", tool.__class__.__name__): tool
135
+ for tool in env.tools
136
+ }
137
+ env.system_prompt = candidate["system_prompt"]
138
+
139
+
140
+ def _run_generate_sync(
141
+ env: vf.Environment,
142
+ dataset: Dataset,
143
+ client: Any,
144
+ model: str,
145
+ max_concurrency: int,
146
+ rollouts_per_example: int,
147
+ ):
148
+ async def _run():
149
+ outputs: vf.GenerateOutputs = await env.generate(
150
+ inputs=dataset,
151
+ client=client, # type: ignore[arg-type]
152
+ model=model,
153
+ rollouts_per_example=rollouts_per_example,
154
+ max_concurrent=max_concurrency,
155
+ use_tqdm=False,
156
+ )
157
+ return outputs
158
+
159
+ try:
160
+ return asyncio.run(_run())
161
+ except RuntimeError:
162
+ loop = asyncio.get_event_loop()
163
+ if loop.is_running():
164
+ return asyncio.run_coroutine_threadsafe(_run(), loop).result()
165
+ return loop.run_until_complete(_run())
166
+
167
+
168
+ @dataclass
169
+ class EvalResult:
170
+ scores: list[float]
171
+ trajectories: list[dict[str, Any]]
172
+ avg_score: float
173
+ example_ids: list[Any]
174
+ subscores: dict[Any, float]
175
+
176
+
177
+ def evaluate_candidate(
178
+ env: vf.ToolEnv,
179
+ candidate: dict[str, str],
180
+ records: list[dict[str, Any]],
181
+ client: Any,
182
+ model: str,
183
+ max_concurrency: int,
184
+ capture_traces: bool,
185
+ rollouts_per_example: int,
186
+ return_subscores: bool = False,
187
+ ) -> EvalResult:
188
+ _prepare_env(env, candidate)
189
+ question_key = _question_key_from_records(records)
190
+ formatted = _format_dataset(env, records, candidate["system_prompt"], question_key)
191
+ results = _run_generate_sync(
192
+ env,
193
+ formatted,
194
+ client,
195
+ model,
196
+ max_concurrency=max_concurrency,
197
+ rollouts_per_example=rollouts_per_example,
198
+ )
199
+
200
+ trajectories: list[dict[str, Any]] = []
201
+ scores = [float(r) for r in results.reward]
202
+ example_ids: list[Any] = []
203
+ subscores: dict[Any, float] = {}
204
+ for idx in range(len(formatted)):
205
+ completion_messages = results.completion[idx]
206
+ ex_id = results.example_id[idx]
207
+ example_ids.append(ex_id)
208
+ if return_subscores:
209
+ subscores[ex_id] = scores[idx]
210
+ traj = {
211
+ "example_id": ex_id,
212
+ "question": formatted[idx].get(question_key, ""),
213
+ "answer": str(results.answer[idx]),
214
+ "reward": scores[idx],
215
+ "tool_calls": _count_tool_calls(completion_messages), # type: ignore
216
+ "tool_calls_detail": _summarize_tool_calls(completion_messages), # type: ignore
217
+ "assistant_message": _extract_assistant_message(completion_messages), # type: ignore
218
+ "predicted_documents": _parse_documents_from_completion(
219
+ completion_messages # type: ignore
220
+ ),
221
+ "prompt_messages": results.prompt[idx],
222
+ "completion_messages": completion_messages,
223
+ "state": _clean_state(results.state[idx]),
224
+ }
225
+ trajectories.append(traj)
226
+
227
+ avg_score = sum(scores) / max(len(scores), 1)
228
+ if not capture_traces:
229
+ trajectories = []
230
+ if not return_subscores:
231
+ subscores = {}
232
+ return EvalResult(
233
+ scores=scores,
234
+ trajectories=trajectories,
235
+ avg_score=avg_score,
236
+ example_ids=example_ids,
237
+ subscores=subscores,
238
+ )
239
+
240
+
241
+ def _build_reflection_prompt(
242
+ component: str, current_text: str, trajectories: list[dict[str, Any]], k: int = 4
243
+ ) -> str:
244
+ worst = sorted(trajectories, key=lambda t: t.get("reward", 0.0))[:k]
245
+ intro = {
246
+ "system_prompt": "Refine the system prompt for the search agent.",
247
+ "search_docstring": "Refine the SEARCH tool description so the model issues higher-recall queries.",
248
+ "fetch_docstring": "Refine the FETCH tool description so the model inspects the right snippets and returns correct doc IDs.",
249
+ }[component]
250
+ lines = [
251
+ intro,
252
+ f"Return ONLY the improved text for the {component.replace('_', ' ')}.",
253
+ "",
254
+ "Current text:",
255
+ current_text,
256
+ "",
257
+ "Trajectories:",
258
+ ]
259
+ for t in worst:
260
+ tool_calls_detail = t.get("tool_calls_detail", [])
261
+ tool_calls_str = (
262
+ "; ".join(
263
+ [f"{c.get('name', '')}({c.get('args', '')})" for c in tool_calls_detail]
264
+ )
265
+ if tool_calls_detail
266
+ else "none"
267
+ )
268
+ lines.append(
269
+ f"- Q: {t.get('question', '')}\n"
270
+ f" Truth doc: {t.get('answer', '')}\n"
271
+ f" Predicted: {t.get('predicted_documents', [])}\n"
272
+ f" Reward: {t.get('reward', 0.0)} | Tool calls: {t.get('tool_calls', 0)} ({tool_calls_str})\n"
273
+ f" Assistant msg: {t.get('assistant_message', '')}\n"
274
+ )
275
+ lines.append(
276
+ "Improve the component to boost recall/precision and ensure the final JSON includes correct doc IDs."
277
+ )
278
+ return "\n".join(lines)
279
+
280
+
281
+ def propose_new_text(
282
+ reflection_client: LLMClient, # type: ignore
283
+ component: str,
284
+ current_text: str,
285
+ trajectories: list[dict[str, Any]],
286
+ ) -> str:
287
+ prompt = _build_reflection_prompt(component, current_text, trajectories)
288
+ resp = reflection_client.process_prompts_sync([prompt], show_progress=False)[0]
289
+ text = resp.completion.strip()
290
+ return text if text else current_text
291
+
292
+
293
+ # ---------------------- Frontier / merge helpers ---------------------- #
294
+
295
+
296
+ def compute_val_frontier(population: list["CandidateRecord"]) -> dict[Any, set[int]]:
297
+ per_val_best: dict[Any, set[int]] = {}
298
+ max_score: dict[Any, float] = {}
299
+ for idx, cand in enumerate(population):
300
+ for val_id, score in cand.val_subscores.items():
301
+ best = max_score.get(val_id)
302
+ if best is None or score > best:
303
+ max_score[val_id] = score
304
+ per_val_best[val_id] = {idx}
305
+ elif score == best:
306
+ per_val_best[val_id].add(idx)
307
+ return per_val_best
308
+
309
+
310
+ def frontier_union(frontier: dict[Any, set[int]]) -> set[int]:
311
+ all_ids: set[int] = set()
312
+ for ids in frontier.values():
313
+ all_ids.update(ids)
314
+ return all_ids
315
+
316
+
317
+ def choose_merge_parents(
318
+ frontier_union_set: set[int],
319
+ population: list["CandidateRecord"],
320
+ rng: random.Random,
321
+ ) -> tuple[int, int] | None:
322
+ if len(frontier_union_set) < 2:
323
+ return None
324
+ choices = list(frontier_union_set)
325
+ p1 = rng.choice(choices)
326
+ choices.remove(p1)
327
+ p2 = rng.choice(choices)
328
+ return p1, p2
329
+
330
+
331
+ def merge_candidates(
332
+ parent_a: dict[str, str],
333
+ parent_b: dict[str, str],
334
+ components: list[str],
335
+ rng: random.Random,
336
+ ) -> dict[str, str]:
337
+ child = {}
338
+ for comp in components:
339
+ child[comp] = rng.choice([parent_a[comp], parent_b[comp]])
340
+ return child
341
+
342
+
343
+ # ---------------------- GEPA loop ---------------------- #
344
+
345
+
346
+ @dataclass
347
+ class CandidateRecord:
348
+ candidate: dict[str, str]
349
+ val_scores: list[float]
350
+ val_avg: float
351
+ parents: list[int]
352
+ val_subscores: dict[Any, float]
353
+
354
+
355
+ def parse_args() -> argparse.Namespace:
356
+ parser = argparse.ArgumentParser(
357
+ description="GEPA-style optimizer for fts-bench using lm-deluge."
358
+ )
359
+ parser.add_argument(
360
+ "--corpus-file", default="/Users/benjamin/building_codes_corpus.jsonl"
361
+ )
362
+ parser.add_argument(
363
+ "--queries-file",
364
+ default="/Users/benjamin/building_codes_queries_with_labels.jsonl",
365
+ )
366
+ parser.add_argument("--env-file", default="/Users/benjamin/Desktop/llm_tokens.env")
367
+ parser.add_argument(
368
+ "--model",
369
+ default="claude-5-mini",
370
+ help="Model for rollouts via lm-deluge proxy server.",
371
+ )
372
+ parser.add_argument(
373
+ "--proxy-url",
374
+ default="http://localhost:8000/v1",
375
+ help="URL of the lm-deluge proxy server.",
376
+ )
377
+ parser.add_argument(
378
+ "--reflection-model",
379
+ default="gpt-4.1-mini",
380
+ help="Model for reflection via LLMClient.",
381
+ )
382
+ parser.add_argument("--train-examples", type=int, default=48)
383
+ parser.add_argument("--val-examples", type=int, default=16)
384
+ parser.add_argument("--max-concurrency", type=int, default=4)
385
+ parser.add_argument("--iterations", type=int, default=40)
386
+ parser.add_argument("--minibatch-size", type=int, default=6)
387
+ parser.add_argument(
388
+ "--eval-every", type=int, default=5, help="Val evaluation cadence for logging."
389
+ )
390
+ parser.add_argument(
391
+ "--max-metric-calls",
392
+ type=int,
393
+ default=1000,
394
+ help="Budget in rollout evaluations.",
395
+ )
396
+ parser.add_argument("--rollouts-per-example", type=int, default=1)
397
+ parser.add_argument(
398
+ "--use-merge", action="store_true", help="Enable merge proposals."
399
+ )
400
+ parser.add_argument(
401
+ "--max-merge-invocations", type=int, default=5, help="Max merge attempts."
402
+ )
403
+ parser.add_argument(
404
+ "--merge-period",
405
+ type=int,
406
+ default=3,
407
+ help="Try merge every N iters when merges remain.",
408
+ )
409
+ parser.add_argument("--seed", type=int, default=0)
410
+ return parser.parse_args()
411
+
412
+
413
+ def main() -> None:
414
+ args = parse_args()
415
+ load_dotenv(args.env_file)
416
+ rng = random.Random(args.seed)
417
+
418
+ # Build base environment once (keeps the index hot).
419
+ base_env = vf.load_environment(
420
+ "fts-bench",
421
+ corpus_file=args.corpus_file,
422
+ queries_file=args.queries_file,
423
+ max_turns=12,
424
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
425
+ search_docstring=DEFAULT_SEARCH_DOCSTRING,
426
+ fetch_docstring=DEFAULT_FETCH_DOCSTRING,
427
+ )
428
+ if base_env.dataset is None:
429
+ raise ValueError("fts-bench environment did not return a dataset.")
430
+
431
+ # Strip prompts to get raw records for re-formatting per candidate.
432
+ if "prompt" in base_env.dataset.column_names:
433
+ raw_ds = base_env.dataset.remove_columns("prompt")
434
+ else:
435
+ raw_ds = base_env.dataset
436
+
437
+ train_ds = raw_ds.select(range(min(len(raw_ds), args.train_examples)))
438
+ remaining_start = len(train_ds)
439
+ val_end = min(len(raw_ds), remaining_start + args.val_examples)
440
+ val_ds = (
441
+ raw_ds.select(range(remaining_start, val_end))
442
+ if val_end > remaining_start
443
+ else train_ds
444
+ )
445
+
446
+ train_records = [train_ds[i] for i in range(len(train_ds))]
447
+ val_records = [val_ds[i] for i in range(len(val_ds))]
448
+ question_key = _question_key_from_records(train_records or val_records) # noqa
449
+
450
+ # Create OpenAI client pointing to lm-deluge proxy server
451
+ rollout_client = AsyncOpenAI(base_url=args.proxy_url, api_key="not-needed")
452
+ reflection_client = LLMClient(args.reflection_model, progress="tqdm")
453
+
454
+ seed_candidate = {
455
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
456
+ "search_docstring": DEFAULT_SEARCH_DOCSTRING,
457
+ "fetch_docstring": DEFAULT_FETCH_DOCSTRING,
458
+ }
459
+
460
+ # Evaluate seed on val set.
461
+ seed_eval = evaluate_candidate(
462
+ base_env,
463
+ seed_candidate,
464
+ val_records,
465
+ rollout_client,
466
+ args.model,
467
+ args.max_concurrency,
468
+ capture_traces=False,
469
+ rollouts_per_example=args.rollouts_per_example,
470
+ return_subscores=True,
471
+ )
472
+ population: list[CandidateRecord] = [
473
+ CandidateRecord(
474
+ candidate=seed_candidate,
475
+ val_scores=seed_eval.scores,
476
+ val_avg=seed_eval.avg_score,
477
+ parents=[],
478
+ val_subscores=seed_eval.subscores,
479
+ )
480
+ ]
481
+ best_idx = 0
482
+ metric_calls = len(val_records) * args.rollouts_per_example
483
+ print(
484
+ f"Seed val avg reward: {seed_eval.avg_score:.3f} over {len(val_records)} examples"
485
+ )
486
+
487
+ components = ["system_prompt", "search_docstring", "fetch_docstring"]
488
+ merges_due = 0
489
+ merges_tested = 0
490
+ frontier = compute_val_frontier(population)
491
+
492
+ def print_rollout_usage(rollout_client: AsyncOpenAI):
493
+ # Usage tracking not available via proxy - would need server-side tracking
494
+ print("Rollout client: using lm-deluge proxy server")
495
+
496
+ for it in range(1, args.iterations + 1):
497
+ print(f"=== Starting iteration {it} ===")
498
+ print_rollout_usage(rollout_client)
499
+ # print(rollout_client._clients)
500
+ if metric_calls >= args.max_metric_calls:
501
+ print(f"Stopping: reached metric budget {metric_calls}")
502
+ break
503
+
504
+ # Attempt merge first if scheduled
505
+ if (
506
+ args.use_merge
507
+ and merges_due > 0
508
+ and merges_tested < args.max_merge_invocations
509
+ and frontier_union(frontier)
510
+ ):
511
+ parent_pair = choose_merge_parents(
512
+ frontier_union(frontier), population, rng
513
+ )
514
+ if parent_pair is not None:
515
+ p1_idx, p2_idx = parent_pair
516
+ parent_a = population[p1_idx].candidate
517
+ parent_b = population[p2_idx].candidate
518
+
519
+ minibatch = rng.sample(
520
+ train_records, k=min(args.minibatch_size, len(train_records))
521
+ )
522
+ eval_p1 = evaluate_candidate(
523
+ base_env,
524
+ parent_a,
525
+ minibatch,
526
+ rollout_client,
527
+ args.model,
528
+ args.max_concurrency,
529
+ capture_traces=False,
530
+ rollouts_per_example=args.rollouts_per_example,
531
+ )
532
+ eval_p2 = evaluate_candidate(
533
+ base_env,
534
+ parent_b,
535
+ minibatch,
536
+ rollout_client,
537
+ args.model,
538
+ args.max_concurrency,
539
+ capture_traces=False,
540
+ rollouts_per_example=args.rollouts_per_example,
541
+ )
542
+ metric_calls += 2 * len(minibatch) * args.rollouts_per_example
543
+
544
+ child_candidate = merge_candidates(parent_a, parent_b, components, rng)
545
+ eval_child = evaluate_candidate(
546
+ base_env,
547
+ child_candidate,
548
+ minibatch,
549
+ rollout_client,
550
+ args.model,
551
+ args.max_concurrency,
552
+ capture_traces=False,
553
+ rollouts_per_example=args.rollouts_per_example,
554
+ )
555
+ metric_calls += len(minibatch) * args.rollouts_per_example
556
+
557
+ parent_max = max(sum(eval_p1.scores), sum(eval_p2.scores))
558
+ child_sum = sum(eval_child.scores)
559
+ improved = child_sum > parent_max
560
+ print(
561
+ f"[Iter {it}][MERGE] parents {p1_idx},{p2_idx} child_sum={child_sum:.2f} "
562
+ f"parent_max={parent_max:.2f} -> {'ACCEPT' if improved else 'REJECT'} "
563
+ f"| metric_calls={metric_calls}"
564
+ )
565
+
566
+ if improved:
567
+ val_eval = evaluate_candidate(
568
+ base_env,
569
+ child_candidate,
570
+ val_records,
571
+ rollout_client,
572
+ args.model,
573
+ args.max_concurrency,
574
+ capture_traces=False,
575
+ rollouts_per_example=args.rollouts_per_example,
576
+ return_subscores=True,
577
+ )
578
+ metric_calls += len(val_records) * args.rollouts_per_example
579
+ population.append(
580
+ CandidateRecord(
581
+ candidate=child_candidate,
582
+ val_scores=val_eval.scores,
583
+ val_avg=val_eval.avg_score,
584
+ parents=[p1_idx, p2_idx],
585
+ val_subscores=val_eval.subscores,
586
+ )
587
+ )
588
+ merges_due = max(0, merges_due - 1)
589
+ merges_tested += 1
590
+ frontier = compute_val_frontier(population)
591
+ if val_eval.avg_score >= population[best_idx].val_avg:
592
+ best_idx = len(population) - 1
593
+ else:
594
+ # rejected merge; leave merges_due unchanged so it can be retried later
595
+ pass
596
+
597
+ # Parent selection: best by val avg.
598
+ parent_idx = max(range(len(population)), key=lambda i: population[i].val_avg)
599
+ parent = population[parent_idx].candidate
600
+ component = components[(it - 1) % len(components)]
601
+
602
+ # Minibatch for reflection.
603
+ minibatch = rng.sample(
604
+ train_records, k=min(args.minibatch_size, len(train_records))
605
+ )
606
+ eval_curr = evaluate_candidate(
607
+ base_env,
608
+ parent,
609
+ minibatch,
610
+ rollout_client,
611
+ args.model,
612
+ args.max_concurrency,
613
+ capture_traces=True,
614
+ rollouts_per_example=args.rollouts_per_example,
615
+ )
616
+ metric_calls += len(minibatch) * args.rollouts_per_example
617
+
618
+ new_text = propose_new_text(
619
+ reflection_client, component, parent[component], eval_curr.trajectories
620
+ )
621
+ candidate_new = dict(parent)
622
+ candidate_new[component] = new_text
623
+
624
+ eval_new = evaluate_candidate(
625
+ base_env,
626
+ candidate_new,
627
+ minibatch,
628
+ rollout_client,
629
+ args.model,
630
+ args.max_concurrency,
631
+ capture_traces=False,
632
+ rollouts_per_example=args.rollouts_per_example,
633
+ )
634
+ metric_calls += len(minibatch) * args.rollouts_per_example
635
+
636
+ old_sum = sum(eval_curr.scores)
637
+ new_sum = sum(eval_new.scores)
638
+ improved = new_sum > old_sum
639
+ print(
640
+ f"[Iter {it}] parent {parent_idx} comp={component} old_sum={old_sum:.2f} new_sum={new_sum:.2f} -> "
641
+ f"{'ACCEPT' if improved else 'REJECT'} | metric_calls={metric_calls}"
642
+ )
643
+
644
+ if not improved:
645
+ continue
646
+
647
+ # Full val eval for accepted candidate.
648
+ val_eval = evaluate_candidate(
649
+ base_env,
650
+ candidate_new,
651
+ val_records,
652
+ rollout_client,
653
+ args.model,
654
+ args.max_concurrency,
655
+ capture_traces=False,
656
+ rollouts_per_example=args.rollouts_per_example,
657
+ return_subscores=True,
658
+ )
659
+ metric_calls += len(val_records) * args.rollouts_per_example
660
+ population.append(
661
+ CandidateRecord(
662
+ candidate=candidate_new,
663
+ val_scores=val_eval.scores,
664
+ val_avg=val_eval.avg_score,
665
+ parents=[parent_idx],
666
+ val_subscores=val_eval.subscores,
667
+ )
668
+ )
669
+ if val_eval.avg_score >= population[best_idx].val_avg:
670
+ best_idx = len(population) - 1
671
+ frontier = compute_val_frontier(population)
672
+ if args.use_merge and merges_tested < args.max_merge_invocations:
673
+ merges_due = min(merges_due + 1, args.max_merge_invocations - merges_tested)
674
+
675
+ if it % args.eval_every == 0:
676
+ print(
677
+ f" Val avg {val_eval.avg_score:.3f} (best {population[best_idx].val_avg:.3f}, pool {len(population)})"
678
+ )
679
+
680
+ best = population[best_idx]
681
+ out_dir = Path("debug_runs/gepa_lm_deluge_full")
682
+ out_dir.mkdir(parents=True, exist_ok=True)
683
+ (out_dir / "best_system_prompt.txt").write_text(
684
+ best.candidate["system_prompt"], encoding="utf-8"
685
+ )
686
+ (out_dir / "best_search_docstring.txt").write_text(
687
+ best.candidate["search_docstring"], encoding="utf-8"
688
+ )
689
+ (out_dir / "best_fetch_docstring.txt").write_text(
690
+ best.candidate["fetch_docstring"], encoding="utf-8"
691
+ )
692
+ print(
693
+ f"Done. Best val {best.val_avg:.3f} (pool {len(population)}, metric calls {metric_calls}). "
694
+ f"Artifacts in {out_dir.resolve()}"
695
+ )
696
+
697
+
698
+ if __name__ == "__main__":
699
+ main()
700
+
701
+ # uv run python gepa_lm_deluge_full.py \
702
+ # --use-merge --max-merge-invocations 5 --merge-period 3 \
703
+ # --corpus-file /Users/benjamin/ccr_corpus.jsonl \
704
+ # --queries-file /Users/benjamin/ccr_queries_with_labels.jsonl \
705
+ # --env-file .env --model gpt-5-mini --reflection-model gpt-5-mini