augment-sdk 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
augment/agent.py ADDED
@@ -0,0 +1,1139 @@
1
+ """
2
+ Main Agent class for the Augment Python SDK.
3
+
4
+ This implementation uses the ACP (Agent Client Protocol) client for
5
+ better performance and real-time streaming.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import re
11
+ import subprocess
12
+ from contextlib import contextmanager
13
+ from dataclasses import dataclass, fields, is_dataclass
14
+ from enum import Enum
15
+ from pathlib import Path
16
+ from typing import (
17
+ Any,
18
+ Callable,
19
+ Dict,
20
+ Generator,
21
+ List,
22
+ Optional,
23
+ Tuple,
24
+ Type,
25
+ TypeVar,
26
+ Union,
27
+ get_args,
28
+ get_origin,
29
+ )
30
+
31
+ from .acp import ACPClient, AuggieACPClient
32
+ from .exceptions import (
33
+ AugmentCLIError,
34
+ AugmentNotFoundError,
35
+ AugmentParseError,
36
+ AugmentWorkspaceError,
37
+ AugmentVerificationError,
38
+ )
39
+ from .listener import AgentListener
40
+ from .listener_adapter import AgentListenerAdapter
41
+ from .prompt_formatter import AgentPromptFormatter, DEFAULT_INFERENCE_TYPES
42
+
43
+ T = TypeVar("T")
44
+
45
+
46
+ @dataclass
47
+ class VerificationResult:
48
+ """Result of verifying success criteria."""
49
+
50
+ all_criteria_met: bool
51
+ unmet_criteria: List[int] # Indices of criteria that are not met (1-based)
52
+ issues: List[str] # Description of each issue found
53
+ overall_assessment: str # Brief overall assessment
54
+
55
+
56
+ def _get_type_name(type_hint: Any) -> str:
57
+ """
58
+ Safely get the name of a type hint.
59
+
60
+ Handles both actual types and string annotations.
61
+
62
+ Args:
63
+ type_hint: A type or type annotation
64
+
65
+ Returns:
66
+ String name of the type
67
+ """
68
+ if isinstance(type_hint, str):
69
+ return type_hint
70
+ if hasattr(type_hint, "__name__"):
71
+ return type_hint.__name__
72
+ return str(type_hint)
73
+
74
+
75
+ @dataclass
76
+ class Model:
77
+ """
78
+ Represents an available AI model.
79
+
80
+ Attributes:
81
+ id: The model identifier used with --model flag (e.g., "sonnet4.5")
82
+ name: The human-readable model name (e.g., "Claude Sonnet 4.5")
83
+ description: Additional information about the model (e.g., "Anthropic Claude Sonnet 4.5, 200k context")
84
+ """
85
+
86
+ id: str
87
+ name: str
88
+ description: str
89
+
90
+ def __str__(self) -> str:
91
+ return f"{self.name} [{self.id}]"
92
+
93
+
94
+ class Agent:
95
+ """
96
+ Augment CLI agent interface for programmatic access.
97
+
98
+ This class provides a Python interface to the Augment CLI agent (auggie),
99
+ using the ACP (Agent Client Protocol) for better performance and
100
+ real-time streaming of responses.
101
+
102
+ By default, each run() call creates a fresh session. Use the session()
103
+ context manager to maintain conversation continuity across multiple calls.
104
+
105
+ Attributes:
106
+ last_model_answer: The last textual explanation returned by the model
107
+ when using typed results. This contains the agent's reasoning or
108
+ context and may be helpful for logging and debugging. None for
109
+ untyped responses or if no message was provided.
110
+ model: The AI model to use for instructions. None uses the CLI's default.
111
+ workspace_path: The resolved workspace path for this agent.
112
+ session_id: The current session ID (only set when inside a session context).
113
+ timeout: Default timeout in seconds for agent operations (defaults to 180 seconds).
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ workspace_root: Optional[Union[str, Path]] = None,
119
+ model: Optional[str] = None,
120
+ listener: Optional[AgentListener] = None,
121
+ cli_path: Optional[str] = None,
122
+ acp_client: Optional[ACPClient] = None,
123
+ removed_tools: Optional[List[str]] = None,
124
+ timeout: int = 180,
125
+ api_key: Optional[str] = None,
126
+ api_url: Optional[str] = None,
127
+ ):
128
+ """
129
+ Initialize an agent instance.
130
+
131
+ Args:
132
+ workspace_root: Path to the workspace root. Defaults to current directory.
133
+ model: The AI model to use (e.g., "claude-3-5-sonnet-latest", "gpt-4o").
134
+ Defaults to the CLI's default model.
135
+ listener: Optional listener for agent events (AgentListener).
136
+ cli_path: Optional path to the Augment CLI. Auto-detected if not provided.
137
+ acp_client: Optional ACP client instance for testing. If provided, this client
138
+ will be used instead of creating a new one. This is primarily for
139
+ testing purposes to allow mocking the ACP client.
140
+ removed_tools: List of tool names to remove/disable (e.g., ["github-api", "linear"]).
141
+ These tools will not be available to the agent.
142
+ timeout: Default timeout in seconds for agent operations. Defaults to 180 seconds (3 minutes).
143
+ This timeout is used when no timeout is specified in the run() method.
144
+ api_key: Optional API key for authentication. If provided, sets AUGMENT_API_TOKEN
145
+ environment variable for the agent process.
146
+ api_url: Optional API URL. If not provided, uses AUGMENT_API_URL environment variable,
147
+ or defaults to "https://api.augmentcode.com". Sets AUGMENT_API_URL environment
148
+ variable for the agent process.
149
+ """
150
+ self.workspace_path = self._validate_workspace_path(workspace_root)
151
+ self.model = model
152
+ self.listener = listener
153
+ # Create adapter to bridge ACP events to AgentListener
154
+ self._listener_adapter = AgentListenerAdapter(listener) if listener else None
155
+ self.cli_path = cli_path
156
+ self.removed_tools = removed_tools or []
157
+ self.timeout = timeout
158
+ self.api_key = api_key
159
+ self.api_url = (
160
+ api_url
161
+ if api_url is not None
162
+ else os.getenv("AUGMENT_API_URL", "https://api.augmentcode.com")
163
+ )
164
+ self.last_model_answer: Optional[str] = None
165
+ self._acp_client: Optional[ACPClient] = (
166
+ acp_client # Use provided client or None
167
+ )
168
+ self._provided_client = acp_client is not None # Track if client was provided
169
+ self._in_session = False # Track if we're in a session context
170
+ self._prompt_formatter = AgentPromptFormatter() # Handles prompt formatting
171
+
172
+ def run(
173
+ self,
174
+ instruction: str,
175
+ return_type: Optional[Type[T]] = None,
176
+ timeout: Optional[int] = None,
177
+ max_retries: int = 3,
178
+ functions: Optional[List[Callable]] = None,
179
+ success_criteria: Optional[List[str]] = None,
180
+ max_verification_rounds: int = 3,
181
+ ) -> Union[T, Any]:
182
+ """
183
+ Execute an instruction and return the agent's response.
184
+
185
+ If a typed response is requested and parsing fails, the agent will be asked
186
+ to correct its output up to max_retries times.
187
+
188
+ When no return_type is specified, the agent automatically infers the best type
189
+ from common types (int, float, bool, str, list, dict) and returns the result.
190
+
191
+ Args:
192
+ instruction: The instruction to send to the agent
193
+ return_type: Expected return type (int, str, bool, float, list, dict, dataclass, or enum).
194
+ If None, the agent will automatically infer the type.
195
+ timeout: Optional timeout in seconds
196
+ max_retries: Maximum number of retry attempts for parsing failures (default: 3)
197
+ functions: Optional list of Python functions the agent can call. Functions should
198
+ have type hints and docstrings describing their parameters.
199
+ success_criteria: Optional list of criteria that must be met after execution.
200
+ The agent will iteratively work on the task and verify criteria
201
+ until all are met or max_verification_rounds is reached.
202
+ max_verification_rounds: Maximum number of verification rounds when using
203
+ success_criteria (default: 3). Each round consists of
204
+ executing/fixing the task and verifying criteria.
205
+
206
+ Returns:
207
+ Parsed result (either of requested type, or inferred type)
208
+
209
+ Note:
210
+ After execution, self.last_model_answer contains the agent's textual explanation,
211
+ which may be helpful for logging and debugging. This is separate from the
212
+ structured return value and provides the model's reasoning or context.
213
+
214
+ Raises:
215
+ ValueError: If instruction is empty or whitespace, or unsupported return_type
216
+ AugmentParseError: If parsing typed response fails after all retries
217
+ AugmentVerificationError: If success_criteria are not met after max_verification_rounds
218
+ """
219
+ # Validate instruction
220
+ if not instruction or not instruction.strip():
221
+ raise ValueError("Instruction cannot be empty or whitespace")
222
+
223
+ # Prepare the client
224
+ client = self._prepare_client()
225
+
226
+ # If no success criteria, just run the task once
227
+ if not success_criteria:
228
+ return self._run_task(
229
+ client=client,
230
+ instruction=instruction,
231
+ return_type=return_type,
232
+ timeout=timeout,
233
+ max_retries=max_retries,
234
+ functions=functions,
235
+ )
236
+
237
+ # With success criteria: loop until verified or max rounds
238
+ current_instruction = instruction
239
+
240
+ for round_num in range(max_verification_rounds):
241
+ # Execute the task (or fix issues from previous round)
242
+ result = self._run_task(
243
+ client=client,
244
+ instruction=current_instruction,
245
+ return_type=return_type,
246
+ timeout=timeout,
247
+ max_retries=max_retries,
248
+ functions=functions,
249
+ )
250
+
251
+ # Verify success criteria
252
+ verification = self._verify_success_criteria(
253
+ client=client,
254
+ success_criteria=success_criteria,
255
+ timeout=timeout,
256
+ )
257
+
258
+ # Check if all criteria are met
259
+ if verification.all_criteria_met:
260
+ # Success! All criteria verified
261
+ return result
262
+
263
+ # Not all criteria met - prepare fix instruction for next round
264
+ if round_num < max_verification_rounds - 1:
265
+ # We have more rounds - prepare fix instruction
266
+ current_instruction = self._prepare_fix_instruction(
267
+ success_criteria=success_criteria,
268
+ verification=verification,
269
+ )
270
+ else:
271
+ # Last round - raise exception
272
+ raise AugmentVerificationError(
273
+ f"Success criteria not fully met after {max_verification_rounds} rounds. "
274
+ f"Unmet criteria: {verification.unmet_criteria}. "
275
+ f"Issues: {verification.issues}",
276
+ unmet_criteria=verification.unmet_criteria,
277
+ issues=verification.issues,
278
+ rounds_attempted=max_verification_rounds,
279
+ )
280
+
281
+ # This should never be reached, but just in case
282
+ return None
283
+
284
+ def _prepare_client(self) -> ACPClient:
285
+ """
286
+ Prepare the ACP client for execution.
287
+
288
+ This method:
289
+ 1. Creates the client if it doesn't exist
290
+ 2. Starts the client if it's not running
291
+ 3. Clears context if not in a session (for fresh start)
292
+
293
+ Returns:
294
+ ACP client instance ready for use
295
+ """
296
+ # Create client if needed
297
+ if self._acp_client is None:
298
+ self._acp_client = AuggieACPClient(
299
+ cli_path=self.cli_path,
300
+ model=self.model,
301
+ workspace_root=str(self.workspace_path),
302
+ listener=self._listener_adapter, # Pass the adapter to ACP
303
+ removed_tools=self.removed_tools,
304
+ api_key=self.api_key,
305
+ api_url=self.api_url,
306
+ )
307
+
308
+ # Start if not running
309
+ if not self._acp_client.is_running:
310
+ self._acp_client.start()
311
+
312
+ # If not in a session, clear context for a fresh start
313
+ if not self._in_session:
314
+ self._acp_client.clear_context()
315
+
316
+ return self._acp_client
317
+
318
+ def _run_task(
319
+ self,
320
+ client: ACPClient,
321
+ instruction: str,
322
+ return_type: Optional[Type[T]],
323
+ timeout: Optional[int],
324
+ max_retries: int,
325
+ functions: Optional[List[Callable]] = None,
326
+ ) -> Union[T, Any]:
327
+ """
328
+ Run a single task: talk to the agent in a loop until the task is complete.
329
+
330
+ This method implements a loop that:
331
+ 1. Sends the current instruction to the agent
332
+ 2. If response contains function calls, executes them and continues
333
+ 3. If response can be parsed successfully, returns the result
334
+ 4. If parsing fails and retries remain, asks agent to fix output and continues
335
+ 5. Repeats until success or max rounds exhausted
336
+
337
+ Args:
338
+ client: The ACP client to use
339
+ instruction: The instruction to send
340
+ return_type: Expected return type (None for type inference)
341
+ timeout: Timeout in seconds (None for default)
342
+ max_retries: Maximum number of retry attempts for parsing failures
343
+ functions: Optional list of functions the agent can call
344
+
345
+ Returns:
346
+ Parsed result (either of requested type, or inferred type)
347
+
348
+ Raises:
349
+ AugmentParseError: If parsing fails after all retries
350
+ """
351
+ # Set effective timeout - use provided timeout, or fall back to instance default
352
+ effective_timeout = (
353
+ float(timeout) if timeout is not None else float(self.timeout)
354
+ )
355
+
356
+ # Prepare the instruction prompt
357
+ current_instruction, function_map = (
358
+ self._prompt_formatter.prepare_instruction_prompt(
359
+ instruction, return_type, functions
360
+ )
361
+ )
362
+
363
+ # Determine if we're in type inference mode
364
+ infer_type = return_type is None
365
+
366
+ # Track retry attempts for parsing failures
367
+ parse_retry_count = 0
368
+
369
+ # Maximum total rounds to prevent infinite loops
370
+ # This includes both function calls and parse retries
371
+ max_rounds = 20
372
+
373
+ for round_num in range(max_rounds):
374
+ # Send message to agent
375
+ raw_response = client.send_message(current_instruction, effective_timeout)
376
+
377
+ # Check if there are function calls to handle
378
+ if function_map:
379
+ function_calls = self._parse_function_calls(raw_response)
380
+ if function_calls:
381
+ # Execute functions and prepare next instruction with results
382
+ current_instruction = self._handle_function_calls(
383
+ raw_response, function_map
384
+ )
385
+ # Continue loop to send function results back to agent
386
+ continue
387
+
388
+ # No function calls - try to parse the response
389
+ try:
390
+ if infer_type:
391
+ result, _ = self._parse_type_inference_response(
392
+ raw_response, DEFAULT_INFERENCE_TYPES
393
+ )
394
+ return result
395
+ else:
396
+ # At this point, return_type must be a valid type (not None)
397
+ assert return_type is not None
398
+ result = self._parse_typed_response(raw_response, return_type)
399
+ return result # type: ignore[return-value,no-any-return]
400
+
401
+ except AugmentParseError as e:
402
+ # Parsing failed - check if we have retries left
403
+ if parse_retry_count >= max_retries:
404
+ raise
405
+
406
+ # Increment retry count and prepare retry instruction
407
+ parse_retry_count += 1
408
+ current_instruction = self._prompt_formatter.prepare_retry_prompt(
409
+ return_type,
410
+ str(e),
411
+ parse_retry_count,
412
+ )
413
+ # Continue loop to retry with corrected instruction
414
+ continue
415
+
416
+ # If we've exhausted all rounds, raise an error
417
+ raise AugmentParseError(
418
+ f"Exceeded maximum rounds ({max_rounds}) without successful completion"
419
+ )
420
+
421
+ def _verify_success_criteria(
422
+ self,
423
+ client: ACPClient,
424
+ success_criteria: List[str],
425
+ timeout: Optional[int],
426
+ ) -> VerificationResult:
427
+ """
428
+ Verify that success criteria are met and return structured feedback.
429
+
430
+ This method checks each criterion and returns detailed information about
431
+ which criteria are met and which are not, along with specific issues.
432
+
433
+ Args:
434
+ client: The ACP client to use
435
+ success_criteria: List of criteria that must be met
436
+ timeout: Timeout in seconds (None for default)
437
+
438
+ Returns:
439
+ VerificationResult with detailed verification feedback
440
+ """
441
+ # Build verification instruction
442
+ criteria_list = "\n".join(
443
+ f"{i + 1}. {criterion}" for i, criterion in enumerate(success_criteria)
444
+ )
445
+
446
+ verification_instruction = f"""The task has been completed. Please verify that ALL of the following success criteria are met:
447
+
448
+ {criteria_list}
449
+
450
+ For each criterion, check if it is satisfied.
451
+
452
+ Respond with a JSON object with the following structure:
453
+ {{
454
+ "all_criteria_met": true/false,
455
+ "unmet_criteria": [list of criterion numbers (1-based) that are NOT met],
456
+ "issues": ["description of each issue found"],
457
+ "overall_assessment": "brief assessment of the current state"
458
+ }}
459
+
460
+ Example:
461
+ {{
462
+ "all_criteria_met": false,
463
+ "unmet_criteria": [2, 3],
464
+ "issues": ["Criterion 2: Function is missing type hints for parameter 'x'", "Criterion 3: No docstring present"],
465
+ "overall_assessment": "Function exists but lacks type hints and documentation"
466
+ }}"""
467
+
468
+ # Run verification and get structured result
469
+ try:
470
+ verification = self._run_task(
471
+ client=client,
472
+ instruction=verification_instruction,
473
+ return_type=VerificationResult,
474
+ timeout=timeout,
475
+ max_retries=2,
476
+ functions=None,
477
+ )
478
+ return verification
479
+
480
+ except (AugmentParseError, Exception) as e:
481
+ # If verification fails, return a conservative result
482
+ import warnings
483
+
484
+ warnings.warn(
485
+ f"Success criteria verification failed: {e}. "
486
+ "Assuming criteria are not met.",
487
+ UserWarning,
488
+ )
489
+ return VerificationResult(
490
+ all_criteria_met=False,
491
+ unmet_criteria=list(range(1, len(success_criteria) + 1)),
492
+ issues=[f"Verification failed with error: {e}"],
493
+ overall_assessment="Verification could not be completed",
494
+ )
495
+
496
+ def _prepare_fix_instruction(
497
+ self,
498
+ success_criteria: List[str],
499
+ verification: VerificationResult,
500
+ ) -> str:
501
+ """
502
+ Prepare an instruction to fix issues identified during verification.
503
+
504
+ Args:
505
+ success_criteria: List of all success criteria
506
+ verification: Verification result with issues
507
+
508
+ Returns:
509
+ Instruction for the agent to fix the issues
510
+ """
511
+ # Build list of unmet criteria with their descriptions
512
+ unmet_details = []
513
+ for criterion_num in verification.unmet_criteria:
514
+ if 1 <= criterion_num <= len(success_criteria):
515
+ criterion_text = success_criteria[criterion_num - 1]
516
+ unmet_details.append(f"{criterion_num}. {criterion_text}")
517
+
518
+ unmet_list = "\n".join(unmet_details)
519
+ issues_list = "\n".join(f"- {issue}" for issue in verification.issues)
520
+
521
+ fix_instruction = f"""The following success criteria are NOT yet met:
522
+
523
+ {unmet_list}
524
+
525
+ Issues identified:
526
+ {issues_list}
527
+
528
+ Overall assessment: {verification.overall_assessment}
529
+
530
+ Please fix these issues to ensure ALL success criteria are satisfied."""
531
+
532
+ return fix_instruction
533
+
534
+ def _handle_function_calls(
535
+ self,
536
+ raw_response: str,
537
+ function_map: Dict[str, Callable],
538
+ ) -> str:
539
+ """
540
+ Execute function calls from the agent's response and prepare next instruction.
541
+
542
+ This method:
543
+ 1. Parses function calls from the response
544
+ 2. Executes the functions
545
+ 3. Returns an instruction containing the function results to send back to the agent
546
+
547
+ Args:
548
+ raw_response: The agent's response that contains function calls
549
+ function_map: Mapping of function names to callables
550
+
551
+ Returns:
552
+ Instruction text with function results to send back to the agent
553
+ """
554
+ # Parse function calls from response
555
+ function_calls = self._parse_function_calls(raw_response)
556
+
557
+ # Execute function calls
558
+ function_results = []
559
+ for func_call in function_calls:
560
+ func_name = func_call.get("name")
561
+ func_args = func_call.get("arguments", {})
562
+
563
+ # Skip if function name is missing
564
+ if not func_name:
565
+ continue
566
+
567
+ if func_name not in function_map:
568
+ result = {"error": f"Function '{func_name}' not found"}
569
+ error = f"Function '{func_name}' not found"
570
+ # Notify listener of error
571
+ if self.listener:
572
+ self.listener.on_function_result(func_name, None, error)
573
+ else:
574
+ try:
575
+ func = function_map[func_name]
576
+
577
+ # Notify listener of function call (right before execution)
578
+ if self.listener:
579
+ self.listener.on_function_call(func_name, func_args)
580
+
581
+ result = func(**func_args)
582
+ # Notify listener of success
583
+ if self.listener:
584
+ self.listener.on_function_result(func_name, result, None)
585
+ except Exception as e:
586
+ result = {"error": f"Error calling {func_name}: {str(e)}"}
587
+ # Notify listener of error
588
+ if self.listener:
589
+ self.listener.on_function_result(func_name, None, str(e))
590
+
591
+ function_results.append({"function": func_name, "result": result})
592
+
593
+ # Build follow-up instruction with function results
594
+ results_text = "Function call results:\n\n"
595
+ for fr in function_results:
596
+ results_text += f"Function: {fr['function']}\n"
597
+ results_text += f"Result: {json.dumps(fr['result'], indent=2)}\n\n"
598
+
599
+ results_text += "\nPlease continue with your response based on these results."
600
+
601
+ return results_text
602
+
603
+ def _parse_function_calls(self, response: str) -> List[Dict[str, Any]]:
604
+ """
605
+ Parse function calls from agent response.
606
+
607
+ Looks for <function-call>...</function-call> blocks containing JSON.
608
+
609
+ Args:
610
+ response: Agent's response text
611
+
612
+ Returns:
613
+ List of function call dictionaries with 'name' and 'arguments' keys
614
+ """
615
+ function_calls = []
616
+ pattern = r"<function-call>\s*(\{.*?\})\s*</function-call>"
617
+ matches = re.findall(pattern, response, re.DOTALL)
618
+
619
+ for match in matches:
620
+ try:
621
+ func_call = json.loads(match)
622
+ if "name" in func_call:
623
+ function_calls.append(func_call)
624
+ except json.JSONDecodeError:
625
+ # Skip invalid JSON
626
+ continue
627
+
628
+ return function_calls
629
+
630
+ def _parse_type_inference_response(
631
+ self, response: str, possible_types: List[Type[Any]]
632
+ ) -> Tuple[Any, Type[Any]]:
633
+ """
634
+ Parse type inference response.
635
+
636
+ Args:
637
+ response: The agent's response
638
+ possible_types: List of possible types
639
+
640
+ Returns:
641
+ Tuple of (parsed result, chosen type)
642
+
643
+ Raises:
644
+ AugmentParseError: If parsing fails
645
+ """
646
+ # Extract message
647
+ message_match = re.search(
648
+ r"<augment-agent-message>\s*(.*?)\s*</augment-agent-message>",
649
+ response,
650
+ re.DOTALL,
651
+ )
652
+ if message_match:
653
+ self.last_model_answer = message_match.group(1).strip()
654
+ else:
655
+ self.last_model_answer = None
656
+
657
+ # Extract type name
658
+ type_match = re.search(
659
+ r"<augment-agent-type>\s*(\w+)\s*</augment-agent-type>", response, re.DOTALL
660
+ )
661
+ if not type_match:
662
+ # If no type tags found, check if response is empty or just whitespace
663
+ # This can happen when agent completes a task (like file creation) successfully
664
+ # but doesn't provide a structured response
665
+ if not response or response.strip() == "":
666
+ # Return empty string as success indicator
667
+ return "", str
668
+
669
+ # If there's content but no tags, try to extract it as a string
670
+ # This handles cases where agent responds without proper formatting
671
+ content = response.strip()
672
+ if content:
673
+ # Return the content as a string
674
+ return content, str
675
+
676
+ raise AugmentParseError(
677
+ "No type classification found. Expected <augment-agent-type> tags."
678
+ )
679
+
680
+ type_name = type_match.group(1).strip()
681
+
682
+ # Find the matching type by name
683
+ chosen_type = None
684
+ for t in possible_types:
685
+ if t.__name__ == type_name:
686
+ chosen_type = t
687
+ break
688
+
689
+ if chosen_type is None:
690
+ type_names = [t.__name__ for t in possible_types]
691
+ raise AugmentParseError(
692
+ f"Invalid type name '{type_name}'. Must be one of: {', '.join(type_names)}"
693
+ )
694
+
695
+ # Extract result
696
+ result_match = re.search(
697
+ r"<augment-agent-result>\s*(.*?)\s*</augment-agent-result>",
698
+ response,
699
+ re.DOTALL,
700
+ )
701
+ if not result_match:
702
+ raise AugmentParseError(
703
+ "No structured result found. Expected <augment-agent-result> tags."
704
+ )
705
+
706
+ content = result_match.group(1).strip()
707
+
708
+ # Parse the result according to the chosen type
709
+ try:
710
+ parsed_result = self._convert_to_type(content, chosen_type)
711
+ return parsed_result, chosen_type
712
+ except (json.JSONDecodeError, ValueError, TypeError) as e:
713
+ raise AugmentParseError(
714
+ f"Could not parse result as {chosen_type.__name__}: {e}"
715
+ )
716
+
717
+ @contextmanager
718
+ def session(
719
+ self, session_id: Optional[str] = None
720
+ ) -> Generator["Agent", None, None]:
721
+ """
722
+ Create a session context for maintaining conversation continuity.
723
+
724
+ By default, each run() call creates a fresh session. Use this context
725
+ manager to maintain conversation continuity across multiple calls.
726
+
727
+ Usage:
728
+ agent = Agent()
729
+
730
+ # Without session - each call is independent
731
+ agent.run("Create a function")
732
+ agent.run("Test it") # ❌ Won't remember the function
733
+
734
+ # With session - calls remember each other
735
+ with agent.session() as session:
736
+ session.run("Create a function called add_numbers")
737
+ session.run("Now test that function") # ✅ Remembers add_numbers!
738
+
739
+ Args:
740
+ session_id: Optional session ID (currently ignored, auto-generated)
741
+
742
+ Yields:
743
+ Agent: This agent instance
744
+ """
745
+ # Mark that we're in a session
746
+ self._in_session = True
747
+
748
+ try:
749
+ yield self
750
+ finally:
751
+ # End the session
752
+ self._in_session = False
753
+
754
+ def _get_type_description(self, return_type: Type) -> str:
755
+ """Get the JSON structure description for the type."""
756
+ # Built-in types
757
+ if return_type in (int, float, str, bool):
758
+ return "your_result_value"
759
+ elif return_type in (list, dict):
760
+ return "your_result_value"
761
+
762
+ # Generic types (list[SomeClass], dict[str, SomeClass], etc.)
763
+ elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
764
+ origin = get_origin(return_type)
765
+ args = get_args(return_type)
766
+
767
+ if origin is list:
768
+ if args and len(args) == 1:
769
+ # list[SomeClass] - create example array with one element
770
+ element_type = args[0]
771
+ if is_dataclass(element_type):
772
+ field_names = [f.name for f in fields(element_type)]
773
+ element_example = {
774
+ name: f"<{name}_value>" for name in field_names
775
+ }
776
+ return json.dumps([element_example], indent=2)
777
+ elif hasattr(element_type, "__name__") and issubclass(
778
+ element_type, Enum
779
+ ):
780
+ return '["<enum_value>"]'
781
+ else:
782
+ return '["<element_value>"]'
783
+ else:
784
+ return "your_result_value" # Plain list
785
+ elif origin is dict:
786
+ return "your_result_value" # For now, treat as plain dict
787
+ else:
788
+ return "your_result_value" # Other generic types
789
+
790
+ # Dataclass
791
+ elif is_dataclass(return_type):
792
+ field_names = [f.name for f in fields(return_type)]
793
+ example = {name: f"<{name}_value>" for name in field_names}
794
+ return json.dumps(example, indent=2)
795
+
796
+ # Enum
797
+ elif issubclass(return_type, Enum):
798
+ return "your_enum_value"
799
+
800
+ else:
801
+ raise ValueError(f"Unsupported return type: {return_type}")
802
+
803
+ def _get_type_instructions(self, return_type: Type) -> str:
804
+ """Get specific instructions for the type."""
805
+ # Built-in types
806
+ if return_type is int:
807
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return just the integer number (no quotes)."
808
+ elif return_type is float:
809
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return just the decimal number (no quotes)."
810
+ elif return_type is str:
811
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return the string value in double quotes."
812
+ elif return_type is bool:
813
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return either true or false (no quotes)."
814
+ elif return_type is list:
815
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array: [item1, item2, item3]"
816
+ elif return_type is dict:
817
+ return 'IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON object: {"key1": "value1", "key2": "value2"}'
818
+
819
+ # Generic types (list[SomeClass], dict[str, SomeClass], etc.)
820
+ elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
821
+ origin = get_origin(return_type)
822
+ args = get_args(return_type)
823
+
824
+ if origin is list:
825
+ if args and len(args) == 1:
826
+ element_type = args[0]
827
+ if is_dataclass(element_type):
828
+ field_info = []
829
+ for field in fields(element_type):
830
+ field_info.append(
831
+ f" - {field.name}: {_get_type_name(field.type)}"
832
+ )
833
+
834
+ return f"""IMPORTANT: Put your JSON array inside <augment-agent-result> tags.
835
+
836
+ Return a JSON array of objects, each with these fields:
837
+ {chr(10).join(field_info)}
838
+
839
+ Example:
840
+ <augment-agent-result>
841
+ [{{"field1": value1, "field2": value2}}, {{"field1": value3, "field2": value4}}]
842
+ </augment-agent-result>"""
843
+ elif hasattr(element_type, "__name__") and issubclass(
844
+ element_type, Enum
845
+ ):
846
+ enum_values = [f'"{e.value}"' for e in element_type]
847
+ return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array of enum values: {enum_values}"
848
+ else:
849
+ return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array of {_get_type_name(element_type)} values: [value1, value2, value3]"
850
+ else:
851
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array: [item1, item2, item3]"
852
+ elif origin is dict:
853
+ return 'IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON object: {"key1": "value1", "key2": "value2"}'
854
+ else:
855
+ return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return the appropriate JSON structure for your type."
856
+
857
+ # Dataclass
858
+ elif is_dataclass(return_type):
859
+ field_info = []
860
+ for field in fields(return_type):
861
+ field_info.append(f" - {field.name}: {_get_type_name(field.type)}")
862
+
863
+ return f"""IMPORTANT: Put your JSON object inside <augment-agent-result> tags.
864
+
865
+ Return a JSON object with these exact fields:
866
+ {chr(10).join(field_info)}
867
+
868
+ Example:
869
+ <augment-agent-result>
870
+ {{"field1": value1, "field2": value2}}
871
+ </augment-agent-result>"""
872
+
873
+ # Enum
874
+ elif issubclass(return_type, Enum):
875
+ enum_values = [f'"{e.value}"' for e in return_type]
876
+ return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return one of these exact values: {', '.join(enum_values)}"
877
+
878
+ else:
879
+ raise ValueError(f"Unsupported return type: {return_type}")
880
+
881
+ def _parse_typed_response(self, response: str, return_type: Type[T]) -> T:
882
+ """Parse the agent's structured response into the desired type."""
883
+ # Extract message from <augment-agent-message> tags
884
+ message_match = re.search(
885
+ r"<augment-agent-message>\s*(.*?)\s*</augment-agent-message>",
886
+ response,
887
+ re.DOTALL,
888
+ )
889
+ if message_match:
890
+ self.last_model_answer = message_match.group(1).strip()
891
+ else:
892
+ self.last_model_answer = None
893
+
894
+ # Extract content from <augment-agent-result> tags
895
+ result_match = re.search(
896
+ r"<augment-agent-result>\s*(.*?)\s*</augment-agent-result>",
897
+ response,
898
+ re.DOTALL,
899
+ )
900
+
901
+ if not result_match:
902
+ raise AugmentParseError(
903
+ "No structured result found. Expected <augment-agent-result> tags in response."
904
+ )
905
+
906
+ content = result_match.group(1).strip()
907
+
908
+ try:
909
+ return self._convert_to_type(content, return_type)
910
+ except (json.JSONDecodeError, ValueError, TypeError) as e:
911
+ raise AugmentParseError(
912
+ f"Could not parse result as {return_type.__name__}: {e}"
913
+ )
914
+
915
+ def _convert_to_type(self, content: str, return_type: Type[T]) -> T:
916
+ """Convert string content to the specified Python type."""
917
+
918
+ # Special case for str - don't JSON parse it, just return as-is
919
+ if return_type is str:
920
+ return content # type: ignore[return-value]
921
+
922
+ # Built-in types that need JSON parsing
923
+ if return_type in (int, float, bool, list, dict):
924
+ parsed = json.loads(content)
925
+
926
+ if not isinstance(parsed, return_type):
927
+ raise ValueError(
928
+ f"Expected {return_type.__name__}, got {type(parsed).__name__}"
929
+ )
930
+
931
+ return parsed
932
+
933
+ # Generic types (list[SomeClass], dict[str, SomeClass], etc.)
934
+ elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
935
+ origin = get_origin(return_type)
936
+ args = get_args(return_type)
937
+
938
+ if origin is list:
939
+ parsed = json.loads(content)
940
+ if not isinstance(parsed, list):
941
+ raise ValueError(f"Expected list, got {type(parsed).__name__}")
942
+
943
+ if args and len(args) == 1:
944
+ element_type = args[0]
945
+ # Convert each element to the specified type
946
+ result = []
947
+ for item in parsed:
948
+ if is_dataclass(element_type):
949
+ if not isinstance(item, dict):
950
+ raise ValueError(
951
+ f"Expected dict for dataclass element, got {type(item).__name__}"
952
+ )
953
+ result.append(element_type(**item)) # type: ignore[misc]
954
+ elif hasattr(element_type, "__name__") and issubclass(
955
+ element_type, Enum
956
+ ):
957
+ result.append(element_type(item))
958
+ else:
959
+ # For basic types, just validate and append
960
+ if not isinstance(item, element_type):
961
+ raise ValueError(
962
+ f"Expected {element_type.__name__}, got {type(item).__name__}"
963
+ )
964
+ result.append(item)
965
+ return result # type: ignore[return-value]
966
+ else:
967
+ # Plain list without type parameter
968
+ return parsed # type: ignore[return-value]
969
+ elif origin is dict:
970
+ parsed = json.loads(content)
971
+ if not isinstance(parsed, dict):
972
+ raise ValueError(f"Expected dict, got {type(parsed).__name__}")
973
+ return parsed # type: ignore[return-value]
974
+ else:
975
+ # Other generic types - try basic JSON parsing
976
+ return json.loads(content) # type: ignore[return-value,no-any-return]
977
+
978
+ # Dataclass
979
+ elif is_dataclass(return_type):
980
+ parsed = json.loads(content)
981
+
982
+ if not isinstance(parsed, dict):
983
+ raise ValueError(
984
+ f"Expected dict for dataclass, got {type(parsed).__name__}"
985
+ )
986
+
987
+ return return_type(**parsed)
988
+
989
+ # Enum
990
+ elif issubclass(return_type, Enum):
991
+ # Try JSON parsing first (for quoted values)
992
+ try:
993
+ parsed = json.loads(content)
994
+ return return_type(parsed) # type: ignore[return-value]
995
+ except json.JSONDecodeError:
996
+ # If JSON parsing fails, try direct string value
997
+ return return_type(content.strip()) # type: ignore[return-value]
998
+
999
+ else:
1000
+ raise ValueError(f"Unsupported return type: {return_type}")
1001
+
1002
+ def get_workspace_path(self) -> Path:
1003
+ """
1004
+ Get the workspace path for this agent.
1005
+
1006
+ Returns:
1007
+ Path object representing the workspace root
1008
+ """
1009
+ return self.workspace_path
1010
+
1011
+ @property
1012
+ def session_id(self) -> Optional[str]:
1013
+ """
1014
+ Get the current session ID.
1015
+
1016
+ Returns the session ID only when inside a session context manager.
1017
+ Returns None for standalone run() calls.
1018
+ """
1019
+ if self._in_session and self._acp_client:
1020
+ return self._acp_client.session_id # type: ignore[attr-defined,no-any-return]
1021
+ return None
1022
+
1023
+ def __repr__(self) -> str:
1024
+ """String representation of the Agent."""
1025
+ if self.session_id:
1026
+ return f"Agent(workspace_path='{self.workspace_path}', session_id='{self.session_id}')"
1027
+ return f"Agent(workspace_path='{self.workspace_path}')"
1028
+
1029
+ def __del__(self) -> None:
1030
+ """Cleanup when agent is destroyed."""
1031
+ # Check if attributes exist (in case __init__ failed)
1032
+ if hasattr(self, "_acp_client") and self._acp_client:
1033
+ try:
1034
+ self._acp_client.stop()
1035
+ except Exception:
1036
+ pass # Ignore errors during cleanup
1037
+
1038
+ @staticmethod
1039
+ def _validate_workspace_path(workspace_root: Optional[Union[str, Path]]) -> Path:
1040
+ """
1041
+ Validate and resolve workspace path.
1042
+
1043
+ Args:
1044
+ workspace_root: User-provided workspace path or None
1045
+
1046
+ Returns:
1047
+ Resolved Path object
1048
+
1049
+ Raises:
1050
+ AugmentWorkspaceError: If path is invalid
1051
+ """
1052
+ if workspace_root is None:
1053
+ return Path.cwd()
1054
+
1055
+ path = Path(workspace_root).resolve()
1056
+
1057
+ if not path.exists():
1058
+ raise AugmentWorkspaceError(f"Workspace path does not exist: {path}")
1059
+
1060
+ if not path.is_dir():
1061
+ raise AugmentWorkspaceError(f"Workspace path is not a directory: {path}")
1062
+
1063
+ return path
1064
+
1065
+ @staticmethod
1066
+ def get_available_models() -> List[Model]:
1067
+ """
1068
+ Get the list of available AI models for the current account.
1069
+
1070
+ This method calls `auggie model list` to retrieve the available models.
1071
+
1072
+ Returns:
1073
+ List of Model objects containing id, name, and description
1074
+
1075
+ Raises:
1076
+ AugmentNotFoundError: If auggie CLI is not found
1077
+ AugmentCLIError: If the CLI command fails
1078
+
1079
+ Example:
1080
+ >>> models = Agent.get_available_models()
1081
+ >>> for model in models:
1082
+ ... print(f"{model.name} [{model.id}]")
1083
+ ... print(f" {model.description}")
1084
+ Claude Sonnet 4.5 [sonnet4.5]
1085
+ Anthropic Claude Sonnet 4.5, 200k context
1086
+ """
1087
+ # Check if auggie is available
1088
+ try:
1089
+ result = subprocess.run(
1090
+ ["auggie", "model", "list"],
1091
+ capture_output=True,
1092
+ text=True,
1093
+ timeout=30,
1094
+ )
1095
+ except FileNotFoundError:
1096
+ raise AugmentNotFoundError(
1097
+ "auggie CLI not found. Please install auggie and ensure it's in your PATH."
1098
+ )
1099
+ except subprocess.TimeoutExpired:
1100
+ raise AugmentCLIError("Command timed out after 30 seconds", -1, "")
1101
+
1102
+ if result.returncode != 0:
1103
+ raise AugmentCLIError(
1104
+ f"Failed to get model list: {result.stderr}",
1105
+ result.returncode,
1106
+ result.stderr,
1107
+ )
1108
+
1109
+ # Parse the output
1110
+ models = []
1111
+ lines = result.stdout.strip().split("\n")
1112
+
1113
+ i = 0
1114
+ while i < len(lines):
1115
+ line = lines[i].strip()
1116
+
1117
+ # Look for lines that start with " - " (model entries)
1118
+ if line.startswith("- "):
1119
+ # Extract name and id from the format: " - Model Name [model-id]"
1120
+ match = re.match(r"^- (.+?)\s+\[([^\]]+)\]$", line)
1121
+ if match:
1122
+ name = match.group(1).strip()
1123
+ model_id = match.group(2).strip()
1124
+
1125
+ # Next line should be the description (indented)
1126
+ description = ""
1127
+ if i + 1 < len(lines):
1128
+ next_line = lines[i + 1].strip()
1129
+ if next_line and not next_line.startswith("- "):
1130
+ description = next_line
1131
+ i += 1 # Skip the description line
1132
+
1133
+ models.append(
1134
+ Model(id=model_id, name=name, description=description)
1135
+ )
1136
+
1137
+ i += 1
1138
+
1139
+ return models