mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified MCP Client - Supports stdio, SSE, and StreamableHTTP transports.
|
|
3
|
+
|
|
4
|
+
This utility provides a single abstraction for connecting to MCP servers
|
|
5
|
+
using different transport mechanisms.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseMCPClient(ABC):
|
|
18
|
+
"""Abstract base class for MCP clients."""
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self.request_id = 0
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
async def start(self):
|
|
25
|
+
"""Start the MCP connection."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def send_request(self, method: str, params: dict | None = None) -> dict:
|
|
30
|
+
"""Send JSON-RPC request and get response."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
async def close(self):
|
|
35
|
+
"""Close the connection."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
async def initialize(self) -> dict:
|
|
39
|
+
"""Initialize MCP session (common across all transports)."""
|
|
40
|
+
result = await self.send_request(
|
|
41
|
+
"initialize",
|
|
42
|
+
{
|
|
43
|
+
"protocolVersion": "2024-11-05",
|
|
44
|
+
"capabilities": {},
|
|
45
|
+
"clientInfo": {"name": "gepa-mcp-adapter", "version": "1.0"},
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
await self._send_initialized_notification()
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def _send_initialized_notification(self):
|
|
53
|
+
"""Send initialized notification (transport-specific)."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
async def list_tools(self) -> list[dict]:
|
|
57
|
+
"""List available tools."""
|
|
58
|
+
result = await self.send_request("tools/list")
|
|
59
|
+
return result.get("tools", [])
|
|
60
|
+
|
|
61
|
+
async def call_tool(self, name: str, arguments: dict) -> dict:
|
|
62
|
+
"""Call a tool."""
|
|
63
|
+
return await self.send_request("tools/call", {"name": name, "arguments": arguments})
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class StdioMCPClient(BaseMCPClient):
|
|
67
|
+
"""MCP client using stdio transport (subprocess-based)."""
|
|
68
|
+
|
|
69
|
+
def __init__(self, command: str, args: list[str]):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.command = command
|
|
72
|
+
self.args = args
|
|
73
|
+
self.process = None
|
|
74
|
+
|
|
75
|
+
async def start(self):
|
|
76
|
+
"""Start the MCP server process."""
|
|
77
|
+
logger.info(f"Starting stdio MCP server: {self.command} {' '.join(self.args)}")
|
|
78
|
+
self.process = await asyncio.create_subprocess_exec(
|
|
79
|
+
self.command,
|
|
80
|
+
*self.args,
|
|
81
|
+
stdin=asyncio.subprocess.PIPE,
|
|
82
|
+
stdout=asyncio.subprocess.PIPE,
|
|
83
|
+
stderr=asyncio.subprocess.DEVNULL,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
async def send_request(self, method: str, params: dict | None = None) -> dict:
|
|
87
|
+
"""Send JSON-RPC request via stdio."""
|
|
88
|
+
if not self.process or not self.process.stdin or not self.process.stdout:
|
|
89
|
+
raise RuntimeError("Process not started or streams not available")
|
|
90
|
+
|
|
91
|
+
self.request_id += 1
|
|
92
|
+
request = {"jsonrpc": "2.0", "method": method, "id": self.request_id}
|
|
93
|
+
|
|
94
|
+
if params is not None:
|
|
95
|
+
request["params"] = params
|
|
96
|
+
|
|
97
|
+
# Send request
|
|
98
|
+
request_str = json.dumps(request) + "\n"
|
|
99
|
+
self.process.stdin.write(request_str.encode())
|
|
100
|
+
await self.process.stdin.drain()
|
|
101
|
+
|
|
102
|
+
# Read response
|
|
103
|
+
response_line = await self.process.stdout.readline()
|
|
104
|
+
response = json.loads(response_line.decode())
|
|
105
|
+
|
|
106
|
+
if "error" in response:
|
|
107
|
+
raise Exception(f"MCP error: {response['error']}")
|
|
108
|
+
|
|
109
|
+
return response.get("result", {})
|
|
110
|
+
|
|
111
|
+
async def _send_initialized_notification(self):
|
|
112
|
+
"""Send initialized notification via stdio."""
|
|
113
|
+
if not self.process or not self.process.stdin:
|
|
114
|
+
raise RuntimeError("Process not started or stdin not available")
|
|
115
|
+
|
|
116
|
+
notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
|
117
|
+
notification_str = json.dumps(notification) + "\n"
|
|
118
|
+
self.process.stdin.write(notification_str.encode())
|
|
119
|
+
await self.process.stdin.drain()
|
|
120
|
+
|
|
121
|
+
async def close(self):
|
|
122
|
+
"""Close the subprocess."""
|
|
123
|
+
if self.process and self.process.stdin:
|
|
124
|
+
self.process.stdin.close()
|
|
125
|
+
await self.process.wait()
|
|
126
|
+
logger.info("Stdio MCP connection closed")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SSEMCPClient(BaseMCPClient):
|
|
130
|
+
"""MCP client using Server-Sent Events transport."""
|
|
131
|
+
|
|
132
|
+
def __init__(self, url: str, headers: dict[str, str] | None = None, timeout: float = 30):
|
|
133
|
+
super().__init__()
|
|
134
|
+
self.url = url
|
|
135
|
+
self.headers = headers or {}
|
|
136
|
+
self.timeout = timeout
|
|
137
|
+
self.read_stream = None
|
|
138
|
+
self.write_stream = None
|
|
139
|
+
self._sse_context = None
|
|
140
|
+
|
|
141
|
+
async def start(self):
|
|
142
|
+
"""Start the SSE connection."""
|
|
143
|
+
from mcp.client.sse import sse_client # type: ignore[import-untyped]
|
|
144
|
+
|
|
145
|
+
logger.info(f"Connecting to SSE MCP server at {self.url}")
|
|
146
|
+
|
|
147
|
+
self._sse_context = sse_client(
|
|
148
|
+
url=self.url,
|
|
149
|
+
headers=self.headers,
|
|
150
|
+
timeout=self.timeout,
|
|
151
|
+
sse_read_timeout=300,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
streams = await self._sse_context.__aenter__()
|
|
155
|
+
self.read_stream, self.write_stream = streams
|
|
156
|
+
logger.info("SSE connection established")
|
|
157
|
+
|
|
158
|
+
async def send_request(self, method: str, params: dict | None = None) -> dict:
|
|
159
|
+
"""Send JSON-RPC request via SSE."""
|
|
160
|
+
from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
|
|
161
|
+
from mcp.types import JSONRPCMessage, JSONRPCRequest # type: ignore[import-untyped]
|
|
162
|
+
|
|
163
|
+
if not self.read_stream or not self.write_stream:
|
|
164
|
+
raise RuntimeError("SSE streams not initialized")
|
|
165
|
+
|
|
166
|
+
self.request_id += 1
|
|
167
|
+
request_dict = {
|
|
168
|
+
"jsonrpc": "2.0",
|
|
169
|
+
"method": method,
|
|
170
|
+
"id": self.request_id,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if params is not None:
|
|
174
|
+
request_dict["params"] = params
|
|
175
|
+
|
|
176
|
+
logger.debug(f"Sending SSE request: {method} (id={self.request_id})")
|
|
177
|
+
|
|
178
|
+
request = JSONRPCRequest(**request_dict)
|
|
179
|
+
session_message = SessionMessage(message=JSONRPCMessage(request))
|
|
180
|
+
await self.write_stream.send(session_message)
|
|
181
|
+
|
|
182
|
+
# Read response
|
|
183
|
+
response_message = await self.read_stream.receive()
|
|
184
|
+
|
|
185
|
+
if hasattr(response_message.message.root, "error"):
|
|
186
|
+
error = response_message.message.root.error
|
|
187
|
+
raise Exception(f"MCP error: {error}")
|
|
188
|
+
|
|
189
|
+
if hasattr(response_message.message.root, "result"):
|
|
190
|
+
return response_message.message.root.result
|
|
191
|
+
|
|
192
|
+
raise Exception(f"Unexpected response format: {response_message}")
|
|
193
|
+
|
|
194
|
+
async def _send_initialized_notification(self):
|
|
195
|
+
"""Send initialized notification via SSE."""
|
|
196
|
+
from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
|
|
197
|
+
from mcp.types import JSONRPCMessage, JSONRPCNotification # type: ignore[import-untyped]
|
|
198
|
+
|
|
199
|
+
if not self.write_stream:
|
|
200
|
+
raise RuntimeError("SSE write stream not initialized")
|
|
201
|
+
|
|
202
|
+
notification = JSONRPCNotification(
|
|
203
|
+
jsonrpc="2.0",
|
|
204
|
+
method="notifications/initialized",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
session_message = SessionMessage(message=JSONRPCMessage(notification))
|
|
208
|
+
await self.write_stream.send(session_message)
|
|
209
|
+
|
|
210
|
+
async def close(self):
|
|
211
|
+
"""Close the SSE connection."""
|
|
212
|
+
if self._sse_context:
|
|
213
|
+
try:
|
|
214
|
+
await self._sse_context.__aexit__(None, None, None)
|
|
215
|
+
logger.info("SSE connection closed")
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.warning(f"Error closing SSE connection: {e}")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StreamableHTTPMCPClient(BaseMCPClient):
|
|
221
|
+
"""MCP client using StreamableHTTP transport (production-grade)."""
|
|
222
|
+
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
url: str,
|
|
226
|
+
headers: dict[str, str] | None = None,
|
|
227
|
+
timeout: float = 30,
|
|
228
|
+
sse_read_timeout: float = 300,
|
|
229
|
+
):
|
|
230
|
+
super().__init__()
|
|
231
|
+
self.url = url
|
|
232
|
+
self.headers = headers or {}
|
|
233
|
+
self.timeout = timeout
|
|
234
|
+
self.sse_read_timeout = sse_read_timeout
|
|
235
|
+
self.read_stream = None
|
|
236
|
+
self.write_stream = None
|
|
237
|
+
self._transport_context = None
|
|
238
|
+
|
|
239
|
+
async def start(self):
|
|
240
|
+
"""Start the StreamableHTTP connection."""
|
|
241
|
+
from mcp.client.streamable_http import streamable_http_client # type: ignore[import-untyped]
|
|
242
|
+
|
|
243
|
+
logger.info(f"Connecting to StreamableHTTP MCP server at {self.url}")
|
|
244
|
+
|
|
245
|
+
self._transport_context = streamable_http_client(
|
|
246
|
+
url=self.url,
|
|
247
|
+
headers=self.headers,
|
|
248
|
+
timeout=self.timeout,
|
|
249
|
+
sse_read_timeout=self.sse_read_timeout,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
streams = await self._transport_context.__aenter__()
|
|
253
|
+
self.read_stream, self.write_stream = streams
|
|
254
|
+
logger.info("StreamableHTTP connection established")
|
|
255
|
+
|
|
256
|
+
async def send_request(self, method: str, params: dict | None = None) -> dict:
|
|
257
|
+
"""Send JSON-RPC request via StreamableHTTP."""
|
|
258
|
+
from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
|
|
259
|
+
from mcp.types import JSONRPCMessage, JSONRPCRequest # type: ignore[import-untyped]
|
|
260
|
+
|
|
261
|
+
if not self.read_stream or not self.write_stream:
|
|
262
|
+
raise RuntimeError("StreamableHTTP streams not initialized")
|
|
263
|
+
|
|
264
|
+
self.request_id += 1
|
|
265
|
+
request_dict = {
|
|
266
|
+
"jsonrpc": "2.0",
|
|
267
|
+
"method": method,
|
|
268
|
+
"id": self.request_id,
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
if params is not None:
|
|
272
|
+
request_dict["params"] = params
|
|
273
|
+
|
|
274
|
+
logger.debug(f"Sending StreamableHTTP request: {method} (id={self.request_id})")
|
|
275
|
+
|
|
276
|
+
request = JSONRPCRequest(**request_dict)
|
|
277
|
+
session_message = SessionMessage(message=JSONRPCMessage(request))
|
|
278
|
+
await self.write_stream.send(session_message)
|
|
279
|
+
|
|
280
|
+
# Read response
|
|
281
|
+
response_message = await self.read_stream.receive()
|
|
282
|
+
|
|
283
|
+
if hasattr(response_message.message.root, "error"):
|
|
284
|
+
error = response_message.message.root.error
|
|
285
|
+
raise Exception(f"MCP error: {error}")
|
|
286
|
+
|
|
287
|
+
if hasattr(response_message.message.root, "result"):
|
|
288
|
+
return response_message.message.root.result
|
|
289
|
+
|
|
290
|
+
raise Exception(f"Unexpected response format: {response_message}")
|
|
291
|
+
|
|
292
|
+
async def _send_initialized_notification(self):
|
|
293
|
+
"""Send initialized notification via StreamableHTTP."""
|
|
294
|
+
from mcp.shared.message import SessionMessage # type: ignore[import-untyped]
|
|
295
|
+
from mcp.types import JSONRPCMessage, JSONRPCNotification # type: ignore[import-untyped]
|
|
296
|
+
|
|
297
|
+
if not self.write_stream:
|
|
298
|
+
raise RuntimeError("StreamableHTTP write stream not initialized")
|
|
299
|
+
|
|
300
|
+
notification = JSONRPCNotification(
|
|
301
|
+
jsonrpc="2.0",
|
|
302
|
+
method="notifications/initialized",
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
session_message = SessionMessage(message=JSONRPCMessage(notification))
|
|
306
|
+
await self.write_stream.send(session_message)
|
|
307
|
+
|
|
308
|
+
async def close(self):
|
|
309
|
+
"""Close the StreamableHTTP connection."""
|
|
310
|
+
if self._transport_context:
|
|
311
|
+
try:
|
|
312
|
+
await self._transport_context.__aexit__(None, None, None)
|
|
313
|
+
logger.info("StreamableHTTP connection closed")
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.warning(f"Error closing StreamableHTTP connection: {e}")
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def create_mcp_client(
|
|
319
|
+
server_params: Any = None,
|
|
320
|
+
remote_url: str | None = None,
|
|
321
|
+
remote_transport: str = "sse",
|
|
322
|
+
remote_headers: dict[str, str] | None = None,
|
|
323
|
+
remote_timeout: float = 30,
|
|
324
|
+
sse_read_timeout: float = 300,
|
|
325
|
+
) -> BaseMCPClient:
|
|
326
|
+
"""
|
|
327
|
+
Factory function to create the appropriate MCP client.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
server_params: StdioServerParameters for local server
|
|
331
|
+
remote_url: URL for remote server
|
|
332
|
+
remote_transport: "sse" or "streamable_http"
|
|
333
|
+
remote_headers: HTTP headers for remote connections
|
|
334
|
+
remote_timeout: Timeout for HTTP operations
|
|
335
|
+
sse_read_timeout: Timeout for SSE streaming
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
BaseMCPClient instance (Stdio, SSE, or StreamableHTTP)
|
|
339
|
+
|
|
340
|
+
Raises:
|
|
341
|
+
ValueError: If configuration is invalid
|
|
342
|
+
"""
|
|
343
|
+
if server_params and remote_url:
|
|
344
|
+
raise ValueError("Provide either server_params (local) or remote_url (remote), not both")
|
|
345
|
+
if not server_params and not remote_url:
|
|
346
|
+
raise ValueError("Must provide either server_params (local) or remote_url (remote)")
|
|
347
|
+
|
|
348
|
+
if server_params:
|
|
349
|
+
return StdioMCPClient(command=server_params.command, args=server_params.args)
|
|
350
|
+
elif remote_url: # Type guard ensures remote_url is not None
|
|
351
|
+
if remote_transport == "sse":
|
|
352
|
+
return SSEMCPClient(url=remote_url, headers=remote_headers, timeout=remote_timeout)
|
|
353
|
+
elif remote_transport == "streamable_http":
|
|
354
|
+
return StreamableHTTPMCPClient(
|
|
355
|
+
url=remote_url,
|
|
356
|
+
headers=remote_headers,
|
|
357
|
+
timeout=remote_timeout,
|
|
358
|
+
sse_read_timeout=sse_read_timeout,
|
|
359
|
+
)
|
|
360
|
+
else:
|
|
361
|
+
raise ValueError(f"Unknown remote transport: {remote_transport}. Must be 'sse' or 'streamable_http'")
|
|
362
|
+
else:
|
|
363
|
+
# This should never happen due to earlier checks
|
|
364
|
+
raise ValueError("Must provide either server_params (local) or remote_url (remote)")
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
### Terminal-bench adapter
|
|
2
|
+
|
|
3
|
+
This adapter is used to optimize the system prompt/terminal-use instruction for the default Terminus agent through custom a `GEPAAdapter` implementation.
|
|
4
|
+
|
|
5
|
+
To run this example, you need to install `pip install terminal-bench` and run the following command:
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
python src/gepa/examples/terminal-bench/train_terminus.py --model_name=gpt-5-mini
|
|
9
|
+
```
|
|
File without changes
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from terminal_bench.agents.terminus_1 import CommandBatchResponse
|
|
9
|
+
|
|
10
|
+
from mantisdk.algorithm.gepa.lib import EvaluationBatch, GEPAAdapter
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TerminalBenchTask(BaseModel):
|
|
14
|
+
task_id: str
|
|
15
|
+
model_name: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_agent_tb(
|
|
19
|
+
task_ids: str | list[str],
|
|
20
|
+
run_id: str,
|
|
21
|
+
model_name: str,
|
|
22
|
+
instruction_prompt: str,
|
|
23
|
+
dataset_name: str = "terminal-bench-core",
|
|
24
|
+
dataset_version: str = "head",
|
|
25
|
+
agent_import_path: str = "train_terminus:TerminusWrapper",
|
|
26
|
+
n_concurrent: int = 6,
|
|
27
|
+
prompt_template_path: str = "prompt-templates/instruction_prompt.txt",
|
|
28
|
+
):
|
|
29
|
+
"""Run the replay agent for multiple task IDs using tb run command."""
|
|
30
|
+
|
|
31
|
+
env = os.environ.copy()
|
|
32
|
+
# write instruction prompt to file
|
|
33
|
+
with open(prompt_template_path, "w") as f:
|
|
34
|
+
f.write(instruction_prompt)
|
|
35
|
+
|
|
36
|
+
cmd = [
|
|
37
|
+
"tb",
|
|
38
|
+
"run",
|
|
39
|
+
"--dataset-name",
|
|
40
|
+
dataset_name,
|
|
41
|
+
"--dataset-version",
|
|
42
|
+
dataset_version,
|
|
43
|
+
"--agent-import-path",
|
|
44
|
+
agent_import_path,
|
|
45
|
+
"--model-name",
|
|
46
|
+
model_name,
|
|
47
|
+
"--run-id",
|
|
48
|
+
run_id,
|
|
49
|
+
"--n-concurrent",
|
|
50
|
+
str(n_concurrent),
|
|
51
|
+
"--output-path",
|
|
52
|
+
str(Path(os.getcwd()) / "runs"),
|
|
53
|
+
]
|
|
54
|
+
if isinstance(task_ids, list):
|
|
55
|
+
for task_id in task_ids:
|
|
56
|
+
cmd.extend(["--task-id", task_id])
|
|
57
|
+
else:
|
|
58
|
+
cmd.extend(["--task-id", task_ids])
|
|
59
|
+
|
|
60
|
+
print(f"Running command: {' '.join(cmd)}")
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
result = subprocess.run(cmd, env=env, cwd=Path(prompt_template_path).parent.parent, check=True)
|
|
64
|
+
print(f"Command completed successfully with return code: {result.returncode}")
|
|
65
|
+
return result.returncode
|
|
66
|
+
except subprocess.CalledProcessError as e:
|
|
67
|
+
print(f"Command failed with return code: {e.returncode}")
|
|
68
|
+
return e.returncode
|
|
69
|
+
except Exception as e:
|
|
70
|
+
print(f"Error running command: {e}")
|
|
71
|
+
return 1
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_results(task_id: str, run_id: str) -> tuple[int, list]:
|
|
75
|
+
def _read_episode_response(episode_dir: Path) -> CommandBatchResponse | None:
|
|
76
|
+
"""Helper method to read and parse response.json from an episode directory."""
|
|
77
|
+
response_file = episode_dir / "response.json"
|
|
78
|
+
if response_file.exists():
|
|
79
|
+
try:
|
|
80
|
+
response_content = response_file.read_text()
|
|
81
|
+
return CommandBatchResponse.model_validate_json(response_content)
|
|
82
|
+
except Exception:
|
|
83
|
+
pass
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def _get_logging_dir(task_id: str, run_id: str):
|
|
87
|
+
logging_dir_base = Path("runs") / run_id / task_id
|
|
88
|
+
for dir in logging_dir_base.iterdir():
|
|
89
|
+
if dir.is_dir() and dir.name.startswith(task_id):
|
|
90
|
+
return dir
|
|
91
|
+
raise ValueError(f"No logging directory found for task {task_id} and run {run_id}")
|
|
92
|
+
|
|
93
|
+
logging_dir = _get_logging_dir(task_id, run_id)
|
|
94
|
+
result_json = logging_dir / "results.json"
|
|
95
|
+
with open(result_json) as f:
|
|
96
|
+
result = json.load(f)
|
|
97
|
+
if result.get("parser_results", None):
|
|
98
|
+
score = sum(x == "passed" for x in result["parser_results"].values())
|
|
99
|
+
else:
|
|
100
|
+
score = 0
|
|
101
|
+
|
|
102
|
+
if result.get("is_resolved", None):
|
|
103
|
+
success = True
|
|
104
|
+
else:
|
|
105
|
+
success = False
|
|
106
|
+
|
|
107
|
+
failed_reason = result.get("failure_mode", "unknown")
|
|
108
|
+
|
|
109
|
+
trajectory_path = logging_dir / "agent-logs"
|
|
110
|
+
episode_dirs = []
|
|
111
|
+
for dir in trajectory_path.iterdir():
|
|
112
|
+
if dir.is_dir() and dir.name.startswith("episode-"):
|
|
113
|
+
episode_dirs.append(dir)
|
|
114
|
+
|
|
115
|
+
if episode_dirs:
|
|
116
|
+
# Sort by episode number to get the last one
|
|
117
|
+
episode_dirs.sort(key=lambda x: int(x.name.split("-")[1]))
|
|
118
|
+
last_episode_dir = episode_dirs[-1]
|
|
119
|
+
|
|
120
|
+
last_episode_dir_trajectory = last_episode_dir / "debug.json"
|
|
121
|
+
with open(last_episode_dir_trajectory) as f:
|
|
122
|
+
trajectory = json.load(f)
|
|
123
|
+
|
|
124
|
+
if "input" in trajectory and isinstance(trajectory["input"], list):
|
|
125
|
+
messages = trajectory["input"]
|
|
126
|
+
|
|
127
|
+
# Add the last assistant response using helper method
|
|
128
|
+
parsed_response = _read_episode_response(last_episode_dir)
|
|
129
|
+
|
|
130
|
+
if parsed_response:
|
|
131
|
+
assistant_message = {
|
|
132
|
+
"role": "assistant",
|
|
133
|
+
"content": parsed_response.model_dump_json(),
|
|
134
|
+
}
|
|
135
|
+
messages.append(assistant_message)
|
|
136
|
+
|
|
137
|
+
return success, score, failed_reason, messages
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class TerminusAdapter(GEPAAdapter):
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
n_concurrent: int = 6,
|
|
144
|
+
instruction_prompt_path: str = "prompt-templates/instruction_prompt.txt",
|
|
145
|
+
):
|
|
146
|
+
self.n_concurrent = n_concurrent
|
|
147
|
+
self.instruction_prompt_path = instruction_prompt_path
|
|
148
|
+
|
|
149
|
+
def evaluate(
|
|
150
|
+
self,
|
|
151
|
+
batch: list[TerminalBenchTask],
|
|
152
|
+
candidate: dict[str, str],
|
|
153
|
+
capture_traces: bool = False,
|
|
154
|
+
) -> EvaluationBatch:
|
|
155
|
+
outputs = []
|
|
156
|
+
scores = []
|
|
157
|
+
trajectories = []
|
|
158
|
+
example_run_id = "temp_gepa_run" + "_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
|
159
|
+
example_model_name = batch[0].model_name
|
|
160
|
+
|
|
161
|
+
run_agent_tb(
|
|
162
|
+
[task.task_id for task in batch],
|
|
163
|
+
example_run_id,
|
|
164
|
+
example_model_name,
|
|
165
|
+
instruction_prompt=candidate["instruction_prompt"],
|
|
166
|
+
n_concurrent=self.n_concurrent,
|
|
167
|
+
prompt_template_path=self.instruction_prompt_path,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
for example in batch:
|
|
171
|
+
try:
|
|
172
|
+
success, score, failed_reason, messages = get_results(example.task_id, example_run_id)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
print(f"Error running example {example.task_id} {example_run_id}: {e}")
|
|
175
|
+
success = False
|
|
176
|
+
score = 0
|
|
177
|
+
failed_reason = str(e)
|
|
178
|
+
messages = []
|
|
179
|
+
|
|
180
|
+
outputs.append(
|
|
181
|
+
f"Terminal Bench outputs are omitted. Please see runs/{example_run_id}/{example.task_id}/ for detailed logging."
|
|
182
|
+
)
|
|
183
|
+
scores.append(score)
|
|
184
|
+
trajectories.append(
|
|
185
|
+
{
|
|
186
|
+
"messages": messages,
|
|
187
|
+
"instruction_prompt": candidate["instruction_prompt"],
|
|
188
|
+
"failed_reason": failed_reason,
|
|
189
|
+
"success": success,
|
|
190
|
+
}
|
|
191
|
+
)
|
|
192
|
+
return EvaluationBatch(
|
|
193
|
+
outputs=outputs,
|
|
194
|
+
scores=scores,
|
|
195
|
+
trajectories=trajectories,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def make_reflective_dataset(
|
|
199
|
+
self,
|
|
200
|
+
candidate: dict[str, str],
|
|
201
|
+
eval_batch: EvaluationBatch,
|
|
202
|
+
components_to_update: list[str],
|
|
203
|
+
):
|
|
204
|
+
reflective_dataset = {"instruction_prompt": []}
|
|
205
|
+
for _score, trajectory in zip(eval_batch.scores, eval_batch.trajectories, strict=False):
|
|
206
|
+
if trajectory["success"]:
|
|
207
|
+
feedback = "Successfully solved the task!"
|
|
208
|
+
else:
|
|
209
|
+
feedback = f"Failed to solve the task. Reason: {trajectory['failed_reason']}"
|
|
210
|
+
reflective_dataset["instruction_prompt"].append(
|
|
211
|
+
{
|
|
212
|
+
"Message History": trajectory["messages"],
|
|
213
|
+
"Instruction Prompt": candidate["instruction_prompt"],
|
|
214
|
+
"Feedback": feedback,
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
return reflective_dataset
|