lm-deluge 0.0.67__py3-none-any.whl → 0.0.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (108) hide show
  1. lm_deluge/__init__.py +1 -2
  2. lm_deluge/api_requests/anthropic.py +117 -22
  3. lm_deluge/api_requests/base.py +84 -11
  4. lm_deluge/api_requests/bedrock.py +30 -6
  5. lm_deluge/api_requests/chat_reasoning.py +4 -0
  6. lm_deluge/api_requests/gemini.py +166 -20
  7. lm_deluge/api_requests/openai.py +145 -25
  8. lm_deluge/batches.py +15 -45
  9. lm_deluge/client.py +309 -50
  10. lm_deluge/config.py +15 -3
  11. lm_deluge/models/__init__.py +14 -1
  12. lm_deluge/models/anthropic.py +29 -14
  13. lm_deluge/models/arcee.py +16 -0
  14. lm_deluge/models/deepseek.py +36 -4
  15. lm_deluge/models/google.py +42 -0
  16. lm_deluge/models/grok.py +24 -0
  17. lm_deluge/models/kimi.py +36 -0
  18. lm_deluge/models/minimax.py +18 -0
  19. lm_deluge/models/openai.py +100 -0
  20. lm_deluge/models/openrouter.py +133 -7
  21. lm_deluge/models/together.py +11 -0
  22. lm_deluge/models/zai.py +50 -0
  23. lm_deluge/pipelines/gepa/__init__.py +95 -0
  24. lm_deluge/pipelines/gepa/core.py +354 -0
  25. lm_deluge/pipelines/gepa/docs/samples.py +705 -0
  26. lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
  27. lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
  28. lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
  29. lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
  30. lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
  31. lm_deluge/pipelines/gepa/optimizer.py +435 -0
  32. lm_deluge/pipelines/gepa/proposer.py +235 -0
  33. lm_deluge/pipelines/gepa/util.py +165 -0
  34. lm_deluge/{llm_tools → pipelines}/score.py +2 -2
  35. lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
  36. lm_deluge/prompt.py +537 -88
  37. lm_deluge/request_context.py +7 -2
  38. lm_deluge/server/__init__.py +24 -0
  39. lm_deluge/server/__main__.py +144 -0
  40. lm_deluge/server/adapters.py +369 -0
  41. lm_deluge/server/app.py +388 -0
  42. lm_deluge/server/auth.py +71 -0
  43. lm_deluge/server/model_policy.py +215 -0
  44. lm_deluge/server/models_anthropic.py +172 -0
  45. lm_deluge/server/models_openai.py +175 -0
  46. lm_deluge/tool/__init__.py +1130 -0
  47. lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
  48. lm_deluge/tool/builtin/anthropic/bash.py +0 -0
  49. lm_deluge/tool/builtin/anthropic/computer_use.py +0 -0
  50. lm_deluge/tool/builtin/gemini.py +59 -0
  51. lm_deluge/tool/builtin/openai.py +74 -0
  52. lm_deluge/tool/cua/__init__.py +173 -0
  53. lm_deluge/tool/cua/actions.py +148 -0
  54. lm_deluge/tool/cua/base.py +27 -0
  55. lm_deluge/tool/cua/batch.py +215 -0
  56. lm_deluge/tool/cua/converters.py +466 -0
  57. lm_deluge/tool/cua/kernel.py +702 -0
  58. lm_deluge/tool/cua/trycua.py +989 -0
  59. lm_deluge/tool/prefab/__init__.py +45 -0
  60. lm_deluge/tool/prefab/batch_tool.py +156 -0
  61. lm_deluge/tool/prefab/docs.py +1119 -0
  62. lm_deluge/tool/prefab/email.py +294 -0
  63. lm_deluge/tool/prefab/filesystem.py +1711 -0
  64. lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
  65. lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
  66. lm_deluge/tool/prefab/memory.py +458 -0
  67. lm_deluge/tool/prefab/otc/__init__.py +165 -0
  68. lm_deluge/tool/prefab/otc/executor.py +281 -0
  69. lm_deluge/tool/prefab/otc/parse.py +188 -0
  70. lm_deluge/tool/prefab/random.py +212 -0
  71. lm_deluge/tool/prefab/rlm/__init__.py +296 -0
  72. lm_deluge/tool/prefab/rlm/executor.py +349 -0
  73. lm_deluge/tool/prefab/rlm/parse.py +144 -0
  74. lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
  75. lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
  76. lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
  77. lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
  78. lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
  79. lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
  80. lm_deluge/tool/prefab/sheets.py +385 -0
  81. lm_deluge/tool/prefab/skills.py +0 -0
  82. lm_deluge/tool/prefab/subagents.py +233 -0
  83. lm_deluge/tool/prefab/todos.py +342 -0
  84. lm_deluge/tool/prefab/tool_search.py +169 -0
  85. lm_deluge/tool/prefab/web_search.py +199 -0
  86. lm_deluge/tracker.py +16 -13
  87. lm_deluge/util/schema.py +412 -0
  88. lm_deluge/warnings.py +8 -0
  89. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +23 -9
  90. lm_deluge-0.0.90.dist-info/RECORD +132 -0
  91. lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
  92. lm_deluge/built_in_tools/openai.py +0 -28
  93. lm_deluge/presets/cerebras.py +0 -17
  94. lm_deluge/presets/meta.py +0 -13
  95. lm_deluge/tool.py +0 -849
  96. lm_deluge-0.0.67.dist-info/RECORD +0 -72
  97. lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
  98. /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
  99. /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
  100. /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
  101. /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
  102. /lm_deluge/{built_in_tools/anthropic/bash.py → skills/anthropic.py} +0 -0
  103. /lm_deluge/{built_in_tools/anthropic/computer_use.py → skills/compat.py} +0 -0
  104. /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
  105. /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
  106. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
  107. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
  108. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,140 @@
1
+ """
2
+ Example 1: Synthetic Keyword Matching Task
3
+
4
+ This is the simplest possible GEPA example - no LLM calls needed for the task itself.
5
+ The goal is to evolve a prompt that contains certain target keywords.
6
+
7
+ This example is useful for:
8
+ - Understanding GEPA's basic mechanics
9
+ - Testing without API costs
10
+ - Debugging your setup
11
+
12
+ Run:
13
+ python 01_synthetic_keywords.py
14
+ """
15
+
16
+ from lm_deluge import LLMClient
17
+ from lm_deluge.pipelines.gepa import Component, EvalResult, GEPAEngine
18
+ from lm_deluge.prompt import Conversation, Message
19
+
20
+
21
+ def main():
22
+ # Target: We want a prompt that mentions these concepts
23
+ TARGET_KEYWORDS = {
24
+ "step",
25
+ "by",
26
+ "think",
27
+ "carefully",
28
+ "show",
29
+ "work",
30
+ "verify",
31
+ "answer",
32
+ }
33
+
34
+ def evaluate(
35
+ client: LLMClient, # type: ignore
36
+ component_values: dict[str, str],
37
+ example: dict,
38
+ ) -> EvalResult:
39
+ """Score based on keyword coverage - no LLM needed."""
40
+ prompt = component_values["system_prompt"]
41
+ words = set(prompt.lower().split())
42
+ matches = len(words & TARGET_KEYWORDS)
43
+ score = matches / len(TARGET_KEYWORDS)
44
+
45
+ # Build a minimal conversation (required by EvalResult)
46
+ conv = Conversation.system(prompt)
47
+ conv = conv.add(Message.user(example["question"]))
48
+ conv = conv.add(
49
+ Message.ai(f"[Keyword score: {matches}/{len(TARGET_KEYWORDS)}]")
50
+ )
51
+
52
+ feedback = f"""Score: {score:.2f}
53
+ Keywords found: {words & TARGET_KEYWORDS}
54
+ Keywords missing: {TARGET_KEYWORDS - words}"""
55
+
56
+ return EvalResult(conversation=conv, score=score, feedback=feedback)
57
+
58
+ # Mock proposer that simulates what an LLM might suggest
59
+ # In real usage, this would be an actual LLM
60
+ iteration = [0]
61
+ improvements = [
62
+ "Think step by step before answering.",
63
+ "Think carefully, step by step. Show your work.",
64
+ "Think carefully, step by step. Show your work and verify your answer.",
65
+ "Think carefully and reason step by step. Show your work, then verify your answer is correct.",
66
+ ]
67
+
68
+ class MockProposerClient:
69
+ """Fake LLMClient that returns predetermined improvements."""
70
+
71
+ def process_prompts_sync(self, prompts, **kwargs):
72
+ iteration[0] += 1
73
+ idx = min(iteration[0], len(improvements) - 1)
74
+
75
+ class FakeResponse:
76
+ completion = f"""COMPONENT: system_prompt
77
+ REASONING: Adding more target keywords to improve coverage.
78
+ NEW_VALUE:
79
+ ```
80
+ {improvements[idx]}
81
+ ```"""
82
+
83
+ return [FakeResponse()]
84
+
85
+ # Simple dataset (content doesn't matter for this toy task)
86
+ dataset = [{"question": f"Question {i}"} for i in range(10)]
87
+
88
+ print("=" * 60)
89
+ print("GEPA Example 1: Synthetic Keyword Matching")
90
+ print("=" * 60)
91
+ print(f"Target keywords: {TARGET_KEYWORDS}")
92
+ print()
93
+
94
+ # Define component to optimize
95
+ components = {
96
+ "system_prompt": Component(
97
+ description="System prompt to optimize for keyword coverage",
98
+ value="You are a helpful assistant.",
99
+ ),
100
+ }
101
+
102
+ # Create engine with mock clients
103
+ engine = GEPAEngine(
104
+ components=components,
105
+ evaluate_fn=evaluate, # type: ignore[arg-type]
106
+ dataset=dataset[:7],
107
+ val_dataset=dataset[7:],
108
+ task_client=MockProposerClient(), # type: ignore[arg-type]
109
+ proposer_client=MockProposerClient(), # type: ignore[arg-type]
110
+ max_iterations=10,
111
+ max_evals=100,
112
+ minibatch_size=2,
113
+ seed=42,
114
+ )
115
+
116
+ # Run optimization
117
+ result = engine.run() # type: ignore[func-returns-value]
118
+
119
+ print()
120
+ print("=" * 60)
121
+ print("Results")
122
+ print("=" * 60)
123
+ print(f"Candidates discovered: {result.num_candidates}") # type: ignore[union-attr]
124
+ print(f"Best score: {result.best_score:.2%}") # type: ignore[union-attr]
125
+ print(f"Total evaluations: {result.total_evals}") # type: ignore[union-attr]
126
+ print()
127
+ print("Best prompt found:")
128
+ print("-" * 40)
129
+ print(result.best_candidate["system_prompt"]) # type: ignore[union-attr]
130
+ print("-" * 40)
131
+
132
+ # Show evolution
133
+ print()
134
+ print("Evolution of candidates:")
135
+ for i, (idx, candidate, score) in enumerate(result.best_k(5)): # type: ignore[union-attr]
136
+ print(f" {i+1}. Score={score:.2%}: {candidate['system_prompt'][:60]}...")
137
+
138
+
139
+ if __name__ == "__main__":
140
+ main()
@@ -0,0 +1,261 @@
1
+ """
2
+ Example: GSM8K Math Reasoning
3
+
4
+ Optimize a system prompt for grade school math problems.
5
+ This is a classic prompt optimization task - improve accuracy on GSM8K.
6
+
7
+ The task:
8
+ - Input: Math word problem
9
+ - Output: Numerical answer
10
+ - Score: Exact match with ground truth
11
+
12
+ Run:
13
+ python 02_gsm8k_math.py
14
+
15
+ Requirements:
16
+ pip install datasets
17
+ # Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable
18
+ """
19
+
20
+ import os
21
+ import re
22
+ import sys
23
+ from typing import cast
24
+
25
+ import dotenv
26
+
27
+ from lm_deluge.client import LLMClient, _LLMClient
28
+ from lm_deluge.pipelines.gepa import Component, EvalResult, optimize
29
+ from lm_deluge.prompt import Conversation, Message
30
+
31
+ dotenv.load_dotenv()
32
+
33
+
34
+ def load_gsm8k_sample(
35
+ n_train: int = 50, n_val: int = 20
36
+ ) -> tuple[list[dict], list[dict]]:
37
+ """Load a sample of GSM8K problems."""
38
+ try:
39
+ from datasets import load_dataset
40
+ except ImportError:
41
+ print("Please install datasets: pip install datasets")
42
+ sys.exit(1)
43
+
44
+ print("Loading GSM8K dataset...")
45
+ ds = load_dataset("openai/gsm8k", "main", split="train")
46
+
47
+ # Extract answer from "#### <number>" format
48
+ def extract_answer(answer_text: str) -> str:
49
+ match = re.search(r"####\s*(-?\d[\d,]*\.?\d*)", answer_text)
50
+ if match:
51
+ return match.group(1).replace(",", "")
52
+ return answer_text.strip()
53
+
54
+ data = []
55
+ for item in ds:
56
+ item = cast(dict[str, str], item)
57
+ data.append(
58
+ {
59
+ "question": item["question"],
60
+ "answer": extract_answer(item["answer"]),
61
+ }
62
+ )
63
+
64
+ # Shuffle and split
65
+ import random
66
+
67
+ random.seed(42)
68
+ random.shuffle(data)
69
+
70
+ return data[:n_train], data[n_train : n_train + n_val]
71
+
72
+
73
+ def extract_final_number(text: str) -> str | None:
74
+ """Extract the final number from model output."""
75
+ # Look for common patterns
76
+ patterns = [
77
+ r"(?:answer|result|total|=)\s*[:=]?\s*\$?(-?\d[\d,]*\.?\d*)",
78
+ r"####\s*(-?\d[\d,]*\.?\d*)",
79
+ r"\*\*(-?\d[\d,]*\.?\d*)\*\*",
80
+ r"(-?\d[\d,]*\.?\d*)\s*$", # Last number in text
81
+ ]
82
+
83
+ for pattern in patterns:
84
+ matches = re.findall(pattern, text, re.IGNORECASE)
85
+ if matches:
86
+ return matches[-1].replace(",", "")
87
+
88
+ # Fallback: find all numbers and return the last one
89
+ numbers = re.findall(r"-?\d[\d,]*\.?\d*", text)
90
+ if numbers:
91
+ return numbers[-1].replace(",", "")
92
+
93
+ return None
94
+
95
+
96
+ async def evaluate(
97
+ client: _LLMClient, component_values: dict[str, str], example: dict
98
+ ) -> EvalResult:
99
+ """Evaluate one math problem."""
100
+ # Build conversation
101
+ conv = Conversation.system(component_values["system_prompt"])
102
+ user_msg = f"""Problem: {example["question"]}
103
+
104
+ Solve this step by step, then provide your final numerical answer."""
105
+ conv = conv.add(Message.user(user_msg))
106
+
107
+ # Run inference (async)
108
+ response = await client.start(conv)
109
+ output = response.completion or ""
110
+
111
+ # Extract and score
112
+ predicted = extract_final_number(output)
113
+ expected = example["answer"]
114
+
115
+ if predicted is None:
116
+ score = 0.0
117
+ else:
118
+ try:
119
+ pred_num = float(predicted)
120
+ exp_num = float(expected)
121
+ score = 1.0 if abs(pred_num - exp_num) < 0.01 else 0.0
122
+ except ValueError:
123
+ score = 1.0 if predicted.strip() == expected.strip() else 0.0
124
+
125
+ # Build feedback for the proposer
126
+ if score == 1.0:
127
+ feedback = f"""Score: 1.0 (CORRECT)
128
+ Question: {example["question"][:100]}...
129
+ Expected: {expected}
130
+ Got: {predicted}"""
131
+ else:
132
+ feedback = f"""Score: 0.0 (INCORRECT)
133
+ Question: {example["question"]}
134
+ Expected answer: {expected}
135
+ Model extracted answer: {predicted}
136
+ Model reasoning: {output[:500]}{"..." if len(output) > 500 else ""}
137
+
138
+ The model either made a calculation error or failed to extract the answer properly."""
139
+
140
+ # Return full trajectory
141
+ full_conv = conv.add(Message.ai(output))
142
+ return EvalResult(conversation=full_conv, score=score, feedback=feedback)
143
+
144
+
145
+ def main():
146
+ # Check for API keys
147
+ model = None
148
+ proposer_model = None
149
+
150
+ if os.getenv("OPENAI_API_KEY"):
151
+ model = "gpt-4.1-nano"
152
+ proposer_model = "gpt-5-mini"
153
+ elif os.getenv("ANTHROPIC_API_KEY"):
154
+ model = "claude-3-5-haiku-latest"
155
+ proposer_model = "claude-sonnet-4-20250514"
156
+ else:
157
+ print("Please set OPENAI_API_KEY or ANTHROPIC_API_KEY")
158
+ sys.exit(1)
159
+
160
+ print(f"Using task model: {model}")
161
+ print(f"Using proposer model: {proposer_model}")
162
+
163
+ # Load data - larger val set for more room to improve
164
+ trainset, valset = load_gsm8k_sample(n_train=50, n_val=50)
165
+ print(f"Loaded {len(trainset)} training, {len(valset)} validation examples")
166
+
167
+ # Create clients
168
+ task_client = LLMClient(
169
+ model,
170
+ max_requests_per_minute=100,
171
+ max_new_tokens=512,
172
+ temperature=0.0,
173
+ )
174
+ proposer_client = LLMClient(
175
+ proposer_model,
176
+ max_requests_per_minute=50,
177
+ max_new_tokens=1024,
178
+ )
179
+
180
+ # Define component to optimize
181
+ components = {
182
+ "system_prompt": Component(
183
+ description="System prompt that instructs the model how to solve math problems step by step",
184
+ value="You are a helpful math tutor. Solve math problems step by step.",
185
+ ),
186
+ }
187
+
188
+ print()
189
+ print("=" * 60)
190
+ print("GEPA Example: GSM8K Math Reasoning")
191
+ print("=" * 60)
192
+ print(f"Seed prompt: {components['system_prompt'].value}")
193
+ print()
194
+
195
+ # Meta-instructions to guide the proposer
196
+ meta_instructions = """
197
+ - Focus on GENERAL improvements that will help across many problems, not just the specific example shown
198
+ - Don't add problem-specific details (like "mosquito infections" or "baguette rates") to the prompt
199
+ - Keep the prompt concise - longer is not always better
200
+ - Prioritize clarity and unambiguous instructions over covering edge cases
201
+ """.strip()
202
+
203
+ # Run optimization
204
+ result = optimize(
205
+ components=components,
206
+ evaluate_fn=evaluate,
207
+ dataset=trainset,
208
+ val_dataset=valset,
209
+ task_client=task_client,
210
+ proposer_client=proposer_client,
211
+ max_iterations=30,
212
+ max_evals=500,
213
+ minibatch_size=5,
214
+ meta_instructions=meta_instructions,
215
+ run_dir="./gsm8k_gepa",
216
+ save_trajectories=True,
217
+ seed=42,
218
+ )
219
+
220
+ print()
221
+ print("=" * 60)
222
+ print("Results")
223
+ print("=" * 60)
224
+ print(f"Candidates discovered: {result.num_candidates}")
225
+ print(f"Best validation accuracy: {result.best_score:.1%}")
226
+ print(f"Total evaluations: {result.total_evals}")
227
+ print()
228
+ print("Best prompt found:")
229
+ print("-" * 40)
230
+ print(result.best_candidate["system_prompt"])
231
+ print("-" * 40)
232
+
233
+ # Show top candidates
234
+ print()
235
+ print("Top 3 candidates:")
236
+ for i, (idx, candidate, score) in enumerate(result.best_k(3)):
237
+ print(f"\n{i + 1}. Score={score:.1%}")
238
+ print(f" {candidate['system_prompt'][:100]}...")
239
+
240
+ # Show lineage of best
241
+ print()
242
+ print(f"Lineage of best candidate (idx={result.best_idx}):")
243
+ lineage = result.lineage(result.best_idx)
244
+ for i, idx in enumerate(lineage):
245
+ score = result.candidate_avg_scores[idx]
246
+ prompt_preview = result.candidates[idx]["system_prompt"][:50]
247
+ print(f" {'→ ' if i > 0 else ''}{idx}: {score:.1%} - {prompt_preview}...")
248
+
249
+ # Print cost summary
250
+ print()
251
+ print("=" * 60)
252
+ print("Cost Summary")
253
+ print("=" * 60)
254
+ print("Task client (evaluations):")
255
+ task_client.print_usage()
256
+ print("\nProposer client (proposals):")
257
+ proposer_client.print_usage()
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
@@ -0,0 +1,300 @@
1
+ """
2
+ Example: HotpotQA Multi-hop Question Answering
3
+
4
+ Optimize a system prompt for multi-hop reasoning questions.
5
+ This task requires combining information from multiple sources.
6
+
7
+ The task:
8
+ - Input: Question + supporting context paragraphs
9
+ - Output: Short answer
10
+ - Score: F1 overlap with ground truth
11
+
12
+ This example demonstrates:
13
+ - Multi-component optimization (system_prompt + answer_format)
14
+ - More complex scoring (F1 instead of exact match)
15
+ - Richer trajectory information for reflection
16
+
17
+ Run:
18
+ python 03_hotpotqa_multihop.py
19
+
20
+ Requirements:
21
+ pip install datasets
22
+ # Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable
23
+ """
24
+
25
+ import os
26
+ import re
27
+ import string
28
+ import sys
29
+ from collections import Counter
30
+
31
+ import dotenv
32
+
33
+ from lm_deluge import LLMClient
34
+ from lm_deluge.pipelines.gepa import Component, EvalResult, optimize
35
+ from lm_deluge.prompt import Conversation, Message
36
+
37
+ dotenv.load_dotenv()
38
+
39
+
40
+ def load_hotpotqa_sample(
41
+ n_train: int = 40, n_val: int = 20
42
+ ) -> tuple[list[dict], list[dict]]:
43
+ """Load a sample of HotpotQA problems."""
44
+ try:
45
+ from datasets import load_dataset
46
+ except ImportError:
47
+ print("Please install datasets: pip install datasets")
48
+ sys.exit(1)
49
+
50
+ print("Loading HotpotQA dataset...")
51
+ ds = load_dataset(
52
+ "hotpot_qa", "distractor", split="validation", trust_remote_code=True
53
+ )
54
+
55
+ data = []
56
+ for item in ds: # type: ignore
57
+ # Combine supporting facts into context
58
+ context_parts = []
59
+ for title, sentences in zip(
60
+ item["context"]["title"], # type: ignore
61
+ item["context"]["sentences"], # type: ignore
62
+ ):
63
+ context_parts.append(f"[{title}]\n" + " ".join(sentences))
64
+
65
+ data.append(
66
+ {
67
+ "question": item["question"], # type: ignore
68
+ "context": "\n\n".join(context_parts),
69
+ "answer": item["answer"], # type: ignore
70
+ "type": item["type"], # type: ignore # 'comparison' or 'bridge'
71
+ }
72
+ )
73
+
74
+ # Shuffle and split
75
+ import random
76
+
77
+ random.seed(42)
78
+ random.shuffle(data)
79
+
80
+ return data[:n_train], data[n_train : n_train + n_val]
81
+
82
+
83
+ def normalize_answer(s: str) -> str:
84
+ """Normalize answer for comparison."""
85
+
86
+ def remove_articles(text):
87
+ return re.sub(r"\b(a|an|the)\b", " ", text)
88
+
89
+ def white_space_fix(text):
90
+ return " ".join(text.split())
91
+
92
+ def remove_punc(text):
93
+ exclude = set(string.punctuation)
94
+ return "".join(ch for ch in text if ch not in exclude)
95
+
96
+ def lower(text):
97
+ return text.lower()
98
+
99
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
100
+
101
+
102
+ def compute_f1(prediction: str, ground_truth: str) -> float:
103
+ """Compute F1 score between prediction and ground truth."""
104
+ pred_tokens = normalize_answer(prediction).split()
105
+ gt_tokens = normalize_answer(ground_truth).split()
106
+
107
+ if not pred_tokens or not gt_tokens:
108
+ return float(pred_tokens == gt_tokens)
109
+
110
+ common = Counter(pred_tokens) & Counter(gt_tokens)
111
+ num_same = sum(common.values())
112
+
113
+ if num_same == 0:
114
+ return 0.0
115
+
116
+ precision = num_same / len(pred_tokens)
117
+ recall = num_same / len(gt_tokens)
118
+ f1 = (2 * precision * recall) / (precision + recall)
119
+
120
+ return f1
121
+
122
+
123
+ def extract_answer(output: str) -> str:
124
+ """Extract answer from model output."""
125
+ if "Answer:" in output:
126
+ return output.split("Answer:")[-1].strip()
127
+ elif "answer:" in output.lower():
128
+ return output.lower().split("answer:")[-1].strip()
129
+ else:
130
+ # Take first sentence/line
131
+ return output.split("\n")[0].split(".")[0].strip()
132
+
133
+
134
+ def make_evaluate_fn(task_client: LLMClient): # type: ignore
135
+ """Create the evaluate function."""
136
+
137
+ def evaluate(
138
+ client: LLMClient, # type: ignore
139
+ component_values: dict[str, str],
140
+ example: dict,
141
+ ) -> EvalResult:
142
+ """Evaluate one HotpotQA question."""
143
+ # Build conversation
144
+ conv = Conversation.system(component_values["system_prompt"])
145
+
146
+ user_msg = f"""Context:
147
+ {example['context']}
148
+
149
+ Question: {example['question']}
150
+
151
+ {component_values['answer_format']}"""
152
+ conv = conv.add(Message.user(user_msg))
153
+
154
+ # Run inference
155
+ response = client.process_prompts_sync([conv], show_progress=False)[0]
156
+ output = response.completion or ""
157
+
158
+ # Extract answer and compute F1
159
+ extracted = extract_answer(output)
160
+ f1 = compute_f1(extracted, example["answer"])
161
+
162
+ # Build detailed feedback
163
+ if f1 >= 0.8:
164
+ feedback = f"""Score: {f1:.2f} (GOOD)
165
+ Question type: {example['type']}
166
+ Expected: {example['answer']}
167
+ Got: {extracted}"""
168
+ else:
169
+ hint = ""
170
+ if example["type"] == "comparison":
171
+ hint = "This requires comparing two entities."
172
+ else:
173
+ hint = "This requires following a chain of reasoning."
174
+
175
+ feedback = f"""Score: {f1:.2f} (NEEDS IMPROVEMENT)
176
+ Question type: {example['type']}
177
+ Expected: {example['answer']}
178
+ Got: {extracted}
179
+ Model output: {output[:300]}{'...' if len(output) > 300 else ''}
180
+ Hint: {hint}"""
181
+
182
+ # Return full trajectory
183
+ full_conv = conv.add(Message.ai(output))
184
+ return EvalResult(conversation=full_conv, score=f1, feedback=feedback)
185
+
186
+ return evaluate
187
+
188
+
189
+ def main():
190
+ # Check for API keys
191
+ model = None
192
+ proposer_model = None
193
+
194
+ if os.getenv("OPENAI_API_KEY"):
195
+ model = "gpt-4.1-nano"
196
+ proposer_model = "gpt-4.1-mini"
197
+ elif os.getenv("ANTHROPIC_API_KEY"):
198
+ model = "claude-3-5-haiku-latest"
199
+ proposer_model = "claude-sonnet-4-20250514"
200
+ else:
201
+ print("Please set OPENAI_API_KEY or ANTHROPIC_API_KEY")
202
+ sys.exit(1)
203
+
204
+ print(f"Using task model: {model}")
205
+ print(f"Using proposer model: {proposer_model}")
206
+
207
+ # Load data
208
+ trainset, valset = load_hotpotqa_sample(n_train=30, n_val=15)
209
+ print(f"Loaded {len(trainset)} training, {len(valset)} validation examples")
210
+
211
+ # Show question type distribution
212
+ train_types = Counter(x["type"] for x in trainset)
213
+ print(f"Training set types: {dict(train_types)}")
214
+
215
+ # Create clients
216
+ task_client = LLMClient( # type: ignore[operator]
217
+ model,
218
+ max_requests_per_minute=100,
219
+ max_new_tokens=256,
220
+ temperature=0.0,
221
+ )
222
+ proposer_client = LLMClient( # type: ignore[operator]
223
+ proposer_model,
224
+ max_requests_per_minute=50,
225
+ max_new_tokens=1024,
226
+ )
227
+
228
+ # Define components to optimize (two components this time)
229
+ components = {
230
+ "system_prompt": Component(
231
+ description="System prompt that guides the model's reasoning approach",
232
+ value="You are a helpful assistant that answers questions based on the provided context.",
233
+ ),
234
+ "answer_format": Component(
235
+ description="Instructions for how the model should format its answer",
236
+ value="Provide a short, direct answer to the question.",
237
+ ),
238
+ }
239
+
240
+ print()
241
+ print("=" * 60)
242
+ print("GEPA Example: HotpotQA Multi-hop QA")
243
+ print("=" * 60)
244
+ print("Components being optimized:")
245
+ for name, comp in components.items():
246
+ print(f" - {name}: {comp.value[:50]}...")
247
+ print()
248
+
249
+ # Run optimization
250
+ result = optimize(
251
+ components=components,
252
+ evaluate_fn=make_evaluate_fn(task_client), # type: ignore[arg-type]
253
+ dataset=trainset,
254
+ val_dataset=valset,
255
+ task_client=task_client,
256
+ proposer_client=proposer_client,
257
+ max_iterations=20,
258
+ max_evals=250,
259
+ minibatch_size=3,
260
+ run_dir="./hotpotqa_gepa",
261
+ save_trajectories=True,
262
+ seed=42,
263
+ )
264
+
265
+ print()
266
+ print("=" * 60)
267
+ print("Results")
268
+ print("=" * 60)
269
+ print(f"Candidates discovered: {result.num_candidates}")
270
+ print(f"Best validation F1: {result.best_score:.1%}")
271
+ print(f"Total evaluations: {result.total_evals}")
272
+ print()
273
+ print("Best candidate found:")
274
+ print("-" * 40)
275
+ for name, text in result.best_candidate.items():
276
+ print(f"{name}:")
277
+ print(f" {text}")
278
+ print()
279
+ print("-" * 40)
280
+
281
+ # Show improvement
282
+ seed_score = result.candidate_avg_scores[0]
283
+ improvement = result.best_score - seed_score
284
+ print(
285
+ f"\nImprovement over seed: {seed_score:.1%} → {result.best_score:.1%} (+{improvement:.1%})"
286
+ )
287
+
288
+ # Show which component changed most
289
+ if result.num_candidates > 1:
290
+ diff = result.diff(0, result.best_idx)
291
+ print("\nChanges from seed to best:")
292
+ for comp, (old, new) in diff.items():
293
+ if old != new:
294
+ print(f" {comp}:")
295
+ print(f" OLD: {old[:60]}...")
296
+ print(f" NEW: {new[:60]}...")
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()