mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,364 @@
1
+ """
2
+ Unified MCP Client - Supports stdio, SSE, and StreamableHTTP transports.
3
+
4
+ This utility provides a single abstraction for connecting to MCP servers
5
+ using different transport mechanisms.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class BaseMCPClient(ABC):
18
+ """Abstract base class for MCP clients."""
19
+
20
+ def __init__(self):
21
+ self.request_id = 0
22
+
23
+ @abstractmethod
24
+ async def start(self):
25
+ """Start the MCP connection."""
26
+ pass
27
+
28
+ @abstractmethod
29
+ async def send_request(self, method: str, params: dict | None = None) -> dict:
30
+ """Send JSON-RPC request and get response."""
31
+ pass
32
+
33
+ @abstractmethod
34
+ async def close(self):
35
+ """Close the connection."""
36
+ pass
37
+
38
+ async def initialize(self) -> dict:
39
+ """Initialize MCP session (common across all transports)."""
40
+ result = await self.send_request(
41
+ "initialize",
42
+ {
43
+ "protocolVersion": "2024-11-05",
44
+ "capabilities": {},
45
+ "clientInfo": {"name": "gepa-mcp-adapter", "version": "1.0"},
46
+ },
47
+ )
48
+ await self._send_initialized_notification()
49
+ return result
50
+
51
+ @abstractmethod
52
+ async def _send_initialized_notification(self):
53
+ """Send initialized notification (transport-specific)."""
54
+ pass
55
+
56
+ async def list_tools(self) -> list[dict]:
57
+ """List available tools."""
58
+ result = await self.send_request("tools/list")
59
+ return result.get("tools", [])
60
+
61
+ async def call_tool(self, name: str, arguments: dict) -> dict:
62
+ """Call a tool."""
63
+ return await self.send_request("tools/call", {"name": name, "arguments": arguments})
64
+
65
+
66
+ class StdioMCPClient(BaseMCPClient):
67
+ """MCP client using stdio transport (subprocess-based)."""
68
+
69
+ def __init__(self, command: str, args: list[str]):
70
+ super().__init__()
71
+ self.command = command
72
+ self.args = args
73
+ self.process = None
74
+
75
+ async def start(self):
76
+ """Start the MCP server process."""
77
+ logger.info(f"Starting stdio MCP server: {self.command} {' '.join(self.args)}")
78
+ self.process = await asyncio.create_subprocess_exec(
79
+ self.command,
80
+ *self.args,
81
+ stdin=asyncio.subprocess.PIPE,
82
+ stdout=asyncio.subprocess.PIPE,
83
+ stderr=asyncio.subprocess.DEVNULL,
84
+ )
85
+
86
+ async def send_request(self, method: str, params: dict | None = None) -> dict:
87
+ """Send JSON-RPC request via stdio."""
88
+ if not self.process or not self.process.stdin or not self.process.stdout:
89
+ raise RuntimeError("Process not started or streams not available")
90
+
91
+ self.request_id += 1
92
+ request = {"jsonrpc": "2.0", "method": method, "id": self.request_id}
93
+
94
+ if params is not None:
95
+ request["params"] = params
96
+
97
+ # Send request
98
+ request_str = json.dumps(request) + "\n"
99
+ self.process.stdin.write(request_str.encode())
100
+ await self.process.stdin.drain()
101
+
102
+ # Read response
103
+ response_line = await self.process.stdout.readline()
104
+ response = json.loads(response_line.decode())
105
+
106
+ if "error" in response:
107
+ raise Exception(f"MCP error: {response['error']}")
108
+
109
+ return response.get("result", {})
110
+
111
+ async def _send_initialized_notification(self):
112
+ """Send initialized notification via stdio."""
113
+ if not self.process or not self.process.stdin:
114
+ raise RuntimeError("Process not started or stdin not available")
115
+
116
+ notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
117
+ notification_str = json.dumps(notification) + "\n"
118
+ self.process.stdin.write(notification_str.encode())
119
+ await self.process.stdin.drain()
120
+
121
+ async def close(self):
122
+ """Close the subprocess."""
123
+ if self.process and self.process.stdin:
124
+ self.process.stdin.close()
125
+ await self.process.wait()
126
+ logger.info("Stdio MCP connection closed")
127
+
128
+
129
+ class SSEMCPClient(BaseMCPClient):
130
+ """MCP client using Server-Sent Events transport."""
131
+
132
+ def __init__(self, url: str, headers: dict[str, str] | None = None, timeout: float = 30):
133
+ super().__init__()
134
+ self.url = url
135
+ self.headers = headers or {}
136
+ self.timeout = timeout
137
+ self.read_stream = None
138
+ self.write_stream = None
139
+ self._sse_context = None
140
+
141
+ async def start(self):
142
+ """Start the SSE connection."""
143
+ from mcp.client.sse import sse_client # type: ignore[import-untyped]
144
+
145
+ logger.info(f"Connecting to SSE MCP server at {self.url}")
146
+
147
+ self._sse_context = sse_client(
148
+ url=self.url,
149
+ headers=self.headers,
150
+ timeout=self.timeout,
151
+ sse_read_timeout=300,
152
+ )
153
+
154
+ streams = await self._sse_context.__aenter__()
155
+ self.read_stream, self.write_stream = streams
156
+ logger.info("SSE connection established")
157
+
158
+ async def send_request(self, method: str, params: dict | None = None) -> dict:
159
+ """Send JSON-RPC request via SSE."""
160
+ from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
161
+ from mcp.types import JSONRPCMessage, JSONRPCRequest # type: ignore[import-untyped]
162
+
163
+ if not self.read_stream or not self.write_stream:
164
+ raise RuntimeError("SSE streams not initialized")
165
+
166
+ self.request_id += 1
167
+ request_dict = {
168
+ "jsonrpc": "2.0",
169
+ "method": method,
170
+ "id": self.request_id,
171
+ }
172
+
173
+ if params is not None:
174
+ request_dict["params"] = params
175
+
176
+ logger.debug(f"Sending SSE request: {method} (id={self.request_id})")
177
+
178
+ request = JSONRPCRequest(**request_dict)
179
+ session_message = SessionMessage(message=JSONRPCMessage(request))
180
+ await self.write_stream.send(session_message)
181
+
182
+ # Read response
183
+ response_message = await self.read_stream.receive()
184
+
185
+ if hasattr(response_message.message.root, "error"):
186
+ error = response_message.message.root.error
187
+ raise Exception(f"MCP error: {error}")
188
+
189
+ if hasattr(response_message.message.root, "result"):
190
+ return response_message.message.root.result
191
+
192
+ raise Exception(f"Unexpected response format: {response_message}")
193
+
194
+ async def _send_initialized_notification(self):
195
+ """Send initialized notification via SSE."""
196
+ from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
197
+ from mcp.types import JSONRPCMessage, JSONRPCNotification # type: ignore[import-untyped]
198
+
199
+ if not self.write_stream:
200
+ raise RuntimeError("SSE write stream not initialized")
201
+
202
+ notification = JSONRPCNotification(
203
+ jsonrpc="2.0",
204
+ method="notifications/initialized",
205
+ )
206
+
207
+ session_message = SessionMessage(message=JSONRPCMessage(notification))
208
+ await self.write_stream.send(session_message)
209
+
210
+ async def close(self):
211
+ """Close the SSE connection."""
212
+ if self._sse_context:
213
+ try:
214
+ await self._sse_context.__aexit__(None, None, None)
215
+ logger.info("SSE connection closed")
216
+ except Exception as e:
217
+ logger.warning(f"Error closing SSE connection: {e}")
218
+
219
+
220
+ class StreamableHTTPMCPClient(BaseMCPClient):
221
+ """MCP client using StreamableHTTP transport (production-grade)."""
222
+
223
+ def __init__(
224
+ self,
225
+ url: str,
226
+ headers: dict[str, str] | None = None,
227
+ timeout: float = 30,
228
+ sse_read_timeout: float = 300,
229
+ ):
230
+ super().__init__()
231
+ self.url = url
232
+ self.headers = headers or {}
233
+ self.timeout = timeout
234
+ self.sse_read_timeout = sse_read_timeout
235
+ self.read_stream = None
236
+ self.write_stream = None
237
+ self._transport_context = None
238
+
239
+ async def start(self):
240
+ """Start the StreamableHTTP connection."""
241
+ from mcp.client.streamable_http import streamable_http_client # type: ignore[import-untyped]
242
+
243
+ logger.info(f"Connecting to StreamableHTTP MCP server at {self.url}")
244
+
245
+ self._transport_context = streamable_http_client(
246
+ url=self.url,
247
+ headers=self.headers,
248
+ timeout=self.timeout,
249
+ sse_read_timeout=self.sse_read_timeout,
250
+ )
251
+
252
+ streams = await self._transport_context.__aenter__()
253
+ self.read_stream, self.write_stream = streams
254
+ logger.info("StreamableHTTP connection established")
255
+
256
+ async def send_request(self, method: str, params: dict | None = None) -> dict:
257
+ """Send JSON-RPC request via StreamableHTTP."""
258
+ from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
259
+ from mcp.types import JSONRPCMessage, JSONRPCRequest # type: ignore[import-untyped]
260
+
261
+ if not self.read_stream or not self.write_stream:
262
+ raise RuntimeError("StreamableHTTP streams not initialized")
263
+
264
+ self.request_id += 1
265
+ request_dict = {
266
+ "jsonrpc": "2.0",
267
+ "method": method,
268
+ "id": self.request_id,
269
+ }
270
+
271
+ if params is not None:
272
+ request_dict["params"] = params
273
+
274
+ logger.debug(f"Sending StreamableHTTP request: {method} (id={self.request_id})")
275
+
276
+ request = JSONRPCRequest(**request_dict)
277
+ session_message = SessionMessage(message=JSONRPCMessage(request))
278
+ await self.write_stream.send(session_message)
279
+
280
+ # Read response
281
+ response_message = await self.read_stream.receive()
282
+
283
+ if hasattr(response_message.message.root, "error"):
284
+ error = response_message.message.root.error
285
+ raise Exception(f"MCP error: {error}")
286
+
287
+ if hasattr(response_message.message.root, "result"):
288
+ return response_message.message.root.result
289
+
290
+ raise Exception(f"Unexpected response format: {response_message}")
291
+
292
+ async def _send_initialized_notification(self):
293
+ """Send initialized notification via StreamableHTTP."""
294
+ from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
295
+ from mcp.types import JSONRPCMessage, JSONRPCNotification # type: ignore[import-untyped]
296
+
297
+ if not self.write_stream:
298
+ raise RuntimeError("StreamableHTTP write stream not initialized")
299
+
300
+ notification = JSONRPCNotification(
301
+ jsonrpc="2.0",
302
+ method="notifications/initialized",
303
+ )
304
+
305
+ session_message = SessionMessage(message=JSONRPCMessage(notification))
306
+ await self.write_stream.send(session_message)
307
+
308
+ async def close(self):
309
+ """Close the StreamableHTTP connection."""
310
+ if self._transport_context:
311
+ try:
312
+ await self._transport_context.__aexit__(None, None, None)
313
+ logger.info("StreamableHTTP connection closed")
314
+ except Exception as e:
315
+ logger.warning(f"Error closing StreamableHTTP connection: {e}")
316
+
317
+
318
+ def create_mcp_client(
319
+ server_params: Any = None,
320
+ remote_url: str | None = None,
321
+ remote_transport: str = "sse",
322
+ remote_headers: dict[str, str] | None = None,
323
+ remote_timeout: float = 30,
324
+ sse_read_timeout: float = 300,
325
+ ) -> BaseMCPClient:
326
+ """
327
+ Factory function to create the appropriate MCP client.
328
+
329
+ Args:
330
+ server_params: StdioServerParameters for local server
331
+ remote_url: URL for remote server
332
+ remote_transport: "sse" or "streamable_http"
333
+ remote_headers: HTTP headers for remote connections
334
+ remote_timeout: Timeout for HTTP operations
335
+ sse_read_timeout: Timeout for SSE streaming
336
+
337
+ Returns:
338
+ BaseMCPClient instance (Stdio, SSE, or StreamableHTTP)
339
+
340
+ Raises:
341
+ ValueError: If configuration is invalid
342
+ """
343
+ if server_params and remote_url:
344
+ raise ValueError("Provide either server_params (local) or remote_url (remote), not both")
345
+ if not server_params and not remote_url:
346
+ raise ValueError("Must provide either server_params (local) or remote_url (remote)")
347
+
348
+ if server_params:
349
+ return StdioMCPClient(command=server_params.command, args=server_params.args)
350
+ elif remote_url: # Type guard ensures remote_url is not None
351
+ if remote_transport == "sse":
352
+ return SSEMCPClient(url=remote_url, headers=remote_headers, timeout=remote_timeout)
353
+ elif remote_transport == "streamable_http":
354
+ return StreamableHTTPMCPClient(
355
+ url=remote_url,
356
+ headers=remote_headers,
357
+ timeout=remote_timeout,
358
+ sse_read_timeout=sse_read_timeout,
359
+ )
360
+ else:
361
+ raise ValueError(f"Unknown remote transport: {remote_transport}. Must be 'sse' or 'streamable_http'")
362
+ else:
363
+ # This should never happen due to earlier checks
364
+ raise ValueError("Must provide either server_params (local) or remote_url (remote)")
@@ -0,0 +1,9 @@
1
+ ### Terminal-bench adapter
2
+
3
+ This adapter is used to optimize the system prompt/terminal-use instruction for the default Terminus agent through custom a `GEPAAdapter` implementation.
4
+
5
+ To run this example, you need to install `pip install terminal-bench` and run the following command:
6
+
7
+ ```bash
8
+ python src/gepa/examples/terminal-bench/train_terminus.py --model_name=gpt-5-mini
9
+ ```
@@ -0,0 +1,217 @@
1
+ import json
2
+ import os
3
+ import subprocess
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ from pydantic import BaseModel
8
+ from terminal_bench.agents.terminus_1 import CommandBatchResponse
9
+
10
+ from mantisdk.algorithm.gepa.lib import EvaluationBatch, GEPAAdapter
11
+
12
+
13
+ class TerminalBenchTask(BaseModel):
14
+ task_id: str
15
+ model_name: str
16
+
17
+
18
+ def run_agent_tb(
19
+ task_ids: str | list[str],
20
+ run_id: str,
21
+ model_name: str,
22
+ instruction_prompt: str,
23
+ dataset_name: str = "terminal-bench-core",
24
+ dataset_version: str = "head",
25
+ agent_import_path: str = "train_terminus:TerminusWrapper",
26
+ n_concurrent: int = 6,
27
+ prompt_template_path: str = "prompt-templates/instruction_prompt.txt",
28
+ ):
29
+ """Run the replay agent for multiple task IDs using tb run command."""
30
+
31
+ env = os.environ.copy()
32
+ # write instruction prompt to file
33
+ with open(prompt_template_path, "w") as f:
34
+ f.write(instruction_prompt)
35
+
36
+ cmd = [
37
+ "tb",
38
+ "run",
39
+ "--dataset-name",
40
+ dataset_name,
41
+ "--dataset-version",
42
+ dataset_version,
43
+ "--agent-import-path",
44
+ agent_import_path,
45
+ "--model-name",
46
+ model_name,
47
+ "--run-id",
48
+ run_id,
49
+ "--n-concurrent",
50
+ str(n_concurrent),
51
+ "--output-path",
52
+ str(Path(os.getcwd()) / "runs"),
53
+ ]
54
+ if isinstance(task_ids, list):
55
+ for task_id in task_ids:
56
+ cmd.extend(["--task-id", task_id])
57
+ else:
58
+ cmd.extend(["--task-id", task_ids])
59
+
60
+ print(f"Running command: {' '.join(cmd)}")
61
+
62
+ try:
63
+ result = subprocess.run(cmd, env=env, cwd=Path(prompt_template_path).parent.parent, check=True)
64
+ print(f"Command completed successfully with return code: {result.returncode}")
65
+ return result.returncode
66
+ except subprocess.CalledProcessError as e:
67
+ print(f"Command failed with return code: {e.returncode}")
68
+ return e.returncode
69
+ except Exception as e:
70
+ print(f"Error running command: {e}")
71
+ return 1
72
+
73
+
74
+ def get_results(task_id: str, run_id: str) -> tuple[int, list]:
75
+ def _read_episode_response(episode_dir: Path) -> CommandBatchResponse | None:
76
+ """Helper method to read and parse response.json from an episode directory."""
77
+ response_file = episode_dir / "response.json"
78
+ if response_file.exists():
79
+ try:
80
+ response_content = response_file.read_text()
81
+ return CommandBatchResponse.model_validate_json(response_content)
82
+ except Exception:
83
+ pass
84
+ return None
85
+
86
+ def _get_logging_dir(task_id: str, run_id: str):
87
+ logging_dir_base = Path("runs") / run_id / task_id
88
+ for dir in logging_dir_base.iterdir():
89
+ if dir.is_dir() and dir.name.startswith(task_id):
90
+ return dir
91
+ raise ValueError(f"No logging directory found for task {task_id} and run {run_id}")
92
+
93
+ logging_dir = _get_logging_dir(task_id, run_id)
94
+ result_json = logging_dir / "results.json"
95
+ with open(result_json) as f:
96
+ result = json.load(f)
97
+ if result.get("parser_results", None):
98
+ score = sum(x == "passed" for x in result["parser_results"].values())
99
+ else:
100
+ score = 0
101
+
102
+ if result.get("is_resolved", None):
103
+ success = True
104
+ else:
105
+ success = False
106
+
107
+ failed_reason = result.get("failure_mode", "unknown")
108
+
109
+ trajectory_path = logging_dir / "agent-logs"
110
+ episode_dirs = []
111
+ for dir in trajectory_path.iterdir():
112
+ if dir.is_dir() and dir.name.startswith("episode-"):
113
+ episode_dirs.append(dir)
114
+
115
+ if episode_dirs:
116
+ # Sort by episode number to get the last one
117
+ episode_dirs.sort(key=lambda x: int(x.name.split("-")[1]))
118
+ last_episode_dir = episode_dirs[-1]
119
+
120
+ last_episode_dir_trajectory = last_episode_dir / "debug.json"
121
+ with open(last_episode_dir_trajectory) as f:
122
+ trajectory = json.load(f)
123
+
124
+ if "input" in trajectory and isinstance(trajectory["input"], list):
125
+ messages = trajectory["input"]
126
+
127
+ # Add the last assistant response using helper method
128
+ parsed_response = _read_episode_response(last_episode_dir)
129
+
130
+ if parsed_response:
131
+ assistant_message = {
132
+ "role": "assistant",
133
+ "content": parsed_response.model_dump_json(),
134
+ }
135
+ messages.append(assistant_message)
136
+
137
+ return success, score, failed_reason, messages
138
+
139
+
140
+ class TerminusAdapter(GEPAAdapter):
141
+ def __init__(
142
+ self,
143
+ n_concurrent: int = 6,
144
+ instruction_prompt_path: str = "prompt-templates/instruction_prompt.txt",
145
+ ):
146
+ self.n_concurrent = n_concurrent
147
+ self.instruction_prompt_path = instruction_prompt_path
148
+
149
+ def evaluate(
150
+ self,
151
+ batch: list[TerminalBenchTask],
152
+ candidate: dict[str, str],
153
+ capture_traces: bool = False,
154
+ ) -> EvaluationBatch:
155
+ outputs = []
156
+ scores = []
157
+ trajectories = []
158
+ example_run_id = "temp_gepa_run" + "_" + datetime.now().strftime("%Y%m%d%H%M%S")
159
+ example_model_name = batch[0].model_name
160
+
161
+ run_agent_tb(
162
+ [task.task_id for task in batch],
163
+ example_run_id,
164
+ example_model_name,
165
+ instruction_prompt=candidate["instruction_prompt"],
166
+ n_concurrent=self.n_concurrent,
167
+ prompt_template_path=self.instruction_prompt_path,
168
+ )
169
+
170
+ for example in batch:
171
+ try:
172
+ success, score, failed_reason, messages = get_results(example.task_id, example_run_id)
173
+ except Exception as e:
174
+ print(f"Error running example {example.task_id} {example_run_id}: {e}")
175
+ success = False
176
+ score = 0
177
+ failed_reason = str(e)
178
+ messages = []
179
+
180
+ outputs.append(
181
+ f"Terminal Bench outputs are omitted. Please see runs/{example_run_id}/{example.task_id}/ for detailed logging."
182
+ )
183
+ scores.append(score)
184
+ trajectories.append(
185
+ {
186
+ "messages": messages,
187
+ "instruction_prompt": candidate["instruction_prompt"],
188
+ "failed_reason": failed_reason,
189
+ "success": success,
190
+ }
191
+ )
192
+ return EvaluationBatch(
193
+ outputs=outputs,
194
+ scores=scores,
195
+ trajectories=trajectories,
196
+ )
197
+
198
+ def make_reflective_dataset(
199
+ self,
200
+ candidate: dict[str, str],
201
+ eval_batch: EvaluationBatch,
202
+ components_to_update: list[str],
203
+ ):
204
+ reflective_dataset = {"instruction_prompt": []}
205
+ for _score, trajectory in zip(eval_batch.scores, eval_batch.trajectories, strict=False):
206
+ if trajectory["success"]:
207
+ feedback = "Successfully solved the task!"
208
+ else:
209
+ feedback = f"Failed to solve the task. Reason: {trajectory['failed_reason']}"
210
+ reflective_dataset["instruction_prompt"].append(
211
+ {
212
+ "Message History": trajectory["messages"],
213
+ "Instruction Prompt": candidate["instruction_prompt"],
214
+ "Feedback": feedback,
215
+ }
216
+ )
217
+ return reflective_dataset