augment-sdk 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
augment/acp/client.py ADDED
@@ -0,0 +1,640 @@
1
+ """
2
+ Synchronous ACP Client for Augment CLI
3
+
4
+ A clean, easy-to-use wrapper around the Agent Client Protocol for communicating
5
+ with the Augment CLI agent.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ from abc import ABC, abstractmethod
11
+ from queue import Empty, Queue
12
+ from threading import Thread
13
+ from typing import Optional, Any, List
14
+
15
+ from acp import (
16
+ Client,
17
+ ClientSideConnection,
18
+ InitializeRequest,
19
+ NewSessionRequest,
20
+ PromptRequest,
21
+ RequestError,
22
+ SessionNotification,
23
+ text_block,
24
+ PROTOCOL_VERSION,
25
+ spawn_agent_process,
26
+ )
27
+ from acp.schema import (
28
+ RequestPermissionRequest,
29
+ RequestPermissionResponse,
30
+ AllowedOutcome,
31
+ )
32
+
33
+
34
+ class AgentEventListener(ABC):
35
+ """
36
+ Interface for listening to agent events.
37
+
38
+ Implement this interface to receive notifications about what the agent is doing.
39
+ """
40
+
41
+ @abstractmethod
42
+ def on_agent_message_chunk(self, text: str) -> None:
43
+ """
44
+ Called when the agent sends a message chunk (streaming).
45
+
46
+ The agent streams its response in real-time. This method is called
47
+ multiple times with small chunks of text that together form the
48
+ complete message.
49
+
50
+ Args:
51
+ text: A chunk of text from the agent's response
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def on_tool_call(
57
+ self,
58
+ tool_call_id: str,
59
+ title: str,
60
+ kind: Optional[str] = None,
61
+ status: Optional[str] = None,
62
+ ) -> None:
63
+ """
64
+ Called when the agent makes a tool call.
65
+
66
+ Args:
67
+ tool_call_id: Unique identifier for this tool call
68
+ title: Human-readable description of what the tool is doing
69
+ kind: Category of tool (read, edit, delete, execute, etc.)
70
+ status: Current status (pending, in_progress, completed, failed)
71
+ """
72
+ pass
73
+
74
+ @abstractmethod
75
+ def on_tool_response(
76
+ self,
77
+ tool_call_id: str,
78
+ status: Optional[str] = None,
79
+ content: Optional[Any] = None,
80
+ ) -> None:
81
+ """
82
+ Called when a tool responds with results.
83
+
84
+ Args:
85
+ tool_call_id: Unique identifier for this tool call
86
+ status: Response status (completed, failed, etc.)
87
+ content: Response content/results from the tool
88
+ """
89
+ pass
90
+
91
+ def on_agent_thought(self, text: str) -> None:
92
+ """
93
+ Called when the agent shares its internal reasoning.
94
+
95
+ Args:
96
+ text: The thought content from the agent
97
+ """
98
+ pass
99
+
100
+ def on_agent_message(self, message: str) -> None:
101
+ """
102
+ Called when the agent finishes sending a complete message.
103
+
104
+ This is called once after all message chunks have been sent,
105
+ with the complete assembled message.
106
+
107
+ Args:
108
+ message: The complete message from the agent
109
+ """
110
+ pass
111
+
112
+
113
+ class _InternalACPClient(Client):
114
+ """Internal ACP client implementation."""
115
+
116
+ def __init__(self, listener: Optional[AgentEventListener] = None):
117
+ self.listener = listener
118
+ self.last_response = ""
119
+
120
+ async def requestPermission(
121
+ self, params: RequestPermissionRequest
122
+ ) -> RequestPermissionResponse:
123
+ """
124
+ Handle permission requests from the CLI (e.g., indexing permission).
125
+ Auto-approves all requests by selecting the first "allow" option.
126
+ """
127
+ # Find the first "allow" option (allow_once or allow_always)
128
+ allow_option = None
129
+ for option in params.options:
130
+ if option.kind.startswith("allow"):
131
+ allow_option = option
132
+ break
133
+
134
+ if not allow_option and params.options:
135
+ # If no allow option found, just select the first option
136
+ allow_option = params.options[0]
137
+
138
+ if not allow_option:
139
+ # No options available, return a default
140
+ return RequestPermissionResponse(
141
+ outcome=AllowedOutcome(optionId="default", outcome="selected")
142
+ )
143
+
144
+ # Return approval response using proper ACP types
145
+ return RequestPermissionResponse(
146
+ outcome=AllowedOutcome(optionId=allow_option.optionId, outcome="selected")
147
+ )
148
+
149
+ async def writeTextFile(self, params): # type: ignore[override]
150
+ raise RequestError.method_not_found("fs/write_text_file")
151
+
152
+ async def readTextFile(self, params): # type: ignore[override]
153
+ raise RequestError.method_not_found("fs/read_text_file")
154
+
155
+ async def createTerminal(self, params): # type: ignore[override]
156
+ raise RequestError.method_not_found("terminal/create")
157
+
158
+ async def terminalOutput(self, params): # type: ignore[override]
159
+ raise RequestError.method_not_found("terminal/output")
160
+
161
+ async def releaseTerminal(self, params): # type: ignore[override]
162
+ raise RequestError.method_not_found("terminal/release")
163
+
164
+ async def waitForTerminalExit(self, params): # type: ignore[override]
165
+ raise RequestError.method_not_found("terminal/wait_for_exit")
166
+
167
+ async def killTerminal(self, params): # type: ignore[override]
168
+ raise RequestError.method_not_found("terminal/kill")
169
+
170
+ async def sessionUpdate(self, params: SessionNotification) -> None:
171
+ update = params.update
172
+ if isinstance(update, dict):
173
+ kind = update.get("sessionUpdate")
174
+ content = update.get("content")
175
+ else:
176
+ kind = getattr(update, "sessionUpdate", None)
177
+ content = getattr(update, "content", None)
178
+
179
+ # Handle agent message chunks
180
+ if kind == "agent_message_chunk" and content is not None:
181
+ if isinstance(content, dict):
182
+ text = content.get("text", "")
183
+ else:
184
+ text = getattr(content, "text", "")
185
+ self.last_response += text
186
+ if self.listener:
187
+ self.listener.on_agent_message_chunk(text)
188
+
189
+ # Handle agent thoughts
190
+ elif kind == "agent_thought_chunk" and content is not None:
191
+ if isinstance(content, dict):
192
+ text = content.get("text", "")
193
+ else:
194
+ text = getattr(content, "text", "")
195
+ if self.listener:
196
+ self.listener.on_agent_thought(text)
197
+
198
+ # Handle tool calls
199
+ elif kind == "tool_call":
200
+ tool_call_id = getattr(update, "toolCallId", "unknown")
201
+ title = getattr(update, "title", "")
202
+ tool_kind = getattr(update, "kind", None)
203
+ status = getattr(update, "status", None)
204
+ if self.listener:
205
+ self.listener.on_tool_call(tool_call_id, title, tool_kind, status)
206
+
207
+ # Handle tool responses
208
+ elif kind == "tool_call_update":
209
+ tool_call_id = getattr(update, "toolCallId", "unknown")
210
+ status = getattr(update, "status", None)
211
+ content = getattr(update, "content", None)
212
+ if self.listener:
213
+ self.listener.on_tool_response(tool_call_id, status, content)
214
+
215
+ async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002
216
+ raise RequestError.method_not_found(method)
217
+
218
+ async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002
219
+ raise RequestError.method_not_found(method)
220
+
221
+ def get_last_response(self) -> str:
222
+ return self.last_response.strip()
223
+
224
+
225
+ class ACPClient:
226
+ """ACP client interface."""
227
+
228
+ def start(self) -> None:
229
+ """
230
+ Start the agent process and establish ACP connection.
231
+
232
+ Raises:
233
+ RuntimeError: If the agent is already started
234
+ Exception: If initialization fails
235
+ """
236
+ raise NotImplementedError()
237
+
238
+ def stop(self) -> None:
239
+ """Stop the agent process and cleanup resources."""
240
+ raise NotImplementedError()
241
+
242
+ def send_message(self, message: str, timeout: float = 30.0) -> str:
243
+ """
244
+ Send a message to the agent and get the response.
245
+
246
+ Args:
247
+ message: The message to send
248
+ timeout: Maximum time to wait for response (seconds)
249
+
250
+ Returns:
251
+ The agent's response as a string
252
+
253
+ Raises:
254
+ RuntimeError: If the agent is not started
255
+ TimeoutError: If the response takes too long
256
+ """
257
+ raise NotImplementedError()
258
+
259
+ def clear_context(self) -> None:
260
+ """
261
+ Clear the session context by restarting the agent.
262
+
263
+ This stops the current agent and starts a new one with a fresh session.
264
+ """
265
+ raise NotImplementedError()
266
+
267
+ @property
268
+ def is_running(self) -> bool:
269
+ """Check if the agent is currently running."""
270
+ raise NotImplementedError()
271
+
272
+
273
+ class AuggieACPClient(ACPClient):
274
+ """
275
+ Synchronous ACP client for the Augment CLI agent.
276
+
277
+ This client provides a simple interface for:
278
+ - Starting/stopping the agent
279
+ - Sending messages and getting responses
280
+ - Listening to agent events (messages, tool calls, etc.)
281
+ - Clearing session context
282
+
283
+ Example:
284
+ ```python
285
+ # Create a client with an event listener
286
+ client = AuggieACPClient(listener=MyListener())
287
+
288
+ # Start the agent
289
+ client.start()
290
+
291
+ # Send a message
292
+ response = client.send_message("What is 2 + 2?")
293
+ print(response)
294
+
295
+ # Clear context and start fresh
296
+ client.clear_context()
297
+
298
+ # Stop the agent
299
+ client.stop()
300
+ ```
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ cli_path: Optional[str] = None,
306
+ listener: Optional[AgentEventListener] = None,
307
+ model: Optional[str] = None,
308
+ workspace_root: Optional[str] = None,
309
+ acp_max_tool_result_bytes: int = 35 * 1024,
310
+ removed_tools: Optional[List[str]] = None,
311
+ api_key: Optional[str] = None,
312
+ api_url: Optional[str] = None,
313
+ ):
314
+ """
315
+ Initialize the ACP client.
316
+
317
+ Args:
318
+ cli_path: Path to the Augment CLI. If None, uses default location.
319
+ listener: Optional event listener to receive agent events.
320
+ model: AI model to use (e.g., "claude-3-5-sonnet-latest", "gpt-4o").
321
+ If None, uses the CLI's default model.
322
+ workspace_root: Workspace root directory. If None, uses current directory.
323
+ acp_max_tool_result_bytes: Maximum bytes for tool results sent through ACP.
324
+ Default: 35KB (35840 bytes). This prevents Python's asyncio.StreamReader
325
+ from hitting its 64KB line buffer limit. Set higher if needed, but
326
+ keep under 64KB to avoid LimitOverrunError.
327
+ removed_tools: List of tool names to remove/disable (e.g., ["github-api", "linear"]).
328
+ These tools will not be available to the agent.
329
+ api_key: Optional API key for authentication. If provided, sets AUGMENT_API_TOKEN
330
+ environment variable for the agent process.
331
+ api_url: Optional API URL. If not provided, uses AUGMENT_API_URL environment variable,
332
+ or defaults to "https://api.augmentcode.com". Sets AUGMENT_API_URL environment
333
+ variable for the agent process.
334
+ """
335
+ if cli_path is None:
336
+ # Default to 'auggie' in PATH
337
+ cli_path = "auggie"
338
+
339
+ self.cli_path = cli_path
340
+
341
+ self.listener = listener
342
+ self.model = model
343
+ self.workspace_root = workspace_root
344
+ self.acp_max_tool_result_bytes = acp_max_tool_result_bytes
345
+ self.removed_tools = removed_tools or []
346
+ self.api_key = api_key
347
+ self.api_url = (
348
+ api_url
349
+ if api_url is not None
350
+ else os.getenv("AUGMENT_API_URL", "https://api.augmentcode.com")
351
+ )
352
+ self._client: Optional[_InternalACPClient] = None
353
+ self._conn: Optional[ClientSideConnection] = None
354
+ self._session_id: Optional[str] = None
355
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
356
+ self._thread: Optional[Thread] = None
357
+ self._context = None
358
+ self._ready_queue: Optional[Queue] = None
359
+
360
+ def start(self, timeout: float = 30.0) -> None:
361
+ """
362
+ Start the agent process and establish ACP connection.
363
+
364
+ Args:
365
+ timeout: Maximum time to wait for the agent to start (seconds)
366
+
367
+ Raises:
368
+ RuntimeError: If the agent is already started
369
+ TimeoutError: If the agent fails to start within the timeout
370
+ Exception: If initialization fails
371
+ """
372
+ if self._thread is not None:
373
+ raise RuntimeError("Agent already started")
374
+
375
+ self._ready_queue = Queue()
376
+ self._thread = Thread(target=self._run_async_loop, daemon=True)
377
+ self._thread.start()
378
+
379
+ # Wait for initialization with timeout
380
+ try:
381
+ result = self._ready_queue.get(timeout=timeout)
382
+ if isinstance(result, Exception):
383
+ raise result
384
+ except Empty:
385
+ # Queue.get() timed out - no result received within timeout
386
+ raise TimeoutError(
387
+ f"Agent failed to start within {timeout} seconds. "
388
+ f"Check that the CLI path is correct and the agent process can start. "
389
+ f"CLI path: {self.cli_path}"
390
+ )
391
+
392
+ def stop(self) -> None:
393
+ """Stop the agent process and cleanup resources."""
394
+ if self._loop is not None:
395
+ asyncio.run_coroutine_threadsafe(self._async_stop(), self._loop)
396
+ self._loop.call_soon_threadsafe(self._loop.stop)
397
+ if self._thread is not None:
398
+ self._thread.join(timeout=2.0)
399
+
400
+ self._client = None
401
+ self._conn = None
402
+ self._session_id = None
403
+ self._loop = None
404
+ self._thread = None
405
+ self._context = None
406
+
407
+ def send_message(self, message: str, timeout: float = 30.0) -> str:
408
+ """
409
+ Send a message to the agent and get the response.
410
+
411
+ Args:
412
+ message: The message to send
413
+ timeout: Maximum time to wait for response (seconds)
414
+
415
+ Returns:
416
+ The agent's response as a string
417
+
418
+ Raises:
419
+ RuntimeError: If the agent is not started
420
+ TimeoutError: If the response takes too long
421
+ """
422
+ if self._loop is None or self._conn is None or self._client is None:
423
+ raise RuntimeError("Agent not started. Call start() first.")
424
+
425
+ # Reset the response
426
+ self._client.last_response = ""
427
+
428
+ # Schedule the async query
429
+ future = asyncio.run_coroutine_threadsafe(
430
+ self._async_send_message(message), self._loop
431
+ )
432
+
433
+ # Wait for completion - when prompt() completes, the message is done
434
+ # (per ACP spec, there is no agent_message_end event)
435
+ future.result(timeout=timeout)
436
+
437
+ response = self._client.get_last_response()
438
+
439
+ # Call listener with complete message now that prompt() has completed
440
+ if self.listener and response:
441
+ self.listener.on_agent_message(response)
442
+
443
+ return response
444
+
445
+ def clear_context(self) -> None:
446
+ """
447
+ Clear the session context by restarting the agent.
448
+
449
+ This stops the current agent and starts a new one with a fresh session.
450
+ """
451
+ self.stop()
452
+ self.start()
453
+
454
+ @property
455
+ def session_id(self) -> Optional[str]:
456
+ """Get the current session ID."""
457
+ return self._session_id
458
+
459
+ @property
460
+ def is_running(self) -> bool:
461
+ """Check if the agent is currently running."""
462
+ return self._thread is not None and self._loop is not None
463
+
464
+ def _run_async_loop(self):
465
+ """Run the asyncio event loop in a background thread."""
466
+ try:
467
+ self._loop = asyncio.new_event_loop()
468
+ asyncio.set_event_loop(self._loop)
469
+ self._loop.run_until_complete(self._async_start())
470
+ self._ready_queue.put(True)
471
+ self._loop.run_forever()
472
+ except Exception as e:
473
+ self._ready_queue.put(e)
474
+
475
+ async def _async_start(self):
476
+ """Async initialization."""
477
+ self._client = _InternalACPClient(self.listener)
478
+
479
+ # Build CLI arguments
480
+ cli_path_str = str(self.cli_path)
481
+ if cli_path_str.endswith(".mjs") or cli_path_str.endswith(".js"):
482
+ # It's a JS file, run with node
483
+ cli_args = ["node", cli_path_str]
484
+ else:
485
+ # It's a binary or script (like 'auggie'), run directly
486
+ cli_args = [cli_path_str]
487
+
488
+ cli_args.extend(
489
+ [
490
+ "--acp",
491
+ "--log-file",
492
+ "./auggie-acp.log",
493
+ ]
494
+ )
495
+
496
+ # Add model if specified
497
+ if self.model:
498
+ cli_args.extend(["--model", self.model])
499
+
500
+ # Add workspace root if specified
501
+ if self.workspace_root:
502
+ cli_args.extend(["--workspace-root", self.workspace_root])
503
+
504
+ # Add removed tools if specified and enabled
505
+ if self.removed_tools:
506
+ for tool in self.removed_tools:
507
+ cli_args.extend(["--remove-tool", tool])
508
+
509
+ # Add ACP max tool result bytes
510
+ # TODO: Re-enable once --acp-max-tool-result-bytes is in pre-release
511
+ # cli_args.extend(
512
+ # ["--acp-max-tool-result-bytes", str(self.acp_max_tool_result_bytes)]
513
+ # )
514
+
515
+ # Set environment variables for API authentication
516
+ # Build environment dict to pass to spawn_agent_process
517
+ env = os.environ.copy()
518
+ if self.api_key:
519
+ env["AUGMENT_API_TOKEN"] = self.api_key
520
+ if self.api_url:
521
+ env["AUGMENT_API_URL"] = self.api_url
522
+
523
+ # Spawn the agent process with environment variables
524
+ self._context = spawn_agent_process(
525
+ lambda _agent: self._client, *cli_args, env=env
526
+ )
527
+
528
+ # Start the process and get connection
529
+ conn_proc = await self._context.__aenter__()
530
+ self._conn, self._proc = conn_proc
531
+
532
+ # Create a task to monitor if the process exits early
533
+ async def wait_for_process_exit():
534
+ """Wait for the process to exit and raise an error if it does."""
535
+ await self._proc.wait()
536
+ stderr = ""
537
+ if self._proc.stderr:
538
+ try:
539
+ stderr_bytes = await asyncio.wait_for(
540
+ self._proc.stderr.read(), timeout=1.0
541
+ )
542
+ stderr = stderr_bytes.decode("utf-8", errors="replace")
543
+ except Exception:
544
+ pass
545
+ raise RuntimeError(
546
+ f"Agent process exited with code {self._proc.returncode}. "
547
+ f"CLI path: {self.cli_path}\n"
548
+ f"Stderr: {stderr}"
549
+ )
550
+
551
+ # Check if process has already exited
552
+ if self._proc.returncode is not None:
553
+ # Process already exited
554
+ stderr = ""
555
+ if self._proc.stderr:
556
+ stderr_bytes = await self._proc.stderr.read()
557
+ stderr = stderr_bytes.decode("utf-8", errors="replace")
558
+ raise RuntimeError(
559
+ f"Agent process exited immediately with code {self._proc.returncode}. "
560
+ f"CLI path: {self.cli_path}\n"
561
+ f"Stderr: {stderr}"
562
+ )
563
+
564
+ # Create process monitor task
565
+ monitor_task = asyncio.create_task(wait_for_process_exit())
566
+
567
+ try:
568
+ # Race between initialization and process exit
569
+ init_task = asyncio.create_task(
570
+ self._conn.initialize(
571
+ InitializeRequest(
572
+ protocolVersion=PROTOCOL_VERSION, clientCapabilities=None
573
+ )
574
+ )
575
+ )
576
+ done, pending = await asyncio.wait(
577
+ [init_task, monitor_task], return_when=asyncio.FIRST_COMPLETED
578
+ )
579
+
580
+ # If monitor task completed first, it means process exited
581
+ if monitor_task in done:
582
+ # Cancel the init task
583
+ init_task.cancel()
584
+ # Re-raise the exception from monitor_task
585
+ await monitor_task
586
+
587
+ # Otherwise, initialization succeeded
588
+ await init_task
589
+
590
+ # Use workspace_root as cwd if provided, otherwise use current directory
591
+ cwd = self.workspace_root if self.workspace_root else os.getcwd()
592
+
593
+ # Race between session creation and process exit
594
+ session_task = asyncio.create_task(
595
+ self._conn.newSession(NewSessionRequest(mcpServers=[], cwd=cwd))
596
+ )
597
+ done, pending = await asyncio.wait(
598
+ [session_task, monitor_task], return_when=asyncio.FIRST_COMPLETED
599
+ )
600
+
601
+ # If monitor task completed first, it means process exited
602
+ if monitor_task in done:
603
+ # Cancel the session task
604
+ session_task.cancel()
605
+ # Re-raise the exception from monitor_task
606
+ await monitor_task
607
+
608
+ # Otherwise, session creation succeeded
609
+ session = await session_task
610
+ self._session_id = session.sessionId
611
+
612
+ # Keep the monitor task running in the background
613
+ # (don't cancel it, it will keep monitoring the process)
614
+ except Exception:
615
+ # If anything fails, cancel the monitor task
616
+ monitor_task.cancel()
617
+ raise
618
+
619
+ async def _async_send_message(self, message: str):
620
+ """Async message sending."""
621
+ await self._conn.prompt(
622
+ PromptRequest(
623
+ sessionId=self._session_id,
624
+ prompt=[text_block(message)],
625
+ )
626
+ )
627
+
628
+ async def _async_stop(self):
629
+ """Async cleanup."""
630
+ if self._context is not None:
631
+ await self._context.__aexit__(None, None, None)
632
+
633
+ def __enter__(self):
634
+ """Context manager entry."""
635
+ self.start()
636
+ return self
637
+
638
+ def __exit__(self, exc_type, exc_val, exc_tb):
639
+ """Context manager exit."""
640
+ self.stop()