mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,705 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ """
5
+ MCP Adapter for GEPA - Optimizes tool descriptions and system prompts.
6
+
7
+ Supports local (stdio) and remote (SSE/StreamableHTTP) MCP servers.
8
+ Enables optimization of tool descriptions, system prompts, and tool selection
9
+ across single or multiple tools.
10
+ """
11
+
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ from typing import Any, Callable, TypedDict
16
+
17
+ from mantisdk.algorithm.gepa.lib.core.adapter import EvaluationBatch, GEPAAdapter
18
+
19
+ try:
20
+ from mcp import StdioServerParameters # type: ignore[import-untyped]
21
+ except ImportError as e:
22
+ raise ImportError("MCP Python SDK is required. Install it with: pip install mcp") from e
23
+
24
+ from .mcp_client import create_mcp_client
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ # ============================================================================
30
+ # Type Definitions
31
+ # ============================================================================
32
+
33
+
34
+ class MCPDataInst(TypedDict):
35
+ """
36
+ Dataset item for MCP tool optimization.
37
+
38
+ Attributes:
39
+ user_query: The user's question or request
40
+ tool_arguments: Expected tool arguments (for validation/guidance)
41
+ reference_answer: Optional reference answer for scoring
42
+ additional_context: Optional additional context
43
+ """
44
+
45
+ user_query: str
46
+ tool_arguments: dict[str, Any]
47
+ reference_answer: str | None
48
+ additional_context: dict[str, str]
49
+
50
+
51
+ class MCPTrajectory(TypedDict):
52
+ """
53
+ Execution trace for MCP tool invocation.
54
+
55
+ Captures the full workflow including tool selection, execution,
56
+ and model behavior at each stage.
57
+ """
58
+
59
+ user_query: str
60
+ tool_names: list[str]
61
+ selected_tool: str | None
62
+ tool_called: bool
63
+ tool_arguments: dict[str, Any] | None
64
+ tool_response: str | None
65
+ tool_description_used: str
66
+ system_prompt_used: str
67
+ model_first_pass_output: str
68
+ model_final_output: str
69
+ score: float
70
+
71
+
72
+ class MCPOutput(TypedDict):
73
+ """
74
+ Output from MCP evaluation.
75
+
76
+ Attributes:
77
+ final_answer: The final answer from the model
78
+ tool_called: Whether a tool was called
79
+ selected_tool: Which tool was selected (if any)
80
+ tool_response: The tool's response (if called)
81
+ """
82
+
83
+ final_answer: str
84
+ tool_called: bool
85
+ selected_tool: str | None
86
+ tool_response: str | None
87
+
88
+
89
+ # ============================================================================
90
+ # MCP Adapter
91
+ # ============================================================================
92
+
93
+
94
+ class MCPAdapter(GEPAAdapter[MCPDataInst, MCPTrajectory, MCPOutput]):
95
+ """
96
+ GEPA adapter for optimizing MCP tool usage.
97
+
98
+ This adapter enables optimization of:
99
+ - Tool descriptions (single or multiple tools)
100
+ - System prompts for tool usage guidance
101
+ - Tool selection logic
102
+
103
+ Features:
104
+ - Multi-tool support: Optimize multiple tools simultaneously
105
+ - Two-pass workflow: Tool call + answer generation
106
+ - Multiple transports: stdio (local), SSE, StreamableHTTP (remote)
107
+ - Reflective datasets: Generate training data for refinement
108
+
109
+ Example (Local):
110
+ >>> from mcp import StdioServerParameters
111
+ >>> adapter = MCPAdapter(
112
+ ... tool_names=["read_file", "write_file"],
113
+ ... task_model="gpt-4o-mini",
114
+ ... metric_fn=lambda item, output: 1.0 if item["reference_answer"] in output else 0.0,
115
+ ... server_params=StdioServerParameters(
116
+ ... command="python",
117
+ ... args=["server.py"],
118
+ ... ),
119
+ ... )
120
+
121
+ Example (Remote):
122
+ >>> adapter = MCPAdapter(
123
+ ... tool_names="search_web",
124
+ ... task_model="gpt-4o-mini",
125
+ ... metric_fn=accuracy_metric,
126
+ ... remote_url="https://mcp-server.com/sse",
127
+ ... remote_transport="sse",
128
+ ... )
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ tool_names: str | list[str],
134
+ task_model: str | Callable,
135
+ metric_fn: Callable[[MCPDataInst, str], float],
136
+ # Local server configuration
137
+ server_params: StdioServerParameters | None = None,
138
+ # Remote server configuration
139
+ remote_url: str | None = None,
140
+ remote_transport: str = "sse",
141
+ remote_headers: dict[str, str] | None = None,
142
+ remote_timeout: float = 30,
143
+ # Adapter configuration
144
+ base_system_prompt: str = "You are a helpful assistant with access to tools.",
145
+ enable_two_pass: bool = True,
146
+ failure_score: float = 0.0,
147
+ ):
148
+ """
149
+ Initialize MCPAdapter.
150
+
151
+ Args:
152
+ tool_names: Name(s) of tool(s) to optimize (str or list[str])
153
+ task_model: Model for task execution (litellm string or callable)
154
+ metric_fn: Scoring function: (data_inst, output) -> float
155
+ server_params: Local MCP server configuration (stdio)
156
+ remote_url: Remote MCP server URL
157
+ remote_transport: "sse" or "streamable_http"
158
+ remote_headers: HTTP headers for remote (e.g., auth tokens)
159
+ remote_timeout: Timeout for remote HTTP operations
160
+ base_system_prompt: Base system prompt template
161
+ enable_two_pass: Use two-pass workflow (tool + answer)
162
+ failure_score: Score assigned when execution fails
163
+ """
164
+ # Store transport configuration
165
+ self.server_params = server_params
166
+ self.remote_url = remote_url
167
+ self.remote_transport = remote_transport
168
+ self.remote_headers = remote_headers or {}
169
+ self.remote_timeout = remote_timeout
170
+
171
+ # Normalize tool_names to list
172
+ self.tool_names = [tool_names] if isinstance(tool_names, str) else tool_names
173
+
174
+ # Store adapter configuration
175
+ self.base_system_prompt = base_system_prompt
176
+ self.enable_two_pass = enable_two_pass
177
+ self.failure_score = failure_score
178
+ self.metric_fn = metric_fn
179
+
180
+ # Setup model
181
+ if isinstance(task_model, str):
182
+ import litellm
183
+
184
+ self.litellm = litellm
185
+ self.task_model = task_model
186
+
187
+ def evaluate(
188
+ self,
189
+ batch: list[MCPDataInst],
190
+ candidate: dict[str, str],
191
+ capture_traces: bool = False,
192
+ ) -> EvaluationBatch[MCPTrajectory, MCPOutput]:
193
+ """
194
+ Evaluate candidate on batch using MCP tools.
195
+
196
+ Args:
197
+ batch: Dataset items to evaluate
198
+ candidate: Component mapping (e.g., {"tool_description": "..."})
199
+ capture_traces: Whether to capture detailed trajectories
200
+
201
+ Returns:
202
+ EvaluationBatch with outputs, scores, and optional trajectories
203
+ """
204
+ return asyncio.run(self._evaluate_async(batch, candidate, capture_traces))
205
+
206
+ async def _evaluate_async(
207
+ self,
208
+ batch: list[MCPDataInst],
209
+ candidate: dict[str, str],
210
+ capture_traces: bool,
211
+ ) -> EvaluationBatch[MCPTrajectory, MCPOutput]:
212
+ """Async implementation of evaluation."""
213
+ outputs: list[MCPOutput] = []
214
+ scores: list[float] = []
215
+ trajectories: list[MCPTrajectory] | None = [] if capture_traces else None
216
+
217
+ client = None
218
+ try:
219
+ # Create MCP client using factory
220
+ logger.info(f"Starting MCP session for batch of {len(batch)} items...")
221
+ client = create_mcp_client(
222
+ server_params=self.server_params,
223
+ remote_url=self.remote_url,
224
+ remote_transport=self.remote_transport,
225
+ remote_headers=self.remote_headers,
226
+ remote_timeout=self.remote_timeout,
227
+ )
228
+
229
+ await client.start()
230
+ init_result = await client.initialize()
231
+ logger.info(f"MCP session initialized: {init_result.get('serverInfo', {}).get('name', 'unknown')}")
232
+
233
+ # Get available tools
234
+ tools_list = await client.list_tools()
235
+ available_tools = [t for t in tools_list if t.get("name") in self.tool_names]
236
+
237
+ if not available_tools:
238
+ available_names = [t.get("name") for t in tools_list]
239
+ raise ValueError(f"Tools {self.tool_names} not found. Available: {available_names}")
240
+
241
+ # Build system prompt with tools
242
+ system_prompt = self._build_system_prompt(candidate, available_tools)
243
+
244
+ # Evaluate each item
245
+ for idx, item in enumerate(batch):
246
+ try:
247
+ logger.info(f"Evaluating item {idx + 1}/{len(batch)}: {item['user_query'][:50]}...")
248
+
249
+ # First pass: Model calls tool
250
+ first_pass = await self._first_pass(client, item, system_prompt, available_tools)
251
+ logger.info(f"First pass complete for item {idx + 1}")
252
+
253
+ # Second pass: Model uses tool response (if enabled)
254
+ if self.enable_two_pass and first_pass["tool_called"]:
255
+ final_output = await self._second_pass(client, item, system_prompt, first_pass["tool_response"])
256
+ else:
257
+ final_output = first_pass["output"]
258
+
259
+ # Score the output
260
+ score = self.metric_fn(item, final_output)
261
+
262
+ # Collect results
263
+ outputs.append(
264
+ {
265
+ "final_answer": final_output,
266
+ "tool_called": first_pass["tool_called"],
267
+ "selected_tool": first_pass["selected_tool"],
268
+ "tool_response": first_pass["tool_response"],
269
+ }
270
+ )
271
+ scores.append(score)
272
+
273
+ # Capture trajectory
274
+ if capture_traces and trajectories is not None:
275
+ trajectories.append(
276
+ {
277
+ "user_query": item["user_query"],
278
+ "tool_names": self.tool_names,
279
+ "selected_tool": first_pass["selected_tool"],
280
+ "tool_called": first_pass["tool_called"],
281
+ "tool_arguments": first_pass["tool_arguments"],
282
+ "tool_response": first_pass["tool_response"],
283
+ "tool_description_used": candidate.get("tool_description", ""),
284
+ "system_prompt_used": system_prompt,
285
+ "model_first_pass_output": first_pass["output"],
286
+ "model_final_output": final_output,
287
+ "score": score,
288
+ }
289
+ )
290
+
291
+ except Exception as e:
292
+ logger.exception(f"Failed to evaluate item: {item['user_query']}")
293
+ outputs.append(
294
+ {
295
+ "final_answer": "",
296
+ "tool_called": False,
297
+ "selected_tool": None,
298
+ "tool_response": None,
299
+ }
300
+ )
301
+ scores.append(self.failure_score)
302
+
303
+ if capture_traces and trajectories is not None:
304
+ trajectories.append(
305
+ {
306
+ "user_query": item["user_query"],
307
+ "tool_names": self.tool_names,
308
+ "selected_tool": None,
309
+ "tool_called": False,
310
+ "tool_arguments": None,
311
+ "tool_response": None,
312
+ "tool_description_used": candidate.get("tool_description", ""),
313
+ "system_prompt_used": system_prompt,
314
+ "model_first_pass_output": f"ERROR: {e!s}",
315
+ "model_final_output": "",
316
+ "score": self.failure_score,
317
+ }
318
+ )
319
+
320
+ except Exception as e:
321
+ logger.exception("Failed to create MCP session")
322
+ # Return failure for entire batch
323
+ for item in batch:
324
+ outputs.append(
325
+ {
326
+ "final_answer": "",
327
+ "tool_called": False,
328
+ "selected_tool": None,
329
+ "tool_response": None,
330
+ }
331
+ )
332
+ scores.append(self.failure_score)
333
+ if capture_traces and trajectories is not None:
334
+ trajectories.append(
335
+ {
336
+ "user_query": item["user_query"],
337
+ "tool_names": self.tool_names,
338
+ "selected_tool": None,
339
+ "tool_called": False,
340
+ "tool_arguments": None,
341
+ "tool_response": None,
342
+ "tool_description_used": "",
343
+ "system_prompt_used": "",
344
+ "model_first_pass_output": f"SESSION ERROR: {e!s}",
345
+ "model_final_output": "",
346
+ "score": self.failure_score,
347
+ }
348
+ )
349
+ finally:
350
+ if client:
351
+ await client.close()
352
+
353
+ return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
354
+
355
+ async def _first_pass(
356
+ self,
357
+ client,
358
+ item: MCPDataInst,
359
+ system_prompt: str,
360
+ available_tools: list[dict[str, Any]],
361
+ ) -> dict[str, Any]:
362
+ """
363
+ First pass: Model receives query and calls tool if needed.
364
+
365
+ Returns dict with:
366
+ - output: Raw model output
367
+ - tool_called: Whether tool was called
368
+ - selected_tool: Which tool was selected
369
+ - tool_arguments: Tool arguments
370
+ - tool_response: Tool response
371
+ """
372
+ messages = [
373
+ {"role": "system", "content": system_prompt},
374
+ {"role": "user", "content": item["user_query"]},
375
+ ]
376
+
377
+ try:
378
+ if isinstance(self.task_model, str):
379
+ logger.debug(f"Calling model with messages: {messages}")
380
+ response = self.litellm.completion(
381
+ model=self.task_model,
382
+ messages=messages,
383
+ )
384
+ model_output = response.choices[0].message.content.strip() # type: ignore[union-attr]
385
+ logger.debug(f"Model output: '{model_output}'")
386
+ else:
387
+ model_output = self.task_model(messages)
388
+
389
+ # Parse tool call (JSON format)
390
+ tool_called = False
391
+ selected_tool = None
392
+ tool_arguments = None
393
+ tool_response = None
394
+
395
+ try:
396
+ parsed = json.loads(model_output)
397
+ if parsed.get("action") == "call_tool":
398
+ tool_called = True
399
+ selected_tool = parsed.get("tool")
400
+ tool_arguments = parsed.get("arguments", {})
401
+
402
+ # Validate tool selection
403
+ if selected_tool not in self.tool_names:
404
+ logger.warning(f"Invalid tool '{selected_tool}', available: {self.tool_names}")
405
+ tool_called = False
406
+ selected_tool = None
407
+ else:
408
+ # Call the tool
409
+ result = await client.call_tool(selected_tool, tool_arguments)
410
+ tool_response = self._extract_tool_response(result)
411
+
412
+ except (json.JSONDecodeError, KeyError):
413
+ # Model didn't follow JSON format
414
+ pass
415
+
416
+ return {
417
+ "output": model_output,
418
+ "tool_called": tool_called,
419
+ "selected_tool": selected_tool,
420
+ "tool_arguments": tool_arguments,
421
+ "tool_response": tool_response,
422
+ }
423
+
424
+ except Exception as e:
425
+ logger.exception("First pass failed")
426
+ return {
427
+ "output": f"ERROR: {e!s}",
428
+ "tool_called": False,
429
+ "selected_tool": None,
430
+ "tool_arguments": None,
431
+ "tool_response": None,
432
+ }
433
+
434
+ async def _second_pass(
435
+ self,
436
+ client,
437
+ item: MCPDataInst,
438
+ system_prompt: str,
439
+ tool_response: str | None,
440
+ ) -> str:
441
+ """Second pass: Model receives tool response and generates final answer."""
442
+ messages = [
443
+ {"role": "system", "content": system_prompt},
444
+ {"role": "user", "content": item["user_query"]},
445
+ {
446
+ "role": "assistant",
447
+ "content": f"I'll use the tool to help answer this. Tool response: {tool_response}",
448
+ },
449
+ {
450
+ "role": "user",
451
+ "content": "Based on the tool response, please provide your final answer.",
452
+ },
453
+ ]
454
+
455
+ try:
456
+ if isinstance(self.task_model, str):
457
+ response = self.litellm.completion(
458
+ model=self.task_model,
459
+ messages=messages,
460
+ )
461
+ return response.choices[0].message.content.strip() # type: ignore[union-attr]
462
+ else:
463
+ return self.task_model(messages)
464
+
465
+ except Exception as e:
466
+ logger.exception("Second pass failed")
467
+ return f"ERROR: {e!s}"
468
+
469
+ def _build_system_prompt(self, candidate: dict[str, str], available_tools: list[dict[str, Any]]) -> str:
470
+ """Build system prompt with tool information."""
471
+ custom_system_prompt = candidate.get("system_prompt", self.base_system_prompt)
472
+
473
+ # Build tool descriptions
474
+ tool_descriptions = {}
475
+ for tool in available_tools:
476
+ tool_name = tool.get("name")
477
+ # Use optimized description if available
478
+ # Support both tool_description_{tool_name} (multi-tool) and tool_description (single-tool)
479
+ optimized_desc = candidate.get(f"tool_description_{tool_name}") or candidate.get("tool_description")
480
+ tool_descriptions[tool_name] = optimized_desc or tool.get("description", "")
481
+
482
+ # Build tools section
483
+ tools_section = "You have access to the following tools:\n\n"
484
+
485
+ for tool in available_tools:
486
+ tool_name = tool.get("name")
487
+ tool_description = tool_descriptions[tool_name]
488
+ input_schema = tool.get("inputSchema", {})
489
+
490
+ # Build example arguments from schema
491
+ properties = input_schema.get("properties", {})
492
+ example_args = {}
493
+ for param_name, param_info in properties.items():
494
+ if param_info.get("type") == "string":
495
+ example_args[param_name] = "example_value"
496
+ elif param_info.get("type") == "number":
497
+ example_args[param_name] = 123
498
+ elif param_info.get("type") == "boolean":
499
+ example_args[param_name] = True
500
+ else:
501
+ example_args[param_name] = "value"
502
+
503
+ if not example_args:
504
+ example_args = {"param": "value"}
505
+
506
+ example_json = json.dumps(example_args)
507
+
508
+ tools_section += f"""Tool: {tool_name}
509
+ Description: {tool_description}
510
+ Input Schema: {json.dumps(input_schema, indent=2)}
511
+ Example usage: {{"action": "call_tool", "tool": "{tool_name}", "arguments": {example_json}}}
512
+
513
+ """
514
+
515
+ # Add usage instructions
516
+ usage_instructions = f"""
517
+ When you need to use a tool, respond ONLY with JSON:
518
+ {{"action": "call_tool", "tool": "tool_name", "arguments": {{"param": "value"}}}}
519
+
520
+ When you can answer directly, respond ONLY with JSON:
521
+ {{"action": "answer", "text": "your answer"}}
522
+
523
+ Choose the most appropriate tool for the task. Available tools: {[t.get("name") for t in available_tools]}
524
+
525
+ Always respond with valid JSON. No other text.
526
+ """
527
+
528
+ return f"{custom_system_prompt}\n{tools_section}{usage_instructions}"
529
+
530
+ def _extract_tool_response(self, result) -> str:
531
+ """
532
+ Extract text from MCP tool response.
533
+
534
+ Handles multiple content types following MCP SDK best practices:
535
+ - TextContent: Plain text responses
536
+ - EmbeddedResource: Resource references
537
+ - ImageContent: Image data (converted to description)
538
+ - structuredContent: Structured JSON data
539
+
540
+ Based on latest MCP SDK examples and DSPy implementation.
541
+ """
542
+ try:
543
+ # Import MCP types for proper parsing
544
+ from mcp.types import EmbeddedResource, ImageContent, TextContent # type: ignore[import-untyped]
545
+
546
+ # Check for errors first (following DSPy pattern)
547
+ if hasattr(result, "isError") and result.isError:
548
+ # Extract error message from content
549
+ error_texts = []
550
+ for content_item in result.content:
551
+ if isinstance(content_item, TextContent):
552
+ error_texts.append(content_item.text)
553
+ error_msg = "\n".join(error_texts) if error_texts else "Tool execution failed"
554
+ logger.warning(f"Tool returned error: {error_msg}")
555
+ return f"ERROR: {error_msg}"
556
+
557
+ # Try structured content first (modern MCP pattern)
558
+ if hasattr(result, "structuredContent") and result.structuredContent:
559
+ import json
560
+
561
+ return json.dumps(result.structuredContent, indent=2)
562
+
563
+ # Parse content array
564
+ if hasattr(result, "content"):
565
+ texts = []
566
+ for content_item in result.content:
567
+ if isinstance(content_item, TextContent):
568
+ texts.append(content_item.text)
569
+ elif isinstance(content_item, EmbeddedResource):
570
+ # Handle embedded resources
571
+ resource = content_item.resource
572
+ if hasattr(resource, "text"):
573
+ texts.append(resource.text)
574
+ else:
575
+ texts.append(f"[Resource: {getattr(resource, 'uri', 'unknown')}]")
576
+ elif isinstance(content_item, ImageContent):
577
+ # Handle images with description
578
+ mime_type = getattr(content_item, "mimeType", "unknown")
579
+ data_len = len(getattr(content_item, "data", b""))
580
+ texts.append(f"[Image: {mime_type}, {data_len} bytes]")
581
+
582
+ if texts:
583
+ return "\n".join(texts)
584
+
585
+ # Fallback to dict access for backward compatibility
586
+ if isinstance(result, dict):
587
+ content = result.get("content", [])
588
+ if isinstance(content, list):
589
+ texts = []
590
+ for item in content:
591
+ if isinstance(item, dict) and item.get("type") == "text":
592
+ texts.append(item.get("text", ""))
593
+ # Return empty string if content list is present but empty
594
+ return "\n".join(texts)
595
+
596
+ return str(result)
597
+
598
+ except Exception as e:
599
+ logger.exception("Failed to extract tool response")
600
+ return f"ERROR extracting response: {e!s}"
601
+
602
+ def make_reflective_dataset(
603
+ self,
604
+ candidate: dict[str, str],
605
+ eval_batch: EvaluationBatch[MCPTrajectory, MCPOutput],
606
+ components_to_update: list[str],
607
+ ) -> dict[str, list[dict[str, Any]]]:
608
+ """
609
+ Build reflective dataset for instruction refinement.
610
+
611
+ Args:
612
+ candidate: Current candidate components
613
+ eval_batch: Evaluation results with trajectories
614
+ components_to_update: Which components to generate data for
615
+
616
+ Returns:
617
+ Dictionary mapping component names to reflective examples
618
+ """
619
+ reflective_data: dict[str, list[dict[str, Any]]] = {}
620
+
621
+ for component in components_to_update:
622
+ examples: list[dict[str, Any]] = []
623
+
624
+ for traj, score, _output in zip(
625
+ eval_batch.trajectories or [],
626
+ eval_batch.scores,
627
+ eval_batch.outputs,
628
+ strict=False,
629
+ ):
630
+ if component == "tool_description":
631
+ feedback = self._generate_tool_feedback(traj, score)
632
+ examples.append(
633
+ {
634
+ "Inputs": {
635
+ "user_query": traj["user_query"],
636
+ "tool_description": traj["tool_description_used"],
637
+ },
638
+ "Generated Outputs": {
639
+ "tool_called": traj["tool_called"],
640
+ "selected_tool": traj["selected_tool"],
641
+ "tool_arguments": traj["tool_arguments"],
642
+ "final_answer": traj["model_final_output"],
643
+ },
644
+ "Feedback": feedback,
645
+ }
646
+ )
647
+
648
+ elif component == "system_prompt":
649
+ feedback = self._generate_system_prompt_feedback(traj, score)
650
+ examples.append(
651
+ {
652
+ "Inputs": {
653
+ "user_query": traj["user_query"],
654
+ "system_prompt": traj["system_prompt_used"],
655
+ },
656
+ "Generated Outputs": traj["model_final_output"],
657
+ "Feedback": feedback,
658
+ }
659
+ )
660
+
661
+ reflective_data[component] = examples
662
+
663
+ return reflective_data
664
+
665
+ def _generate_tool_feedback(self, traj: MCPTrajectory, score: float) -> str:
666
+ """Generate feedback focused on tool usage and selection."""
667
+ if score > 0.5:
668
+ if traj["tool_called"]:
669
+ return (
670
+ f"Good! The tool '{traj['selected_tool']}' was used appropriately. "
671
+ f"Score: {score:.2f}"
672
+ )
673
+ else:
674
+ return f"Good! No tool needed, direct answer was correct. Score: {score:.2f}"
675
+ else:
676
+ feedback_parts = [f"Incorrect response (score: {score:.2f})."]
677
+
678
+ if not traj["tool_called"]:
679
+ feedback_parts.append("Tool was not called. Consider if a tool would help.")
680
+ else:
681
+ selected_tool = traj["selected_tool"]
682
+ available_tools = traj["tool_names"]
683
+ feedback_parts.append(
684
+ f"Tool '{selected_tool}' was called with {traj['tool_arguments']}, "
685
+ f"but answer was incorrect."
686
+ )
687
+ if len(available_tools) > 1:
688
+ feedback_parts.append(
689
+ f"Consider a different tool from {available_tools} or clearer description."
690
+ )
691
+ else:
692
+ feedback_parts.append("Tool description may need improvement.")
693
+
694
+ return " ".join(feedback_parts)
695
+
696
+ def _generate_system_prompt_feedback(self, traj: MCPTrajectory, score: float) -> str:
697
+ """Generate feedback focused on system prompt guidance."""
698
+ if score > 0.5:
699
+ return f"System prompt provided good guidance. Score: {score:.2f}"
700
+ else:
701
+ return (
702
+ f"System prompt may need improvement (score: {score:.2f}). "
703
+ f"Model {'called' if traj['tool_called'] else 'did not call'} tool, "
704
+ f"but answer was incorrect."
705
+ )