coding-agent-wrapper 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caw/mcp.py ADDED
@@ -0,0 +1,602 @@
1
+ """FastMCP-based HTTP tool-server infrastructure for CAW.
2
+
3
+ Ported from ``poc_playground/poc/utils/mcp.py`` with additions for:
4
+ - Synchronous start/stop bridge (CAW's Agent API is sync, FastMCP/Uvicorn are async)
5
+ - Subagent tool factory (replaces the old stdio ``subagent_server.py``)
6
+
7
+ Trajectory marker constants must stay in sync with ``caw.agent._TRAJ_MARKER_RE``.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ import os
16
+ import re
17
+ import socket
18
+ import threading
19
+ import uuid as uuid_mod
20
+ from collections.abc import AsyncIterator
21
+ from contextlib import asynccontextmanager
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Awaitable, Callable
24
+
25
+ import uvicorn
26
+ from mcp.server.fastmcp import Context, FastMCP
27
+
28
+ __all__ = [
29
+ "Context",
30
+ "MCPServerHandle",
31
+ "create_mcp_http_server_bundle",
32
+ "create_subagent_tool_server",
33
+ "get_state_from_context",
34
+ "mcp_tool",
35
+ "register_tool",
36
+ ]
37
+
38
+ logging.getLogger("mcp.server.streamable_http_manager").setLevel(logging.WARNING)
39
+ logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING)
40
+
41
+ # -- Trajectory markers (shared with caw.agent) --------------------------------
42
+
43
+ _TRAJ_MARKER_PREFIX = "\n<!-- caw_traj:"
44
+ _TRAJ_MARKER_SUFFIX = " -->"
45
+
46
+
47
+ # -- Helpers -------------------------------------------------------------------
48
+
49
+
50
+ async def _wait_for_server_ready(host: str, port: int, timeout: float = 5.0) -> None:
51
+ """Poll until the HTTP endpoint is listening."""
52
+ loop = asyncio.get_running_loop()
53
+ deadline = loop.time() + timeout
54
+ while True:
55
+ try:
56
+ reader, writer = await asyncio.open_connection(host, port)
57
+ except OSError:
58
+ if loop.time() >= deadline:
59
+ raise RuntimeError(f"MCP server {host}:{port} did not start in time.") from None
60
+ await asyncio.sleep(0.1)
61
+ continue
62
+ writer.close()
63
+ await writer.wait_closed()
64
+ return
65
+
66
+
67
+ def _create_bound_socket(host: str) -> socket.socket:
68
+ """Create and return a bound socket, keeping it open to reserve the port."""
69
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
70
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
71
+ sock.bind((host, 0))
72
+ sock.listen(128)
73
+ sock.setblocking(False)
74
+ return sock
75
+
76
+
77
+ # -- Decorator / registration --------------------------------------------------
78
+
79
+
80
+ def mcp_tool(
81
+ name: str | None = None,
82
+ *,
83
+ title: str | None = None,
84
+ description: str | None = None,
85
+ annotations: Any | None = None,
86
+ icons: list[Any] | None = None,
87
+ meta: dict[str, Any] | None = None,
88
+ structured_output: bool | None = None,
89
+ ):
90
+ """Decorator to attach MCP metadata to a tool function."""
91
+
92
+ def decorator(func: Callable[..., Any]):
93
+ info = {
94
+ "name": name,
95
+ "title": title,
96
+ "description": description or func.__doc__ or "",
97
+ "annotations": annotations,
98
+ "icons": icons,
99
+ "meta": meta,
100
+ "structured_output": structured_output,
101
+ }
102
+ setattr(func, "_mcp_tool_info", info)
103
+ return func
104
+
105
+ return decorator
106
+
107
+
108
+ def register_tool(server: FastMCP, func: Callable[..., Any]) -> None:
109
+ """Register a decorated tool function with a FastMCP server."""
110
+ info = getattr(func, "_mcp_tool_info", None) or getattr(func, "_toolkit_tool_info", {})
111
+ server.tool(
112
+ name=info.get("name"),
113
+ title=info.get("title"),
114
+ description=info.get("description", func.__doc__ or ""),
115
+ annotations=info.get("annotations"),
116
+ icons=info.get("icons"),
117
+ structured_output=info.get("structured_output"),
118
+ )(func)
119
+
120
+
121
+ def get_state_from_context(ctx: Context) -> Any:
122
+ """Return the lifespan state object from a tool Context."""
123
+ return ctx.request_context.lifespan_context
124
+
125
+
126
+ # -- MCPServerHandle -----------------------------------------------------------
127
+
128
+
129
+ @dataclass
130
+ class MCPServerHandle:
131
+ """Convenience wrapper exposing runner + agent config for a FastMCP server."""
132
+
133
+ server_id: str
134
+ server: FastMCP
135
+ host: str = "127.0.0.1"
136
+ port: int | None = None
137
+ path: str | None = None
138
+ _state_instance: Any = None
139
+ _server_task: asyncio.Task | None = None
140
+ _uvicorn_server: uvicorn.Server | None = None
141
+ uvicorn_log_level: str | None = "error"
142
+ _bound_socket: socket.socket | None = None
143
+
144
+ # -- Sync bridge fields (set by start_sync / stop_sync) --
145
+ _daemon_thread: threading.Thread | None = field(default=None, repr=False)
146
+ _daemon_loop: asyncio.AbstractEventLoop | None = field(default=None, repr=False)
147
+ _ready_event: threading.Event | None = field(default=None, repr=False)
148
+ _startup_error: BaseException | None = field(default=None, repr=False)
149
+
150
+ def _build_uvicorn_server(self) -> uvicorn.Server:
151
+ host, port = self._ensure_address()
152
+ app = self.server.streamable_http_app()
153
+ log_level = self.uvicorn_log_level or self.server.settings.log_level
154
+ if isinstance(log_level, str):
155
+ log_level = log_level.lower()
156
+ config = uvicorn.Config(
157
+ app,
158
+ host=host,
159
+ port=port,
160
+ loop="asyncio",
161
+ log_level=log_level,
162
+ access_log=False,
163
+ )
164
+ return uvicorn.Server(config)
165
+
166
+ @property
167
+ def url(self) -> str:
168
+ host, port = self._ensure_address()
169
+ path = self._ensure_path()
170
+ return f"http://{host}:{port}{path}"
171
+
172
+ def runner(self) -> Callable[[], None]:
173
+ """Return a sync function that blocks while serving the MCP HTTP endpoint."""
174
+ bound_socket = self._bound_socket
175
+
176
+ async def _serve() -> None:
177
+ uvicorn_server = self._build_uvicorn_server()
178
+ if bound_socket is not None:
179
+ await uvicorn_server.serve(sockets=[bound_socket])
180
+ else:
181
+ await uvicorn_server.serve()
182
+
183
+ def _run() -> None:
184
+ asyncio.run(_serve())
185
+
186
+ return _run
187
+
188
+ @asynccontextmanager
189
+ async def run_in_background(self) -> AsyncIterator[None]:
190
+ """Async context manager that runs the HTTP server on a background task."""
191
+ await self.start()
192
+ try:
193
+ yield
194
+ finally:
195
+ await self.stop()
196
+
197
+ def get_state(self) -> Any:
198
+ """Return the cached state instance (if provided)."""
199
+ return self._state_instance
200
+
201
+ # -- Async start / stop -----------------------------------------------
202
+
203
+ async def start(self, max_retries: int = 5) -> None:
204
+ """Start the HTTP server in the background with retry on port conflict."""
205
+ if self._server_task is not None:
206
+ raise RuntimeError("Server already running")
207
+
208
+ last_error = None
209
+ for attempt in range(max_retries):
210
+ try:
211
+ if self._bound_socket is None:
212
+ self._bound_socket = _create_bound_socket(self.host)
213
+ self.port = self._bound_socket.getsockname()[1]
214
+ self.server.settings.port = self.port
215
+
216
+ uvicorn_server = self._build_uvicorn_server()
217
+ self._uvicorn_server = uvicorn_server
218
+ self._server_task = asyncio.create_task(uvicorn_server.serve(sockets=[self._bound_socket]))
219
+ host, port = self._ensure_address()
220
+ await _wait_for_server_ready(host, port)
221
+ return
222
+ except (OSError, RuntimeError) as e:
223
+ last_error = e
224
+ if self._uvicorn_server is not None:
225
+ self._uvicorn_server.should_exit = True
226
+ if self._server_task is not None:
227
+ self._server_task.cancel()
228
+ try:
229
+ await self._server_task
230
+ except (asyncio.CancelledError, Exception):
231
+ pass
232
+ self._server_task = None
233
+ self._uvicorn_server = None
234
+ if self._bound_socket is not None:
235
+ try:
236
+ self._bound_socket.close()
237
+ except OSError:
238
+ pass
239
+ self._bound_socket = None
240
+ if attempt < max_retries - 1:
241
+ await asyncio.sleep(0.1 * (attempt + 1))
242
+
243
+ raise RuntimeError(f"Failed to start MCP server after {max_retries} attempts: {last_error}")
244
+
245
+ async def stop(self) -> None:
246
+ """Stop the background HTTP server."""
247
+ if self._server_task is None:
248
+ return
249
+ if self._uvicorn_server is not None:
250
+ self._uvicorn_server.should_exit = True
251
+ try:
252
+ await self._server_task
253
+ except asyncio.CancelledError:
254
+ pass
255
+ finally:
256
+ self._server_task = None
257
+ self._uvicorn_server = None
258
+
259
+ # -- Sync bridge (for use from synchronous CAW Agent API) ---------------
260
+
261
+ def start_sync(self, timeout: float = 30.0) -> None:
262
+ """Start the server from a synchronous context.
263
+
264
+ Spawns a daemon thread with its own event loop, starts the server
265
+ inside it, and blocks the calling thread until the server is ready.
266
+ """
267
+ if self._daemon_thread is not None:
268
+ raise RuntimeError("Server already running (sync)")
269
+
270
+ ready = threading.Event()
271
+ self._ready_event = ready
272
+ self._startup_error = None
273
+
274
+ def _daemon_main() -> None:
275
+ loop = asyncio.new_event_loop()
276
+ asyncio.set_event_loop(loop)
277
+ self._daemon_loop = loop
278
+
279
+ async def _run() -> None:
280
+ try:
281
+ await self.start()
282
+ except BaseException as exc:
283
+ self._startup_error = exc
284
+ return
285
+ finally:
286
+ ready.set()
287
+
288
+ # Keep the loop alive while the server task runs
289
+ if self._server_task is not None:
290
+ try:
291
+ await self._server_task
292
+ except asyncio.CancelledError:
293
+ pass
294
+
295
+ try:
296
+ loop.run_until_complete(_run())
297
+ finally:
298
+ # Cancel lingering tasks (e.g. SSE shutdown watchers) to avoid
299
+ # "Task was destroyed but it is pending!" warnings on exit.
300
+ pending = asyncio.all_tasks(loop)
301
+ for task in pending:
302
+ task.cancel()
303
+ if pending:
304
+ loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
305
+ loop.close()
306
+
307
+ thread = threading.Thread(target=_daemon_main, daemon=True)
308
+ self._daemon_thread = thread
309
+ thread.start()
310
+
311
+ ready.wait(timeout=timeout)
312
+ if self._startup_error is not None:
313
+ # Join the thread since startup failed
314
+ thread.join(timeout=5)
315
+ self._daemon_thread = None
316
+ raise RuntimeError(f"MCP server failed to start: {self._startup_error}") from self._startup_error
317
+
318
+ def stop_sync(self, timeout: float = 10.0) -> None:
319
+ """Stop the server from a synchronous context."""
320
+ if self._daemon_thread is None or self._daemon_loop is None:
321
+ return
322
+
323
+ loop = self._daemon_loop
324
+ future = asyncio.run_coroutine_threadsafe(self.stop(), loop)
325
+ try:
326
+ future.result(timeout=timeout)
327
+ except Exception:
328
+ pass
329
+
330
+ self._daemon_thread.join(timeout=timeout)
331
+ self._daemon_thread = None
332
+ self._daemon_loop = None
333
+ self._ready_event = None
334
+
335
+ # -- Internal helpers -------------------------------------------------
336
+
337
+ def _ensure_address(self) -> tuple[str, int]:
338
+ if self.port is None:
339
+ self.port = self.server.settings.port
340
+ return self.host, self.port
341
+
342
+ def _ensure_path(self) -> str:
343
+ if self.path is None:
344
+ self.path = f"/mcp/{self.server_id}/{uuid_mod.uuid4().hex[:6]}"
345
+ self.server.settings.streamable_http_path = self.path
346
+ return self.path
347
+
348
+
349
+ # -- Factory -------------------------------------------------------------------
350
+
351
+
352
+ def create_mcp_http_server_bundle(
353
+ server_id: str,
354
+ *,
355
+ display_name: str,
356
+ state_factory: Callable[[], Any] | None = None,
357
+ state_instance: Any = None,
358
+ state_display_name: str | None = None,
359
+ state_shutdown: Callable[[Any], None | Awaitable[None]] | None = None,
360
+ uvicorn_log_level: str | None = "error",
361
+ state_logger: Callable[[str], None] | None = None,
362
+ **fastmcp_kwargs: Any,
363
+ ) -> MCPServerHandle:
364
+ """Create a FastMCP server plus helpers for running it over HTTP."""
365
+ lifespan = fastmcp_kwargs.pop("lifespan", None)
366
+ if lifespan is not None and (state_factory is not None or state_instance is not None):
367
+ raise ValueError("Provide either lifespan or state_factory/state_instance, not both.")
368
+ if state_factory is not None and state_instance is not None:
369
+ raise ValueError("Provide either state_factory or state_instance, not both.")
370
+
371
+ if state_factory is not None:
372
+ state_instance = state_factory()
373
+
374
+ state = state_instance
375
+ if state is not None:
376
+ active = 0
377
+
378
+ @asynccontextmanager
379
+ async def managed_lifespan(_server: FastMCP):
380
+ nonlocal active
381
+ first_start = active == 0
382
+ if first_start and state_display_name:
383
+ message = f"Initializing {state_display_name}"
384
+ if state_logger:
385
+ state_logger(message)
386
+ active += 1
387
+ try:
388
+ yield state
389
+ finally:
390
+ active -= 1
391
+ if active == 0:
392
+ if state_display_name:
393
+ message = f"Shutting down {state_display_name}"
394
+ if state_logger:
395
+ state_logger(message)
396
+ if state_shutdown is not None:
397
+ result = state_shutdown(state)
398
+ if isinstance(result, Awaitable):
399
+ await result
400
+
401
+ lifespan = managed_lifespan
402
+
403
+ streamable_path = f"/mcp/{server_id}/{uuid_mod.uuid4().hex[:6]}"
404
+
405
+ bound_socket = _create_bound_socket("127.0.0.1")
406
+ port = bound_socket.getsockname()[1]
407
+
408
+ server = FastMCP(
409
+ name=display_name,
410
+ host="127.0.0.1",
411
+ port=port,
412
+ streamable_http_path=streamable_path,
413
+ lifespan=lifespan,
414
+ **fastmcp_kwargs,
415
+ )
416
+ return MCPServerHandle(
417
+ server_id=server_id,
418
+ server=server,
419
+ host="127.0.0.1",
420
+ port=port,
421
+ path=streamable_path,
422
+ _state_instance=state_instance,
423
+ uvicorn_log_level=uvicorn_log_level,
424
+ _bound_socket=bound_socket,
425
+ )
426
+
427
+
428
+ def create_stateless_tool_server(
429
+ funcs: list[Callable[..., Any]],
430
+ *,
431
+ server_id: str | None = None,
432
+ display_name: str = "stateless_tools",
433
+ ) -> MCPServerHandle:
434
+ """Bundle plain functions into a single :class:`MCPServerHandle`.
435
+
436
+ Each function is registered as an MCP tool. Functions may optionally be
437
+ decorated with :func:`mcp_tool` or :func:`~caw.toolkit.tool` to supply
438
+ metadata; bare functions are registered using their name and docstring.
439
+ """
440
+ sid = server_id or f"general_{uuid_mod.uuid4().hex[:6]}"
441
+ handle = create_mcp_http_server_bundle(sid, display_name=display_name)
442
+ for func in funcs:
443
+ register_tool(handle.server, func)
444
+ return handle
445
+
446
+
447
+ # -- Name sanitization ---------------------------------------------------------
448
+
449
+ _INVALID_TOOL_NAME_RE = re.compile(r"[^A-Za-z0-9_\-.]")
450
+
451
+
452
+ def _sanitize_tool_name(name: str) -> str:
453
+ """Replace characters invalid in MCP tool names with underscores."""
454
+ return _INVALID_TOOL_NAME_RE.sub("_", name)
455
+
456
+
457
+ # -- Subagent tool factory ----------------------------------------------------
458
+
459
+
460
+ @dataclass
461
+ class SubagentState:
462
+ """Lifespan state for a subagent tool server."""
463
+
464
+ name: str
465
+ description: str
466
+ system_prompt: str
467
+ model: str
468
+ traj_dir: str
469
+ jsonl_path: str
470
+ tools: Any = None
471
+ tool_servers: list = field(default_factory=list)
472
+ mcp_servers: list = field(default_factory=list)
473
+ subagents: list = field(default_factory=list)
474
+
475
+
476
+ def _run_subagent_blocking(
477
+ prompt: str,
478
+ system_prompt: str,
479
+ model: str,
480
+ traj_dir: str,
481
+ jsonl_path: str,
482
+ subagent_name: str,
483
+ tools: Any = None,
484
+ tool_servers: list | None = None,
485
+ mcp_servers: list | None = None,
486
+ subagents: list | None = None,
487
+ ) -> str:
488
+ """Run a single-turn subagent synchronously (called from a thread).
489
+
490
+ Returns ``result_text`` with an optional trajectory marker appended.
491
+ """
492
+ from caw import Agent
493
+
494
+ agent = Agent(
495
+ system_prompt=system_prompt,
496
+ model=model or None,
497
+ tools=tools,
498
+ data_dir=None,
499
+ )
500
+
501
+ for ts in tool_servers or []:
502
+ agent.add_tool_server(ts)
503
+
504
+ for srv in mcp_servers or []:
505
+ agent.add_mcp_server(srv)
506
+
507
+ for sub in subagents or []:
508
+ agent.add_subagent(sub)
509
+
510
+ try:
511
+ with agent.start_session() as session:
512
+ turn = session.send(prompt)
513
+ traj = session.trajectory
514
+ except Exception as e:
515
+ return f"Error: {e}"
516
+
517
+ result_text = turn.result
518
+ traj_dict = traj.to_dict()
519
+
520
+ # Write subagent events to parent's JSONL
521
+ if traj.turns and jsonl_path:
522
+ try:
523
+ from caw.storage import JsonlWriter
524
+
525
+ writer = JsonlWriter(jsonl_path, subagent=subagent_name)
526
+ for i, t in enumerate(traj.turns):
527
+ writer.write_turn_events(t, i)
528
+ except Exception:
529
+ pass
530
+
531
+ # Write trajectory to file and embed marker in response
532
+ traj_marker = ""
533
+ if traj_dict and traj_dir:
534
+ traj_id = str(uuid_mod.uuid4())
535
+ traj_path = os.path.join(traj_dir, f"{traj_id}.json")
536
+ try:
537
+ os.makedirs(traj_dir, exist_ok=True)
538
+ with open(traj_path, "w") as f:
539
+ json.dump(traj_dict, f, indent=2)
540
+ traj_marker = f"{_TRAJ_MARKER_PREFIX}{traj_id}{_TRAJ_MARKER_SUFFIX}"
541
+ except OSError:
542
+ pass
543
+
544
+ return result_text + traj_marker
545
+
546
+
547
+ def create_subagent_tool_server(
548
+ spec: Any,
549
+ traj_dir: str,
550
+ jsonl_path: str | None = None,
551
+ ) -> MCPServerHandle:
552
+ """Create an HTTP tool server that exposes a subagent as a callable tool.
553
+
554
+ Parameters
555
+ ----------
556
+ spec
557
+ An ``AgentSpec`` with ``name``, ``description``, ``system_prompt``, ``model``.
558
+ traj_dir
559
+ Directory where subagent trajectory JSON files are written.
560
+ jsonl_path
561
+ Path to the parent session's JSONL log (for interleaved subagent events).
562
+ """
563
+ tool_name = _sanitize_tool_name(spec.name)
564
+
565
+ state = SubagentState(
566
+ name=spec.name,
567
+ description=spec.description,
568
+ system_prompt=spec.system_prompt,
569
+ model=spec.model or "",
570
+ traj_dir=traj_dir,
571
+ jsonl_path=jsonl_path or "",
572
+ tools=getattr(spec, "tools", None),
573
+ tool_servers=list(getattr(spec, "tool_servers", None) or []),
574
+ mcp_servers=list(getattr(spec, "mcp_servers", None) or []),
575
+ subagents=list(getattr(spec, "subagents", None) or []),
576
+ )
577
+
578
+ handle = create_mcp_http_server_bundle(
579
+ "subagent",
580
+ display_name=f"caw-subagent-{spec.name}",
581
+ state_instance=state,
582
+ )
583
+
584
+ @mcp_tool(name=tool_name, description=spec.description)
585
+ async def subagent_tool(prompt: str, ctx: Context) -> str:
586
+ s: SubagentState = get_state_from_context(ctx)
587
+ return await asyncio.to_thread(
588
+ _run_subagent_blocking,
589
+ prompt,
590
+ s.system_prompt,
591
+ s.model,
592
+ s.traj_dir,
593
+ s.jsonl_path,
594
+ s.name,
595
+ s.tools,
596
+ s.tool_servers,
597
+ s.mcp_servers,
598
+ s.subagents,
599
+ )
600
+
601
+ register_tool(handle.server, subagent_tool)
602
+ return handle