augment-sdk 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- augment/__init__.py +30 -0
- augment/acp/__init__.py +11 -0
- augment/acp/claude_code_client.py +365 -0
- augment/acp/client.py +640 -0
- augment/acp/test_client_e2e.py +472 -0
- augment/agent.py +1139 -0
- augment/exceptions.py +92 -0
- augment/function_tools.py +265 -0
- augment/listener.py +186 -0
- augment/listener_adapter.py +83 -0
- augment/prompt_formatter.py +343 -0
- augment_sdk-0.1.1.dist-info/METADATA +841 -0
- augment_sdk-0.1.1.dist-info/RECORD +17 -0
- augment_sdk-0.1.1.dist-info/WHEEL +5 -0
- augment_sdk-0.1.1.dist-info/entry_points.txt +2 -0
- augment_sdk-0.1.1.dist-info/licenses/LICENSE +22 -0
- augment_sdk-0.1.1.dist-info/top_level.txt +1 -0
augment/agent.py
ADDED
|
@@ -0,0 +1,1139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main Agent class for the Augment Python SDK.
|
|
3
|
+
|
|
4
|
+
This implementation uses the ACP (Agent Client Protocol) client for
|
|
5
|
+
better performance and real-time streaming.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import subprocess
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from dataclasses import dataclass, fields, is_dataclass
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import (
|
|
17
|
+
Any,
|
|
18
|
+
Callable,
|
|
19
|
+
Dict,
|
|
20
|
+
Generator,
|
|
21
|
+
List,
|
|
22
|
+
Optional,
|
|
23
|
+
Tuple,
|
|
24
|
+
Type,
|
|
25
|
+
TypeVar,
|
|
26
|
+
Union,
|
|
27
|
+
get_args,
|
|
28
|
+
get_origin,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from .acp import ACPClient, AuggieACPClient
|
|
32
|
+
from .exceptions import (
|
|
33
|
+
AugmentCLIError,
|
|
34
|
+
AugmentNotFoundError,
|
|
35
|
+
AugmentParseError,
|
|
36
|
+
AugmentWorkspaceError,
|
|
37
|
+
AugmentVerificationError,
|
|
38
|
+
)
|
|
39
|
+
from .listener import AgentListener
|
|
40
|
+
from .listener_adapter import AgentListenerAdapter
|
|
41
|
+
from .prompt_formatter import AgentPromptFormatter, DEFAULT_INFERENCE_TYPES
|
|
42
|
+
|
|
43
|
+
T = TypeVar("T")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class VerificationResult:
|
|
48
|
+
"""Result of verifying success criteria."""
|
|
49
|
+
|
|
50
|
+
all_criteria_met: bool
|
|
51
|
+
unmet_criteria: List[int] # Indices of criteria that are not met (1-based)
|
|
52
|
+
issues: List[str] # Description of each issue found
|
|
53
|
+
overall_assessment: str # Brief overall assessment
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_type_name(type_hint: Any) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Safely get the name of a type hint.
|
|
59
|
+
|
|
60
|
+
Handles both actual types and string annotations.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
type_hint: A type or type annotation
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
String name of the type
|
|
67
|
+
"""
|
|
68
|
+
if isinstance(type_hint, str):
|
|
69
|
+
return type_hint
|
|
70
|
+
if hasattr(type_hint, "__name__"):
|
|
71
|
+
return type_hint.__name__
|
|
72
|
+
return str(type_hint)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class Model:
|
|
77
|
+
"""
|
|
78
|
+
Represents an available AI model.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
id: The model identifier used with --model flag (e.g., "sonnet4.5")
|
|
82
|
+
name: The human-readable model name (e.g., "Claude Sonnet 4.5")
|
|
83
|
+
description: Additional information about the model (e.g., "Anthropic Claude Sonnet 4.5, 200k context")
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
id: str
|
|
87
|
+
name: str
|
|
88
|
+
description: str
|
|
89
|
+
|
|
90
|
+
def __str__(self) -> str:
|
|
91
|
+
return f"{self.name} [{self.id}]"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Agent:
|
|
95
|
+
"""
|
|
96
|
+
Augment CLI agent interface for programmatic access.
|
|
97
|
+
|
|
98
|
+
This class provides a Python interface to the Augment CLI agent (auggie),
|
|
99
|
+
using the ACP (Agent Client Protocol) for better performance and
|
|
100
|
+
real-time streaming of responses.
|
|
101
|
+
|
|
102
|
+
By default, each run() call creates a fresh session. Use the session()
|
|
103
|
+
context manager to maintain conversation continuity across multiple calls.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
last_model_answer: The last textual explanation returned by the model
|
|
107
|
+
when using typed results. This contains the agent's reasoning or
|
|
108
|
+
context and may be helpful for logging and debugging. None for
|
|
109
|
+
untyped responses or if no message was provided.
|
|
110
|
+
model: The AI model to use for instructions. None uses the CLI's default.
|
|
111
|
+
workspace_path: The resolved workspace path for this agent.
|
|
112
|
+
session_id: The current session ID (only set when inside a session context).
|
|
113
|
+
timeout: Default timeout in seconds for agent operations (defaults to 180 seconds).
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
workspace_root: Optional[Union[str, Path]] = None,
|
|
119
|
+
model: Optional[str] = None,
|
|
120
|
+
listener: Optional[AgentListener] = None,
|
|
121
|
+
cli_path: Optional[str] = None,
|
|
122
|
+
acp_client: Optional[ACPClient] = None,
|
|
123
|
+
removed_tools: Optional[List[str]] = None,
|
|
124
|
+
timeout: int = 180,
|
|
125
|
+
api_key: Optional[str] = None,
|
|
126
|
+
api_url: Optional[str] = None,
|
|
127
|
+
):
|
|
128
|
+
"""
|
|
129
|
+
Initialize an agent instance.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
workspace_root: Path to the workspace root. Defaults to current directory.
|
|
133
|
+
model: The AI model to use (e.g., "claude-3-5-sonnet-latest", "gpt-4o").
|
|
134
|
+
Defaults to the CLI's default model.
|
|
135
|
+
listener: Optional listener for agent events (AgentListener).
|
|
136
|
+
cli_path: Optional path to the Augment CLI. Auto-detected if not provided.
|
|
137
|
+
acp_client: Optional ACP client instance for testing. If provided, this client
|
|
138
|
+
will be used instead of creating a new one. This is primarily for
|
|
139
|
+
testing purposes to allow mocking the ACP client.
|
|
140
|
+
removed_tools: List of tool names to remove/disable (e.g., ["github-api", "linear"]).
|
|
141
|
+
These tools will not be available to the agent.
|
|
142
|
+
timeout: Default timeout in seconds for agent operations. Defaults to 180 seconds (3 minutes).
|
|
143
|
+
This timeout is used when no timeout is specified in the run() method.
|
|
144
|
+
api_key: Optional API key for authentication. If provided, sets AUGMENT_API_TOKEN
|
|
145
|
+
environment variable for the agent process.
|
|
146
|
+
api_url: Optional API URL. If not provided, uses AUGMENT_API_URL environment variable,
|
|
147
|
+
or defaults to "https://api.augmentcode.com". Sets AUGMENT_API_URL environment
|
|
148
|
+
variable for the agent process.
|
|
149
|
+
"""
|
|
150
|
+
self.workspace_path = self._validate_workspace_path(workspace_root)
|
|
151
|
+
self.model = model
|
|
152
|
+
self.listener = listener
|
|
153
|
+
# Create adapter to bridge ACP events to AgentListener
|
|
154
|
+
self._listener_adapter = AgentListenerAdapter(listener) if listener else None
|
|
155
|
+
self.cli_path = cli_path
|
|
156
|
+
self.removed_tools = removed_tools or []
|
|
157
|
+
self.timeout = timeout
|
|
158
|
+
self.api_key = api_key
|
|
159
|
+
self.api_url = (
|
|
160
|
+
api_url
|
|
161
|
+
if api_url is not None
|
|
162
|
+
else os.getenv("AUGMENT_API_URL", "https://api.augmentcode.com")
|
|
163
|
+
)
|
|
164
|
+
self.last_model_answer: Optional[str] = None
|
|
165
|
+
self._acp_client: Optional[ACPClient] = (
|
|
166
|
+
acp_client # Use provided client or None
|
|
167
|
+
)
|
|
168
|
+
self._provided_client = acp_client is not None # Track if client was provided
|
|
169
|
+
self._in_session = False # Track if we're in a session context
|
|
170
|
+
self._prompt_formatter = AgentPromptFormatter() # Handles prompt formatting
|
|
171
|
+
|
|
172
|
+
def run(
|
|
173
|
+
self,
|
|
174
|
+
instruction: str,
|
|
175
|
+
return_type: Optional[Type[T]] = None,
|
|
176
|
+
timeout: Optional[int] = None,
|
|
177
|
+
max_retries: int = 3,
|
|
178
|
+
functions: Optional[List[Callable]] = None,
|
|
179
|
+
success_criteria: Optional[List[str]] = None,
|
|
180
|
+
max_verification_rounds: int = 3,
|
|
181
|
+
) -> Union[T, Any]:
|
|
182
|
+
"""
|
|
183
|
+
Execute an instruction and return the agent's response.
|
|
184
|
+
|
|
185
|
+
If a typed response is requested and parsing fails, the agent will be asked
|
|
186
|
+
to correct its output up to max_retries times.
|
|
187
|
+
|
|
188
|
+
When no return_type is specified, the agent automatically infers the best type
|
|
189
|
+
from common types (int, float, bool, str, list, dict) and returns the result.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
instruction: The instruction to send to the agent
|
|
193
|
+
return_type: Expected return type (int, str, bool, float, list, dict, dataclass, or enum).
|
|
194
|
+
If None, the agent will automatically infer the type.
|
|
195
|
+
timeout: Optional timeout in seconds
|
|
196
|
+
max_retries: Maximum number of retry attempts for parsing failures (default: 3)
|
|
197
|
+
functions: Optional list of Python functions the agent can call. Functions should
|
|
198
|
+
have type hints and docstrings describing their parameters.
|
|
199
|
+
success_criteria: Optional list of criteria that must be met after execution.
|
|
200
|
+
The agent will iteratively work on the task and verify criteria
|
|
201
|
+
until all are met or max_verification_rounds is reached.
|
|
202
|
+
max_verification_rounds: Maximum number of verification rounds when using
|
|
203
|
+
success_criteria (default: 3). Each round consists of
|
|
204
|
+
executing/fixing the task and verifying criteria.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Parsed result (either of requested type, or inferred type)
|
|
208
|
+
|
|
209
|
+
Note:
|
|
210
|
+
After execution, self.last_model_answer contains the agent's textual explanation,
|
|
211
|
+
which may be helpful for logging and debugging. This is separate from the
|
|
212
|
+
structured return value and provides the model's reasoning or context.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If instruction is empty or whitespace, or unsupported return_type
|
|
216
|
+
AugmentParseError: If parsing typed response fails after all retries
|
|
217
|
+
AugmentVerificationError: If success_criteria are not met after max_verification_rounds
|
|
218
|
+
"""
|
|
219
|
+
# Validate instruction
|
|
220
|
+
if not instruction or not instruction.strip():
|
|
221
|
+
raise ValueError("Instruction cannot be empty or whitespace")
|
|
222
|
+
|
|
223
|
+
# Prepare the client
|
|
224
|
+
client = self._prepare_client()
|
|
225
|
+
|
|
226
|
+
# If no success criteria, just run the task once
|
|
227
|
+
if not success_criteria:
|
|
228
|
+
return self._run_task(
|
|
229
|
+
client=client,
|
|
230
|
+
instruction=instruction,
|
|
231
|
+
return_type=return_type,
|
|
232
|
+
timeout=timeout,
|
|
233
|
+
max_retries=max_retries,
|
|
234
|
+
functions=functions,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# With success criteria: loop until verified or max rounds
|
|
238
|
+
current_instruction = instruction
|
|
239
|
+
|
|
240
|
+
for round_num in range(max_verification_rounds):
|
|
241
|
+
# Execute the task (or fix issues from previous round)
|
|
242
|
+
result = self._run_task(
|
|
243
|
+
client=client,
|
|
244
|
+
instruction=current_instruction,
|
|
245
|
+
return_type=return_type,
|
|
246
|
+
timeout=timeout,
|
|
247
|
+
max_retries=max_retries,
|
|
248
|
+
functions=functions,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Verify success criteria
|
|
252
|
+
verification = self._verify_success_criteria(
|
|
253
|
+
client=client,
|
|
254
|
+
success_criteria=success_criteria,
|
|
255
|
+
timeout=timeout,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Check if all criteria are met
|
|
259
|
+
if verification.all_criteria_met:
|
|
260
|
+
# Success! All criteria verified
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
# Not all criteria met - prepare fix instruction for next round
|
|
264
|
+
if round_num < max_verification_rounds - 1:
|
|
265
|
+
# We have more rounds - prepare fix instruction
|
|
266
|
+
current_instruction = self._prepare_fix_instruction(
|
|
267
|
+
success_criteria=success_criteria,
|
|
268
|
+
verification=verification,
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
# Last round - raise exception
|
|
272
|
+
raise AugmentVerificationError(
|
|
273
|
+
f"Success criteria not fully met after {max_verification_rounds} rounds. "
|
|
274
|
+
f"Unmet criteria: {verification.unmet_criteria}. "
|
|
275
|
+
f"Issues: {verification.issues}",
|
|
276
|
+
unmet_criteria=verification.unmet_criteria,
|
|
277
|
+
issues=verification.issues,
|
|
278
|
+
rounds_attempted=max_verification_rounds,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# This should never be reached, but just in case
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
def _prepare_client(self) -> ACPClient:
|
|
285
|
+
"""
|
|
286
|
+
Prepare the ACP client for execution.
|
|
287
|
+
|
|
288
|
+
This method:
|
|
289
|
+
1. Creates the client if it doesn't exist
|
|
290
|
+
2. Starts the client if it's not running
|
|
291
|
+
3. Clears context if not in a session (for fresh start)
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
ACP client instance ready for use
|
|
295
|
+
"""
|
|
296
|
+
# Create client if needed
|
|
297
|
+
if self._acp_client is None:
|
|
298
|
+
self._acp_client = AuggieACPClient(
|
|
299
|
+
cli_path=self.cli_path,
|
|
300
|
+
model=self.model,
|
|
301
|
+
workspace_root=str(self.workspace_path),
|
|
302
|
+
listener=self._listener_adapter, # Pass the adapter to ACP
|
|
303
|
+
removed_tools=self.removed_tools,
|
|
304
|
+
api_key=self.api_key,
|
|
305
|
+
api_url=self.api_url,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Start if not running
|
|
309
|
+
if not self._acp_client.is_running:
|
|
310
|
+
self._acp_client.start()
|
|
311
|
+
|
|
312
|
+
# If not in a session, clear context for a fresh start
|
|
313
|
+
if not self._in_session:
|
|
314
|
+
self._acp_client.clear_context()
|
|
315
|
+
|
|
316
|
+
return self._acp_client
|
|
317
|
+
|
|
318
|
+
def _run_task(
|
|
319
|
+
self,
|
|
320
|
+
client: ACPClient,
|
|
321
|
+
instruction: str,
|
|
322
|
+
return_type: Optional[Type[T]],
|
|
323
|
+
timeout: Optional[int],
|
|
324
|
+
max_retries: int,
|
|
325
|
+
functions: Optional[List[Callable]] = None,
|
|
326
|
+
) -> Union[T, Any]:
|
|
327
|
+
"""
|
|
328
|
+
Run a single task: talk to the agent in a loop until the task is complete.
|
|
329
|
+
|
|
330
|
+
This method implements a loop that:
|
|
331
|
+
1. Sends the current instruction to the agent
|
|
332
|
+
2. If response contains function calls, executes them and continues
|
|
333
|
+
3. If response can be parsed successfully, returns the result
|
|
334
|
+
4. If parsing fails and retries remain, asks agent to fix output and continues
|
|
335
|
+
5. Repeats until success or max rounds exhausted
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
client: The ACP client to use
|
|
339
|
+
instruction: The instruction to send
|
|
340
|
+
return_type: Expected return type (None for type inference)
|
|
341
|
+
timeout: Timeout in seconds (None for default)
|
|
342
|
+
max_retries: Maximum number of retry attempts for parsing failures
|
|
343
|
+
functions: Optional list of functions the agent can call
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Parsed result (either of requested type, or inferred type)
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
AugmentParseError: If parsing fails after all retries
|
|
350
|
+
"""
|
|
351
|
+
# Set effective timeout - use provided timeout, or fall back to instance default
|
|
352
|
+
effective_timeout = (
|
|
353
|
+
float(timeout) if timeout is not None else float(self.timeout)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Prepare the instruction prompt
|
|
357
|
+
current_instruction, function_map = (
|
|
358
|
+
self._prompt_formatter.prepare_instruction_prompt(
|
|
359
|
+
instruction, return_type, functions
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Determine if we're in type inference mode
|
|
364
|
+
infer_type = return_type is None
|
|
365
|
+
|
|
366
|
+
# Track retry attempts for parsing failures
|
|
367
|
+
parse_retry_count = 0
|
|
368
|
+
|
|
369
|
+
# Maximum total rounds to prevent infinite loops
|
|
370
|
+
# This includes both function calls and parse retries
|
|
371
|
+
max_rounds = 20
|
|
372
|
+
|
|
373
|
+
for round_num in range(max_rounds):
|
|
374
|
+
# Send message to agent
|
|
375
|
+
raw_response = client.send_message(current_instruction, effective_timeout)
|
|
376
|
+
|
|
377
|
+
# Check if there are function calls to handle
|
|
378
|
+
if function_map:
|
|
379
|
+
function_calls = self._parse_function_calls(raw_response)
|
|
380
|
+
if function_calls:
|
|
381
|
+
# Execute functions and prepare next instruction with results
|
|
382
|
+
current_instruction = self._handle_function_calls(
|
|
383
|
+
raw_response, function_map
|
|
384
|
+
)
|
|
385
|
+
# Continue loop to send function results back to agent
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
# No function calls - try to parse the response
|
|
389
|
+
try:
|
|
390
|
+
if infer_type:
|
|
391
|
+
result, _ = self._parse_type_inference_response(
|
|
392
|
+
raw_response, DEFAULT_INFERENCE_TYPES
|
|
393
|
+
)
|
|
394
|
+
return result
|
|
395
|
+
else:
|
|
396
|
+
# At this point, return_type must be a valid type (not None)
|
|
397
|
+
assert return_type is not None
|
|
398
|
+
result = self._parse_typed_response(raw_response, return_type)
|
|
399
|
+
return result # type: ignore[return-value,no-any-return]
|
|
400
|
+
|
|
401
|
+
except AugmentParseError as e:
|
|
402
|
+
# Parsing failed - check if we have retries left
|
|
403
|
+
if parse_retry_count >= max_retries:
|
|
404
|
+
raise
|
|
405
|
+
|
|
406
|
+
# Increment retry count and prepare retry instruction
|
|
407
|
+
parse_retry_count += 1
|
|
408
|
+
current_instruction = self._prompt_formatter.prepare_retry_prompt(
|
|
409
|
+
return_type,
|
|
410
|
+
str(e),
|
|
411
|
+
parse_retry_count,
|
|
412
|
+
)
|
|
413
|
+
# Continue loop to retry with corrected instruction
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
# If we've exhausted all rounds, raise an error
|
|
417
|
+
raise AugmentParseError(
|
|
418
|
+
f"Exceeded maximum rounds ({max_rounds}) without successful completion"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
def _verify_success_criteria(
|
|
422
|
+
self,
|
|
423
|
+
client: ACPClient,
|
|
424
|
+
success_criteria: List[str],
|
|
425
|
+
timeout: Optional[int],
|
|
426
|
+
) -> VerificationResult:
|
|
427
|
+
"""
|
|
428
|
+
Verify that success criteria are met and return structured feedback.
|
|
429
|
+
|
|
430
|
+
This method checks each criterion and returns detailed information about
|
|
431
|
+
which criteria are met and which are not, along with specific issues.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
client: The ACP client to use
|
|
435
|
+
success_criteria: List of criteria that must be met
|
|
436
|
+
timeout: Timeout in seconds (None for default)
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
VerificationResult with detailed verification feedback
|
|
440
|
+
"""
|
|
441
|
+
# Build verification instruction
|
|
442
|
+
criteria_list = "\n".join(
|
|
443
|
+
f"{i + 1}. {criterion}" for i, criterion in enumerate(success_criteria)
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
verification_instruction = f"""The task has been completed. Please verify that ALL of the following success criteria are met:
|
|
447
|
+
|
|
448
|
+
{criteria_list}
|
|
449
|
+
|
|
450
|
+
For each criterion, check if it is satisfied.
|
|
451
|
+
|
|
452
|
+
Respond with a JSON object with the following structure:
|
|
453
|
+
{{
|
|
454
|
+
"all_criteria_met": true/false,
|
|
455
|
+
"unmet_criteria": [list of criterion numbers (1-based) that are NOT met],
|
|
456
|
+
"issues": ["description of each issue found"],
|
|
457
|
+
"overall_assessment": "brief assessment of the current state"
|
|
458
|
+
}}
|
|
459
|
+
|
|
460
|
+
Example:
|
|
461
|
+
{{
|
|
462
|
+
"all_criteria_met": false,
|
|
463
|
+
"unmet_criteria": [2, 3],
|
|
464
|
+
"issues": ["Criterion 2: Function is missing type hints for parameter 'x'", "Criterion 3: No docstring present"],
|
|
465
|
+
"overall_assessment": "Function exists but lacks type hints and documentation"
|
|
466
|
+
}}"""
|
|
467
|
+
|
|
468
|
+
# Run verification and get structured result
|
|
469
|
+
try:
|
|
470
|
+
verification = self._run_task(
|
|
471
|
+
client=client,
|
|
472
|
+
instruction=verification_instruction,
|
|
473
|
+
return_type=VerificationResult,
|
|
474
|
+
timeout=timeout,
|
|
475
|
+
max_retries=2,
|
|
476
|
+
functions=None,
|
|
477
|
+
)
|
|
478
|
+
return verification
|
|
479
|
+
|
|
480
|
+
except (AugmentParseError, Exception) as e:
|
|
481
|
+
# If verification fails, return a conservative result
|
|
482
|
+
import warnings
|
|
483
|
+
|
|
484
|
+
warnings.warn(
|
|
485
|
+
f"Success criteria verification failed: {e}. "
|
|
486
|
+
"Assuming criteria are not met.",
|
|
487
|
+
UserWarning,
|
|
488
|
+
)
|
|
489
|
+
return VerificationResult(
|
|
490
|
+
all_criteria_met=False,
|
|
491
|
+
unmet_criteria=list(range(1, len(success_criteria) + 1)),
|
|
492
|
+
issues=[f"Verification failed with error: {e}"],
|
|
493
|
+
overall_assessment="Verification could not be completed",
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def _prepare_fix_instruction(
|
|
497
|
+
self,
|
|
498
|
+
success_criteria: List[str],
|
|
499
|
+
verification: VerificationResult,
|
|
500
|
+
) -> str:
|
|
501
|
+
"""
|
|
502
|
+
Prepare an instruction to fix issues identified during verification.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
success_criteria: List of all success criteria
|
|
506
|
+
verification: Verification result with issues
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
Instruction for the agent to fix the issues
|
|
510
|
+
"""
|
|
511
|
+
# Build list of unmet criteria with their descriptions
|
|
512
|
+
unmet_details = []
|
|
513
|
+
for criterion_num in verification.unmet_criteria:
|
|
514
|
+
if 1 <= criterion_num <= len(success_criteria):
|
|
515
|
+
criterion_text = success_criteria[criterion_num - 1]
|
|
516
|
+
unmet_details.append(f"{criterion_num}. {criterion_text}")
|
|
517
|
+
|
|
518
|
+
unmet_list = "\n".join(unmet_details)
|
|
519
|
+
issues_list = "\n".join(f"- {issue}" for issue in verification.issues)
|
|
520
|
+
|
|
521
|
+
fix_instruction = f"""The following success criteria are NOT yet met:
|
|
522
|
+
|
|
523
|
+
{unmet_list}
|
|
524
|
+
|
|
525
|
+
Issues identified:
|
|
526
|
+
{issues_list}
|
|
527
|
+
|
|
528
|
+
Overall assessment: {verification.overall_assessment}
|
|
529
|
+
|
|
530
|
+
Please fix these issues to ensure ALL success criteria are satisfied."""
|
|
531
|
+
|
|
532
|
+
return fix_instruction
|
|
533
|
+
|
|
534
|
+
def _handle_function_calls(
|
|
535
|
+
self,
|
|
536
|
+
raw_response: str,
|
|
537
|
+
function_map: Dict[str, Callable],
|
|
538
|
+
) -> str:
|
|
539
|
+
"""
|
|
540
|
+
Execute function calls from the agent's response and prepare next instruction.
|
|
541
|
+
|
|
542
|
+
This method:
|
|
543
|
+
1. Parses function calls from the response
|
|
544
|
+
2. Executes the functions
|
|
545
|
+
3. Returns an instruction containing the function results to send back to the agent
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
raw_response: The agent's response that contains function calls
|
|
549
|
+
function_map: Mapping of function names to callables
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
Instruction text with function results to send back to the agent
|
|
553
|
+
"""
|
|
554
|
+
# Parse function calls from response
|
|
555
|
+
function_calls = self._parse_function_calls(raw_response)
|
|
556
|
+
|
|
557
|
+
# Execute function calls
|
|
558
|
+
function_results = []
|
|
559
|
+
for func_call in function_calls:
|
|
560
|
+
func_name = func_call.get("name")
|
|
561
|
+
func_args = func_call.get("arguments", {})
|
|
562
|
+
|
|
563
|
+
# Skip if function name is missing
|
|
564
|
+
if not func_name:
|
|
565
|
+
continue
|
|
566
|
+
|
|
567
|
+
if func_name not in function_map:
|
|
568
|
+
result = {"error": f"Function '{func_name}' not found"}
|
|
569
|
+
error = f"Function '{func_name}' not found"
|
|
570
|
+
# Notify listener of error
|
|
571
|
+
if self.listener:
|
|
572
|
+
self.listener.on_function_result(func_name, None, error)
|
|
573
|
+
else:
|
|
574
|
+
try:
|
|
575
|
+
func = function_map[func_name]
|
|
576
|
+
|
|
577
|
+
# Notify listener of function call (right before execution)
|
|
578
|
+
if self.listener:
|
|
579
|
+
self.listener.on_function_call(func_name, func_args)
|
|
580
|
+
|
|
581
|
+
result = func(**func_args)
|
|
582
|
+
# Notify listener of success
|
|
583
|
+
if self.listener:
|
|
584
|
+
self.listener.on_function_result(func_name, result, None)
|
|
585
|
+
except Exception as e:
|
|
586
|
+
result = {"error": f"Error calling {func_name}: {str(e)}"}
|
|
587
|
+
# Notify listener of error
|
|
588
|
+
if self.listener:
|
|
589
|
+
self.listener.on_function_result(func_name, None, str(e))
|
|
590
|
+
|
|
591
|
+
function_results.append({"function": func_name, "result": result})
|
|
592
|
+
|
|
593
|
+
# Build follow-up instruction with function results
|
|
594
|
+
results_text = "Function call results:\n\n"
|
|
595
|
+
for fr in function_results:
|
|
596
|
+
results_text += f"Function: {fr['function']}\n"
|
|
597
|
+
results_text += f"Result: {json.dumps(fr['result'], indent=2)}\n\n"
|
|
598
|
+
|
|
599
|
+
results_text += "\nPlease continue with your response based on these results."
|
|
600
|
+
|
|
601
|
+
return results_text
|
|
602
|
+
|
|
603
|
+
def _parse_function_calls(self, response: str) -> List[Dict[str, Any]]:
|
|
604
|
+
"""
|
|
605
|
+
Parse function calls from agent response.
|
|
606
|
+
|
|
607
|
+
Looks for <function-call>...</function-call> blocks containing JSON.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
response: Agent's response text
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
List of function call dictionaries with 'name' and 'arguments' keys
|
|
614
|
+
"""
|
|
615
|
+
function_calls = []
|
|
616
|
+
pattern = r"<function-call>\s*(\{.*?\})\s*</function-call>"
|
|
617
|
+
matches = re.findall(pattern, response, re.DOTALL)
|
|
618
|
+
|
|
619
|
+
for match in matches:
|
|
620
|
+
try:
|
|
621
|
+
func_call = json.loads(match)
|
|
622
|
+
if "name" in func_call:
|
|
623
|
+
function_calls.append(func_call)
|
|
624
|
+
except json.JSONDecodeError:
|
|
625
|
+
# Skip invalid JSON
|
|
626
|
+
continue
|
|
627
|
+
|
|
628
|
+
return function_calls
|
|
629
|
+
|
|
630
|
+
def _parse_type_inference_response(
|
|
631
|
+
self, response: str, possible_types: List[Type[Any]]
|
|
632
|
+
) -> Tuple[Any, Type[Any]]:
|
|
633
|
+
"""
|
|
634
|
+
Parse type inference response.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
response: The agent's response
|
|
638
|
+
possible_types: List of possible types
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
Tuple of (parsed result, chosen type)
|
|
642
|
+
|
|
643
|
+
Raises:
|
|
644
|
+
AugmentParseError: If parsing fails
|
|
645
|
+
"""
|
|
646
|
+
# Extract message
|
|
647
|
+
message_match = re.search(
|
|
648
|
+
r"<augment-agent-message>\s*(.*?)\s*</augment-agent-message>",
|
|
649
|
+
response,
|
|
650
|
+
re.DOTALL,
|
|
651
|
+
)
|
|
652
|
+
if message_match:
|
|
653
|
+
self.last_model_answer = message_match.group(1).strip()
|
|
654
|
+
else:
|
|
655
|
+
self.last_model_answer = None
|
|
656
|
+
|
|
657
|
+
# Extract type name
|
|
658
|
+
type_match = re.search(
|
|
659
|
+
r"<augment-agent-type>\s*(\w+)\s*</augment-agent-type>", response, re.DOTALL
|
|
660
|
+
)
|
|
661
|
+
if not type_match:
|
|
662
|
+
# If no type tags found, check if response is empty or just whitespace
|
|
663
|
+
# This can happen when agent completes a task (like file creation) successfully
|
|
664
|
+
# but doesn't provide a structured response
|
|
665
|
+
if not response or response.strip() == "":
|
|
666
|
+
# Return empty string as success indicator
|
|
667
|
+
return "", str
|
|
668
|
+
|
|
669
|
+
# If there's content but no tags, try to extract it as a string
|
|
670
|
+
# This handles cases where agent responds without proper formatting
|
|
671
|
+
content = response.strip()
|
|
672
|
+
if content:
|
|
673
|
+
# Return the content as a string
|
|
674
|
+
return content, str
|
|
675
|
+
|
|
676
|
+
raise AugmentParseError(
|
|
677
|
+
"No type classification found. Expected <augment-agent-type> tags."
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
type_name = type_match.group(1).strip()
|
|
681
|
+
|
|
682
|
+
# Find the matching type by name
|
|
683
|
+
chosen_type = None
|
|
684
|
+
for t in possible_types:
|
|
685
|
+
if t.__name__ == type_name:
|
|
686
|
+
chosen_type = t
|
|
687
|
+
break
|
|
688
|
+
|
|
689
|
+
if chosen_type is None:
|
|
690
|
+
type_names = [t.__name__ for t in possible_types]
|
|
691
|
+
raise AugmentParseError(
|
|
692
|
+
f"Invalid type name '{type_name}'. Must be one of: {', '.join(type_names)}"
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Extract result
|
|
696
|
+
result_match = re.search(
|
|
697
|
+
r"<augment-agent-result>\s*(.*?)\s*</augment-agent-result>",
|
|
698
|
+
response,
|
|
699
|
+
re.DOTALL,
|
|
700
|
+
)
|
|
701
|
+
if not result_match:
|
|
702
|
+
raise AugmentParseError(
|
|
703
|
+
"No structured result found. Expected <augment-agent-result> tags."
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
content = result_match.group(1).strip()
|
|
707
|
+
|
|
708
|
+
# Parse the result according to the chosen type
|
|
709
|
+
try:
|
|
710
|
+
parsed_result = self._convert_to_type(content, chosen_type)
|
|
711
|
+
return parsed_result, chosen_type
|
|
712
|
+
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
|
713
|
+
raise AugmentParseError(
|
|
714
|
+
f"Could not parse result as {chosen_type.__name__}: {e}"
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
@contextmanager
|
|
718
|
+
def session(
|
|
719
|
+
self, session_id: Optional[str] = None
|
|
720
|
+
) -> Generator["Agent", None, None]:
|
|
721
|
+
"""
|
|
722
|
+
Create a session context for maintaining conversation continuity.
|
|
723
|
+
|
|
724
|
+
By default, each run() call creates a fresh session. Use this context
|
|
725
|
+
manager to maintain conversation continuity across multiple calls.
|
|
726
|
+
|
|
727
|
+
Usage:
|
|
728
|
+
agent = Agent()
|
|
729
|
+
|
|
730
|
+
# Without session - each call is independent
|
|
731
|
+
agent.run("Create a function")
|
|
732
|
+
agent.run("Test it") # ❌ Won't remember the function
|
|
733
|
+
|
|
734
|
+
# With session - calls remember each other
|
|
735
|
+
with agent.session() as session:
|
|
736
|
+
session.run("Create a function called add_numbers")
|
|
737
|
+
session.run("Now test that function") # ✅ Remembers add_numbers!
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
session_id: Optional session ID (currently ignored, auto-generated)
|
|
741
|
+
|
|
742
|
+
Yields:
|
|
743
|
+
Agent: This agent instance
|
|
744
|
+
"""
|
|
745
|
+
# Mark that we're in a session
|
|
746
|
+
self._in_session = True
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
yield self
|
|
750
|
+
finally:
|
|
751
|
+
# End the session
|
|
752
|
+
self._in_session = False
|
|
753
|
+
|
|
754
|
+
def _get_type_description(self, return_type: Type) -> str:
|
|
755
|
+
"""Get the JSON structure description for the type."""
|
|
756
|
+
# Built-in types
|
|
757
|
+
if return_type in (int, float, str, bool):
|
|
758
|
+
return "your_result_value"
|
|
759
|
+
elif return_type in (list, dict):
|
|
760
|
+
return "your_result_value"
|
|
761
|
+
|
|
762
|
+
# Generic types (list[SomeClass], dict[str, SomeClass], etc.)
|
|
763
|
+
elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
|
|
764
|
+
origin = get_origin(return_type)
|
|
765
|
+
args = get_args(return_type)
|
|
766
|
+
|
|
767
|
+
if origin is list:
|
|
768
|
+
if args and len(args) == 1:
|
|
769
|
+
# list[SomeClass] - create example array with one element
|
|
770
|
+
element_type = args[0]
|
|
771
|
+
if is_dataclass(element_type):
|
|
772
|
+
field_names = [f.name for f in fields(element_type)]
|
|
773
|
+
element_example = {
|
|
774
|
+
name: f"<{name}_value>" for name in field_names
|
|
775
|
+
}
|
|
776
|
+
return json.dumps([element_example], indent=2)
|
|
777
|
+
elif hasattr(element_type, "__name__") and issubclass(
|
|
778
|
+
element_type, Enum
|
|
779
|
+
):
|
|
780
|
+
return '["<enum_value>"]'
|
|
781
|
+
else:
|
|
782
|
+
return '["<element_value>"]'
|
|
783
|
+
else:
|
|
784
|
+
return "your_result_value" # Plain list
|
|
785
|
+
elif origin is dict:
|
|
786
|
+
return "your_result_value" # For now, treat as plain dict
|
|
787
|
+
else:
|
|
788
|
+
return "your_result_value" # Other generic types
|
|
789
|
+
|
|
790
|
+
# Dataclass
|
|
791
|
+
elif is_dataclass(return_type):
|
|
792
|
+
field_names = [f.name for f in fields(return_type)]
|
|
793
|
+
example = {name: f"<{name}_value>" for name in field_names}
|
|
794
|
+
return json.dumps(example, indent=2)
|
|
795
|
+
|
|
796
|
+
# Enum
|
|
797
|
+
elif issubclass(return_type, Enum):
|
|
798
|
+
return "your_enum_value"
|
|
799
|
+
|
|
800
|
+
else:
|
|
801
|
+
raise ValueError(f"Unsupported return type: {return_type}")
|
|
802
|
+
|
|
803
|
+
def _get_type_instructions(self, return_type: Type) -> str:
|
|
804
|
+
"""Get specific instructions for the type."""
|
|
805
|
+
# Built-in types
|
|
806
|
+
if return_type is int:
|
|
807
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return just the integer number (no quotes)."
|
|
808
|
+
elif return_type is float:
|
|
809
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return just the decimal number (no quotes)."
|
|
810
|
+
elif return_type is str:
|
|
811
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return the string value in double quotes."
|
|
812
|
+
elif return_type is bool:
|
|
813
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return either true or false (no quotes)."
|
|
814
|
+
elif return_type is list:
|
|
815
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array: [item1, item2, item3]"
|
|
816
|
+
elif return_type is dict:
|
|
817
|
+
return 'IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON object: {"key1": "value1", "key2": "value2"}'
|
|
818
|
+
|
|
819
|
+
# Generic types (list[SomeClass], dict[str, SomeClass], etc.)
|
|
820
|
+
elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
|
|
821
|
+
origin = get_origin(return_type)
|
|
822
|
+
args = get_args(return_type)
|
|
823
|
+
|
|
824
|
+
if origin is list:
|
|
825
|
+
if args and len(args) == 1:
|
|
826
|
+
element_type = args[0]
|
|
827
|
+
if is_dataclass(element_type):
|
|
828
|
+
field_info = []
|
|
829
|
+
for field in fields(element_type):
|
|
830
|
+
field_info.append(
|
|
831
|
+
f" - {field.name}: {_get_type_name(field.type)}"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
return f"""IMPORTANT: Put your JSON array inside <augment-agent-result> tags.
|
|
835
|
+
|
|
836
|
+
Return a JSON array of objects, each with these fields:
|
|
837
|
+
{chr(10).join(field_info)}
|
|
838
|
+
|
|
839
|
+
Example:
|
|
840
|
+
<augment-agent-result>
|
|
841
|
+
[{{"field1": value1, "field2": value2}}, {{"field1": value3, "field2": value4}}]
|
|
842
|
+
</augment-agent-result>"""
|
|
843
|
+
elif hasattr(element_type, "__name__") and issubclass(
|
|
844
|
+
element_type, Enum
|
|
845
|
+
):
|
|
846
|
+
enum_values = [f'"{e.value}"' for e in element_type]
|
|
847
|
+
return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array of enum values: {enum_values}"
|
|
848
|
+
else:
|
|
849
|
+
return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array of {_get_type_name(element_type)} values: [value1, value2, value3]"
|
|
850
|
+
else:
|
|
851
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON array: [item1, item2, item3]"
|
|
852
|
+
elif origin is dict:
|
|
853
|
+
return 'IMPORTANT: Put your result inside <augment-agent-result> tags. Return a JSON object: {"key1": "value1", "key2": "value2"}'
|
|
854
|
+
else:
|
|
855
|
+
return "IMPORTANT: Put your result inside <augment-agent-result> tags. Return the appropriate JSON structure for your type."
|
|
856
|
+
|
|
857
|
+
# Dataclass
|
|
858
|
+
elif is_dataclass(return_type):
|
|
859
|
+
field_info = []
|
|
860
|
+
for field in fields(return_type):
|
|
861
|
+
field_info.append(f" - {field.name}: {_get_type_name(field.type)}")
|
|
862
|
+
|
|
863
|
+
return f"""IMPORTANT: Put your JSON object inside <augment-agent-result> tags.
|
|
864
|
+
|
|
865
|
+
Return a JSON object with these exact fields:
|
|
866
|
+
{chr(10).join(field_info)}
|
|
867
|
+
|
|
868
|
+
Example:
|
|
869
|
+
<augment-agent-result>
|
|
870
|
+
{{"field1": value1, "field2": value2}}
|
|
871
|
+
</augment-agent-result>"""
|
|
872
|
+
|
|
873
|
+
# Enum
|
|
874
|
+
elif issubclass(return_type, Enum):
|
|
875
|
+
enum_values = [f'"{e.value}"' for e in return_type]
|
|
876
|
+
return f"IMPORTANT: Put your result inside <augment-agent-result> tags. Return one of these exact values: {', '.join(enum_values)}"
|
|
877
|
+
|
|
878
|
+
else:
|
|
879
|
+
raise ValueError(f"Unsupported return type: {return_type}")
|
|
880
|
+
|
|
881
|
+
def _parse_typed_response(self, response: str, return_type: Type[T]) -> T:
|
|
882
|
+
"""Parse the agent's structured response into the desired type."""
|
|
883
|
+
# Extract message from <augment-agent-message> tags
|
|
884
|
+
message_match = re.search(
|
|
885
|
+
r"<augment-agent-message>\s*(.*?)\s*</augment-agent-message>",
|
|
886
|
+
response,
|
|
887
|
+
re.DOTALL,
|
|
888
|
+
)
|
|
889
|
+
if message_match:
|
|
890
|
+
self.last_model_answer = message_match.group(1).strip()
|
|
891
|
+
else:
|
|
892
|
+
self.last_model_answer = None
|
|
893
|
+
|
|
894
|
+
# Extract content from <augment-agent-result> tags
|
|
895
|
+
result_match = re.search(
|
|
896
|
+
r"<augment-agent-result>\s*(.*?)\s*</augment-agent-result>",
|
|
897
|
+
response,
|
|
898
|
+
re.DOTALL,
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
if not result_match:
|
|
902
|
+
raise AugmentParseError(
|
|
903
|
+
"No structured result found. Expected <augment-agent-result> tags in response."
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
content = result_match.group(1).strip()
|
|
907
|
+
|
|
908
|
+
try:
|
|
909
|
+
return self._convert_to_type(content, return_type)
|
|
910
|
+
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
|
911
|
+
raise AugmentParseError(
|
|
912
|
+
f"Could not parse result as {return_type.__name__}: {e}"
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
def _convert_to_type(self, content: str, return_type: Type[T]) -> T:
|
|
916
|
+
"""Convert string content to the specified Python type."""
|
|
917
|
+
|
|
918
|
+
# Special case for str - don't JSON parse it, just return as-is
|
|
919
|
+
if return_type is str:
|
|
920
|
+
return content # type: ignore[return-value]
|
|
921
|
+
|
|
922
|
+
# Built-in types that need JSON parsing
|
|
923
|
+
if return_type in (int, float, bool, list, dict):
|
|
924
|
+
parsed = json.loads(content)
|
|
925
|
+
|
|
926
|
+
if not isinstance(parsed, return_type):
|
|
927
|
+
raise ValueError(
|
|
928
|
+
f"Expected {return_type.__name__}, got {type(parsed).__name__}"
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
return parsed
|
|
932
|
+
|
|
933
|
+
# Generic types (list[SomeClass], dict[str, SomeClass], etc.)
|
|
934
|
+
elif hasattr(return_type, "__origin__") or get_origin(return_type) is not None:
|
|
935
|
+
origin = get_origin(return_type)
|
|
936
|
+
args = get_args(return_type)
|
|
937
|
+
|
|
938
|
+
if origin is list:
|
|
939
|
+
parsed = json.loads(content)
|
|
940
|
+
if not isinstance(parsed, list):
|
|
941
|
+
raise ValueError(f"Expected list, got {type(parsed).__name__}")
|
|
942
|
+
|
|
943
|
+
if args and len(args) == 1:
|
|
944
|
+
element_type = args[0]
|
|
945
|
+
# Convert each element to the specified type
|
|
946
|
+
result = []
|
|
947
|
+
for item in parsed:
|
|
948
|
+
if is_dataclass(element_type):
|
|
949
|
+
if not isinstance(item, dict):
|
|
950
|
+
raise ValueError(
|
|
951
|
+
f"Expected dict for dataclass element, got {type(item).__name__}"
|
|
952
|
+
)
|
|
953
|
+
result.append(element_type(**item)) # type: ignore[misc]
|
|
954
|
+
elif hasattr(element_type, "__name__") and issubclass(
|
|
955
|
+
element_type, Enum
|
|
956
|
+
):
|
|
957
|
+
result.append(element_type(item))
|
|
958
|
+
else:
|
|
959
|
+
# For basic types, just validate and append
|
|
960
|
+
if not isinstance(item, element_type):
|
|
961
|
+
raise ValueError(
|
|
962
|
+
f"Expected {element_type.__name__}, got {type(item).__name__}"
|
|
963
|
+
)
|
|
964
|
+
result.append(item)
|
|
965
|
+
return result # type: ignore[return-value]
|
|
966
|
+
else:
|
|
967
|
+
# Plain list without type parameter
|
|
968
|
+
return parsed # type: ignore[return-value]
|
|
969
|
+
elif origin is dict:
|
|
970
|
+
parsed = json.loads(content)
|
|
971
|
+
if not isinstance(parsed, dict):
|
|
972
|
+
raise ValueError(f"Expected dict, got {type(parsed).__name__}")
|
|
973
|
+
return parsed # type: ignore[return-value]
|
|
974
|
+
else:
|
|
975
|
+
# Other generic types - try basic JSON parsing
|
|
976
|
+
return json.loads(content) # type: ignore[return-value,no-any-return]
|
|
977
|
+
|
|
978
|
+
# Dataclass
|
|
979
|
+
elif is_dataclass(return_type):
|
|
980
|
+
parsed = json.loads(content)
|
|
981
|
+
|
|
982
|
+
if not isinstance(parsed, dict):
|
|
983
|
+
raise ValueError(
|
|
984
|
+
f"Expected dict for dataclass, got {type(parsed).__name__}"
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
return return_type(**parsed)
|
|
988
|
+
|
|
989
|
+
# Enum
|
|
990
|
+
elif issubclass(return_type, Enum):
|
|
991
|
+
# Try JSON parsing first (for quoted values)
|
|
992
|
+
try:
|
|
993
|
+
parsed = json.loads(content)
|
|
994
|
+
return return_type(parsed) # type: ignore[return-value]
|
|
995
|
+
except json.JSONDecodeError:
|
|
996
|
+
# If JSON parsing fails, try direct string value
|
|
997
|
+
return return_type(content.strip()) # type: ignore[return-value]
|
|
998
|
+
|
|
999
|
+
else:
|
|
1000
|
+
raise ValueError(f"Unsupported return type: {return_type}")
|
|
1001
|
+
|
|
1002
|
+
def get_workspace_path(self) -> Path:
|
|
1003
|
+
"""
|
|
1004
|
+
Get the workspace path for this agent.
|
|
1005
|
+
|
|
1006
|
+
Returns:
|
|
1007
|
+
Path object representing the workspace root
|
|
1008
|
+
"""
|
|
1009
|
+
return self.workspace_path
|
|
1010
|
+
|
|
1011
|
+
@property
|
|
1012
|
+
def session_id(self) -> Optional[str]:
|
|
1013
|
+
"""
|
|
1014
|
+
Get the current session ID.
|
|
1015
|
+
|
|
1016
|
+
Returns the session ID only when inside a session context manager.
|
|
1017
|
+
Returns None for standalone run() calls.
|
|
1018
|
+
"""
|
|
1019
|
+
if self._in_session and self._acp_client:
|
|
1020
|
+
return self._acp_client.session_id # type: ignore[attr-defined,no-any-return]
|
|
1021
|
+
return None
|
|
1022
|
+
|
|
1023
|
+
def __repr__(self) -> str:
|
|
1024
|
+
"""String representation of the Agent."""
|
|
1025
|
+
if self.session_id:
|
|
1026
|
+
return f"Agent(workspace_path='{self.workspace_path}', session_id='{self.session_id}')"
|
|
1027
|
+
return f"Agent(workspace_path='{self.workspace_path}')"
|
|
1028
|
+
|
|
1029
|
+
def __del__(self) -> None:
|
|
1030
|
+
"""Cleanup when agent is destroyed."""
|
|
1031
|
+
# Check if attributes exist (in case __init__ failed)
|
|
1032
|
+
if hasattr(self, "_acp_client") and self._acp_client:
|
|
1033
|
+
try:
|
|
1034
|
+
self._acp_client.stop()
|
|
1035
|
+
except Exception:
|
|
1036
|
+
pass # Ignore errors during cleanup
|
|
1037
|
+
|
|
1038
|
+
@staticmethod
|
|
1039
|
+
def _validate_workspace_path(workspace_root: Optional[Union[str, Path]]) -> Path:
|
|
1040
|
+
"""
|
|
1041
|
+
Validate and resolve workspace path.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
workspace_root: User-provided workspace path or None
|
|
1045
|
+
|
|
1046
|
+
Returns:
|
|
1047
|
+
Resolved Path object
|
|
1048
|
+
|
|
1049
|
+
Raises:
|
|
1050
|
+
AugmentWorkspaceError: If path is invalid
|
|
1051
|
+
"""
|
|
1052
|
+
if workspace_root is None:
|
|
1053
|
+
return Path.cwd()
|
|
1054
|
+
|
|
1055
|
+
path = Path(workspace_root).resolve()
|
|
1056
|
+
|
|
1057
|
+
if not path.exists():
|
|
1058
|
+
raise AugmentWorkspaceError(f"Workspace path does not exist: {path}")
|
|
1059
|
+
|
|
1060
|
+
if not path.is_dir():
|
|
1061
|
+
raise AugmentWorkspaceError(f"Workspace path is not a directory: {path}")
|
|
1062
|
+
|
|
1063
|
+
return path
|
|
1064
|
+
|
|
1065
|
+
@staticmethod
|
|
1066
|
+
def get_available_models() -> List[Model]:
|
|
1067
|
+
"""
|
|
1068
|
+
Get the list of available AI models for the current account.
|
|
1069
|
+
|
|
1070
|
+
This method calls `auggie model list` to retrieve the available models.
|
|
1071
|
+
|
|
1072
|
+
Returns:
|
|
1073
|
+
List of Model objects containing id, name, and description
|
|
1074
|
+
|
|
1075
|
+
Raises:
|
|
1076
|
+
AugmentNotFoundError: If auggie CLI is not found
|
|
1077
|
+
AugmentCLIError: If the CLI command fails
|
|
1078
|
+
|
|
1079
|
+
Example:
|
|
1080
|
+
>>> models = Agent.get_available_models()
|
|
1081
|
+
>>> for model in models:
|
|
1082
|
+
... print(f"{model.name} [{model.id}]")
|
|
1083
|
+
... print(f" {model.description}")
|
|
1084
|
+
Claude Sonnet 4.5 [sonnet4.5]
|
|
1085
|
+
Anthropic Claude Sonnet 4.5, 200k context
|
|
1086
|
+
"""
|
|
1087
|
+
# Check if auggie is available
|
|
1088
|
+
try:
|
|
1089
|
+
result = subprocess.run(
|
|
1090
|
+
["auggie", "model", "list"],
|
|
1091
|
+
capture_output=True,
|
|
1092
|
+
text=True,
|
|
1093
|
+
timeout=30,
|
|
1094
|
+
)
|
|
1095
|
+
except FileNotFoundError:
|
|
1096
|
+
raise AugmentNotFoundError(
|
|
1097
|
+
"auggie CLI not found. Please install auggie and ensure it's in your PATH."
|
|
1098
|
+
)
|
|
1099
|
+
except subprocess.TimeoutExpired:
|
|
1100
|
+
raise AugmentCLIError("Command timed out after 30 seconds", -1, "")
|
|
1101
|
+
|
|
1102
|
+
if result.returncode != 0:
|
|
1103
|
+
raise AugmentCLIError(
|
|
1104
|
+
f"Failed to get model list: {result.stderr}",
|
|
1105
|
+
result.returncode,
|
|
1106
|
+
result.stderr,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
# Parse the output
|
|
1110
|
+
models = []
|
|
1111
|
+
lines = result.stdout.strip().split("\n")
|
|
1112
|
+
|
|
1113
|
+
i = 0
|
|
1114
|
+
while i < len(lines):
|
|
1115
|
+
line = lines[i].strip()
|
|
1116
|
+
|
|
1117
|
+
# Look for lines that start with " - " (model entries)
|
|
1118
|
+
if line.startswith("- "):
|
|
1119
|
+
# Extract name and id from the format: " - Model Name [model-id]"
|
|
1120
|
+
match = re.match(r"^- (.+?)\s+\[([^\]]+)\]$", line)
|
|
1121
|
+
if match:
|
|
1122
|
+
name = match.group(1).strip()
|
|
1123
|
+
model_id = match.group(2).strip()
|
|
1124
|
+
|
|
1125
|
+
# Next line should be the description (indented)
|
|
1126
|
+
description = ""
|
|
1127
|
+
if i + 1 < len(lines):
|
|
1128
|
+
next_line = lines[i + 1].strip()
|
|
1129
|
+
if next_line and not next_line.startswith("- "):
|
|
1130
|
+
description = next_line
|
|
1131
|
+
i += 1 # Skip the description line
|
|
1132
|
+
|
|
1133
|
+
models.append(
|
|
1134
|
+
Model(id=model_id, name=name, description=description)
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
i += 1
|
|
1138
|
+
|
|
1139
|
+
return models
|