hanzo 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hanzo might be problematic. Click here for more details.

hanzo/base_agent.py ADDED
@@ -0,0 +1,516 @@
1
+ """Base Agent - Unified foundation for all AI agent implementations.
2
+
3
+ This module provides the single base class for all agent operations,
4
+ following DRY principles and ensuring consistent behavior across all agents.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import asyncio
11
+ import logging
12
+ from abc import ABC, abstractmethod
13
+ from typing import Any, Dict, List, Generic, TypeVar, Optional, Protocol
14
+ from pathlib import Path
15
+ from datetime import datetime
16
+ from dataclasses import field, dataclass
17
+
18
+ from .model_registry import ModelConfig, registry
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # Type variables for generic context
24
+ TContext = TypeVar("TContext")
25
+ TResult = TypeVar("TResult")
26
+
27
+
28
+ class AgentContext(Protocol[TContext]):
29
+ """Protocol for agent execution context."""
30
+
31
+ async def log(self, message: str, level: str = "info") -> None:
32
+ """Log a message."""
33
+ ...
34
+
35
+ async def progress(self, message: str, percentage: Optional[float] = None) -> None:
36
+ """Report progress."""
37
+ ...
38
+
39
+
40
+ @dataclass
41
+ class AgentConfig:
42
+ """Configuration for agent execution."""
43
+
44
+ model: str = "claude-3-5-sonnet-20241022"
45
+ timeout: int = 300
46
+ max_retries: int = 3
47
+ working_dir: Optional[Path] = None
48
+ environment: Dict[str, str] = field(default_factory=dict)
49
+ stream_output: bool = False
50
+ use_worktree: bool = False
51
+
52
+ def __post_init__(self) -> None:
53
+ """Resolve model name and validate configuration."""
54
+ self.model = registry.resolve(self.model)
55
+ if self.working_dir and not isinstance(self.working_dir, Path):
56
+ self.working_dir = Path(self.working_dir)
57
+
58
+
59
+ @dataclass
60
+ class AgentResult:
61
+ """Result from agent execution."""
62
+
63
+ success: bool
64
+ output: Optional[str] = None
65
+ error: Optional[str] = None
66
+ duration: Optional[float] = None
67
+ metadata: Dict[str, Any] = field(default_factory=dict)
68
+
69
+ @property
70
+ def content(self) -> str:
71
+ """Get the primary content (output or error)."""
72
+ return self.output if self.success else (self.error or "Unknown error")
73
+
74
+
75
+ class BaseAgent(ABC, Generic[TContext, TResult]):
76
+ """Base class for all AI agents.
77
+
78
+ This is the single foundation for all agent implementations,
79
+ ensuring consistent behavior and eliminating code duplication.
80
+ """
81
+
82
+ def __init__(self, config: Optional[AgentConfig] = None) -> None:
83
+ """Initialize agent with configuration.
84
+
85
+ Args:
86
+ config: Agent configuration
87
+ """
88
+ self.config = config or AgentConfig()
89
+ self._start_time: Optional[datetime] = None
90
+ self._end_time: Optional[datetime] = None
91
+
92
+ @property
93
+ @abstractmethod
94
+ def name(self) -> str:
95
+ """Agent name."""
96
+ ...
97
+
98
+ @property
99
+ @abstractmethod
100
+ def description(self) -> str:
101
+ """Agent description."""
102
+ ...
103
+
104
+ async def execute(
105
+ self,
106
+ prompt: str,
107
+ context: Optional[TContext] = None,
108
+ **kwargs: Any,
109
+ ) -> AgentResult:
110
+ """Execute agent with prompt.
111
+
112
+ Args:
113
+ prompt: The prompt or task
114
+ context: Execution context
115
+ **kwargs: Additional parameters
116
+
117
+ Returns:
118
+ Agent execution result
119
+ """
120
+ self._start_time = datetime.now()
121
+
122
+ try:
123
+ # Setup environment
124
+ env = self._prepare_environment()
125
+
126
+ # Log start
127
+ if context and hasattr(context, "log"):
128
+ await context.log(f"Starting {self.name} with model {self.config.model}")
129
+
130
+ # Execute with retries
131
+ result = await self._execute_with_retries(prompt, context, env, **kwargs)
132
+
133
+ # Calculate duration
134
+ self._end_time = datetime.now()
135
+ duration = (self._end_time - self._start_time).total_seconds()
136
+
137
+ return AgentResult(
138
+ success=True,
139
+ output=result,
140
+ duration=duration,
141
+ metadata={"model": self.config.model, "agent": self.name},
142
+ )
143
+
144
+ except Exception as e:
145
+ self._end_time = datetime.now()
146
+ duration = (self._end_time - self._start_time).total_seconds() if self._start_time else None
147
+
148
+ logger.error(f"Agent {self.name} failed: {e}")
149
+
150
+ return AgentResult(
151
+ success=False,
152
+ error=str(e),
153
+ duration=duration,
154
+ metadata={"model": self.config.model, "agent": self.name},
155
+ )
156
+
157
+ def _prepare_environment(self) -> Dict[str, str]:
158
+ """Prepare environment variables for execution.
159
+
160
+ Returns:
161
+ Environment variables dictionary
162
+ """
163
+ env = os.environ.copy()
164
+
165
+ # Add model-specific API key
166
+ model_config = registry.get(self.config.model)
167
+ if model_config and model_config.api_key_env:
168
+ key_var = model_config.api_key_env
169
+ if key_var in os.environ:
170
+ env[key_var] = os.environ[key_var]
171
+
172
+ # Add Hanzo unified auth
173
+ if "HANZO_API_KEY" in os.environ:
174
+ env["HANZO_API_KEY"] = os.environ["HANZO_API_KEY"]
175
+
176
+ # Add custom environment
177
+ env.update(self.config.environment)
178
+
179
+ return env
180
+
181
+ async def _execute_with_retries(
182
+ self,
183
+ prompt: str,
184
+ context: Optional[TContext],
185
+ env: Dict[str, str],
186
+ **kwargs: Any,
187
+ ) -> str:
188
+ """Execute with retry logic.
189
+
190
+ Args:
191
+ prompt: The prompt
192
+ context: Execution context
193
+ env: Environment variables
194
+ **kwargs: Additional parameters
195
+
196
+ Returns:
197
+ Execution output
198
+
199
+ Raises:
200
+ Exception: If all retries fail
201
+ """
202
+ last_error = None
203
+
204
+ for attempt in range(self.config.max_retries):
205
+ try:
206
+ # Call the implementation
207
+ result = await self._execute_impl(prompt, context, env, **kwargs)
208
+ return result
209
+
210
+ except asyncio.TimeoutError:
211
+ last_error = f"Timeout after {self.config.timeout} seconds"
212
+ if context and hasattr(context, "log"):
213
+ await context.log(f"Attempt {attempt + 1} timed out", "warning")
214
+
215
+ except Exception as e:
216
+ last_error = str(e)
217
+ if context and hasattr(context, "log"):
218
+ await context.log(f"Attempt {attempt + 1} failed: {e}", "warning")
219
+
220
+ # Don't retry on certain errors
221
+ if "unauthorized" in str(e).lower() or "forbidden" in str(e).lower():
222
+ raise
223
+
224
+ # Wait before retry (exponential backoff)
225
+ if attempt < self.config.max_retries - 1:
226
+ await asyncio.sleep(2 ** attempt)
227
+
228
+ raise Exception(f"All {self.config.max_retries} attempts failed. Last error: {last_error}")
229
+
230
+ @abstractmethod
231
+ async def _execute_impl(
232
+ self,
233
+ prompt: str,
234
+ context: Optional[TContext],
235
+ env: Dict[str, str],
236
+ **kwargs: Any,
237
+ ) -> str:
238
+ """Implementation-specific execution.
239
+
240
+ Args:
241
+ prompt: The prompt
242
+ context: Execution context
243
+ env: Environment variables
244
+ **kwargs: Additional parameters
245
+
246
+ Returns:
247
+ Execution output
248
+ """
249
+ ...
250
+
251
+
252
+ class CLIAgent(BaseAgent[TContext, str]):
253
+ """Base class for CLI-based agents."""
254
+
255
+ @property
256
+ @abstractmethod
257
+ def cli_command(self) -> str:
258
+ """CLI command to execute."""
259
+ ...
260
+
261
+ def build_command(self, prompt: str, **kwargs: Any) -> List[str]:
262
+ """Build the CLI command.
263
+
264
+ Args:
265
+ prompt: The prompt
266
+ **kwargs: Additional parameters
267
+
268
+ Returns:
269
+ Command arguments list
270
+ """
271
+ command = [self.cli_command]
272
+
273
+ # Add model if specified
274
+ model_config = registry.get(self.config.model)
275
+ if model_config:
276
+ command.extend(["--model", model_config.full_name])
277
+
278
+ # Add prompt
279
+ command.append(prompt)
280
+
281
+ return command
282
+
283
+ async def _execute_impl(
284
+ self,
285
+ prompt: str,
286
+ context: Optional[TContext],
287
+ env: Dict[str, str],
288
+ **kwargs: Any,
289
+ ) -> str:
290
+ """Execute CLI command.
291
+
292
+ Args:
293
+ prompt: The prompt
294
+ context: Execution context
295
+ env: Environment variables
296
+ **kwargs: Additional parameters
297
+
298
+ Returns:
299
+ Command output
300
+ """
301
+ command = self.build_command(prompt, **kwargs)
302
+
303
+ # Execute command
304
+ process = await asyncio.create_subprocess_exec(
305
+ *command,
306
+ stdin=asyncio.subprocess.PIPE,
307
+ stdout=asyncio.subprocess.PIPE,
308
+ stderr=asyncio.subprocess.PIPE,
309
+ cwd=str(self.config.working_dir) if self.config.working_dir else None,
310
+ env=env,
311
+ )
312
+
313
+ # Handle timeout
314
+ try:
315
+ stdout, stderr = await asyncio.wait_for(
316
+ process.communicate(prompt.encode() if len(command) == 1 else None),
317
+ timeout=self.config.timeout,
318
+ )
319
+ except asyncio.TimeoutError:
320
+ process.kill()
321
+ raise asyncio.TimeoutError(f"Command timed out after {self.config.timeout} seconds")
322
+
323
+ # Check for errors
324
+ if process.returncode != 0:
325
+ error_msg = stderr.decode() if stderr else "Command failed"
326
+ raise Exception(error_msg)
327
+
328
+ return stdout.decode()
329
+
330
+
331
+ class APIAgent(BaseAgent[TContext, str]):
332
+ """Base class for API-based agents."""
333
+
334
+ async def _execute_impl(
335
+ self,
336
+ prompt: str,
337
+ context: Optional[TContext],
338
+ env: Dict[str, str],
339
+ **kwargs: Any,
340
+ ) -> str:
341
+ """Execute via API.
342
+
343
+ Args:
344
+ prompt: The prompt
345
+ context: Execution context
346
+ env: Environment variables
347
+ **kwargs: Additional parameters
348
+
349
+ Returns:
350
+ API response
351
+ """
352
+ # This would be implemented by specific API agents
353
+ # using the appropriate client library
354
+ raise NotImplementedError("API agents must implement _execute_impl")
355
+
356
+
357
+ class AgentOrchestrator:
358
+ """Orchestrator for managing multiple agents."""
359
+
360
+ def __init__(self, default_config: Optional[AgentConfig] = None) -> None:
361
+ """Initialize orchestrator.
362
+
363
+ Args:
364
+ default_config: Default configuration for agents
365
+ """
366
+ self.default_config = default_config or AgentConfig()
367
+ self._agents: Dict[str, BaseAgent] = {}
368
+ self._semaphore: Optional[asyncio.Semaphore] = None
369
+
370
+ def register(self, agent: BaseAgent) -> None:
371
+ """Register an agent.
372
+
373
+ Args:
374
+ agent: Agent to register
375
+ """
376
+ self._agents[agent.name] = agent
377
+
378
+ def get_agent(self, name: str) -> Optional[BaseAgent]:
379
+ """Get agent by name.
380
+
381
+ Args:
382
+ name: Agent name
383
+
384
+ Returns:
385
+ Agent instance or None
386
+ """
387
+ return self._agents.get(name)
388
+
389
+ async def execute_single(
390
+ self,
391
+ agent_name: str,
392
+ prompt: str,
393
+ context: Optional[Any] = None,
394
+ **kwargs: Any,
395
+ ) -> AgentResult:
396
+ """Execute single agent.
397
+
398
+ Args:
399
+ agent_name: Name of agent to use
400
+ prompt: The prompt
401
+ context: Execution context
402
+ **kwargs: Additional parameters
403
+
404
+ Returns:
405
+ Execution result
406
+ """
407
+ agent = self.get_agent(agent_name)
408
+ if not agent:
409
+ return AgentResult(
410
+ success=False,
411
+ error=f"Agent '{agent_name}' not found",
412
+ )
413
+
414
+ return await agent.execute(prompt, context, **kwargs)
415
+
416
+ async def execute_parallel(
417
+ self,
418
+ tasks: List[Dict[str, Any]],
419
+ max_concurrent: int = 5,
420
+ ) -> List[AgentResult]:
421
+ """Execute multiple agents in parallel.
422
+
423
+ Args:
424
+ tasks: List of task definitions
425
+ max_concurrent: Maximum concurrent executions
426
+
427
+ Returns:
428
+ List of results
429
+ """
430
+ self._semaphore = asyncio.Semaphore(max_concurrent)
431
+
432
+ async def run_with_semaphore(task: Dict[str, Any]) -> AgentResult:
433
+ async with self._semaphore:
434
+ return await self.execute_single(
435
+ task["agent"],
436
+ task["prompt"],
437
+ task.get("context"),
438
+ **task.get("kwargs", {}),
439
+ )
440
+
441
+ return await asyncio.gather(
442
+ *[run_with_semaphore(task) for task in tasks],
443
+ return_exceptions=False,
444
+ )
445
+
446
+ async def execute_consensus(
447
+ self,
448
+ prompt: str,
449
+ agents: List[str],
450
+ threshold: float = 0.66,
451
+ ) -> Dict[str, Any]:
452
+ """Execute consensus operation with multiple agents.
453
+
454
+ Args:
455
+ prompt: The prompt
456
+ agents: List of agent names
457
+ threshold: Agreement threshold
458
+
459
+ Returns:
460
+ Consensus results
461
+ """
462
+ # Execute all agents in parallel
463
+ tasks = [{"agent": agent, "prompt": prompt} for agent in agents]
464
+ results = await self.execute_parallel(tasks)
465
+
466
+ # Analyze consensus
467
+ successful = [r for r in results if r.success]
468
+ agreement = len(successful) / len(results) if results else 0
469
+
470
+ return {
471
+ "consensus_reached": agreement >= threshold,
472
+ "agreement_score": agreement,
473
+ "individual_results": results,
474
+ "agents_used": agents,
475
+ }
476
+
477
+ async def execute_chain(
478
+ self,
479
+ initial_prompt: str,
480
+ agents: List[str],
481
+ ) -> List[AgentResult]:
482
+ """Execute agents in a chain, passing output forward.
483
+
484
+ Args:
485
+ initial_prompt: Initial prompt
486
+ agents: List of agent names
487
+
488
+ Returns:
489
+ List of results from each step
490
+ """
491
+ results = []
492
+ current_prompt = initial_prompt
493
+
494
+ for agent_name in agents:
495
+ result = await self.execute_single(agent_name, current_prompt)
496
+ results.append(result)
497
+
498
+ if result.success and result.output:
499
+ # Use output as input for next agent
500
+ current_prompt = f"Review and improve:\n{result.output}"
501
+ else:
502
+ # Chain broken
503
+ break
504
+
505
+ return results
506
+
507
+
508
+ __all__ = [
509
+ "AgentContext",
510
+ "AgentConfig",
511
+ "AgentResult",
512
+ "BaseAgent",
513
+ "CLIAgent",
514
+ "APIAgent",
515
+ "AgentOrchestrator",
516
+ ]