hanzo 0.3.21__py3-none-any.whl → 0.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hanzo might be problematic. Click here for more details.

hanzo/base_agent.py ADDED
@@ -0,0 +1,517 @@
1
+ """Base Agent - Unified foundation for all AI agent implementations.
2
+
3
+ This module provides the single base class for all agent operations,
4
+ following DRY principles and ensuring consistent behavior across all agents.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import asyncio
11
+ import logging
12
+ from abc import ABC, abstractmethod
13
+ from typing import Any, Dict, List, Optional, Protocol, TypeVar, Generic
14
+ from dataclasses import dataclass, field
15
+ from datetime import datetime
16
+ from pathlib import Path
17
+
18
+ from .model_registry import registry, ModelConfig
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # Type variables for generic context
25
+ TContext = TypeVar("TContext")
26
+ TResult = TypeVar("TResult")
27
+
28
+
29
+ class AgentContext(Protocol[TContext]):
30
+ """Protocol for agent execution context."""
31
+
32
+ async def log(self, message: str, level: str = "info") -> None:
33
+ """Log a message."""
34
+ ...
35
+
36
+ async def progress(self, message: str, percentage: Optional[float] = None) -> None:
37
+ """Report progress."""
38
+ ...
39
+
40
+
41
+ @dataclass
42
+ class AgentConfig:
43
+ """Configuration for agent execution."""
44
+
45
+ model: str = "claude-3-5-sonnet-20241022"
46
+ timeout: int = 300
47
+ max_retries: int = 3
48
+ working_dir: Optional[Path] = None
49
+ environment: Dict[str, str] = field(default_factory=dict)
50
+ stream_output: bool = False
51
+ use_worktree: bool = False
52
+
53
+ def __post_init__(self) -> None:
54
+ """Resolve model name and validate configuration."""
55
+ self.model = registry.resolve(self.model)
56
+ if self.working_dir and not isinstance(self.working_dir, Path):
57
+ self.working_dir = Path(self.working_dir)
58
+
59
+
60
+ @dataclass
61
+ class AgentResult:
62
+ """Result from agent execution."""
63
+
64
+ success: bool
65
+ output: Optional[str] = None
66
+ error: Optional[str] = None
67
+ duration: Optional[float] = None
68
+ metadata: Dict[str, Any] = field(default_factory=dict)
69
+
70
+ @property
71
+ def content(self) -> str:
72
+ """Get the primary content (output or error)."""
73
+ return self.output if self.success else (self.error or "Unknown error")
74
+
75
+
76
+ class BaseAgent(ABC, Generic[TContext, TResult]):
77
+ """Base class for all AI agents.
78
+
79
+ This is the single foundation for all agent implementations,
80
+ ensuring consistent behavior and eliminating code duplication.
81
+ """
82
+
83
+ def __init__(self, config: Optional[AgentConfig] = None) -> None:
84
+ """Initialize agent with configuration.
85
+
86
+ Args:
87
+ config: Agent configuration
88
+ """
89
+ self.config = config or AgentConfig()
90
+ self._start_time: Optional[datetime] = None
91
+ self._end_time: Optional[datetime] = None
92
+
93
+ @property
94
+ @abstractmethod
95
+ def name(self) -> str:
96
+ """Agent name."""
97
+ ...
98
+
99
+ @property
100
+ @abstractmethod
101
+ def description(self) -> str:
102
+ """Agent description."""
103
+ ...
104
+
105
+ async def execute(
106
+ self,
107
+ prompt: str,
108
+ context: Optional[TContext] = None,
109
+ **kwargs: Any,
110
+ ) -> AgentResult:
111
+ """Execute agent with prompt.
112
+
113
+ Args:
114
+ prompt: The prompt or task
115
+ context: Execution context
116
+ **kwargs: Additional parameters
117
+
118
+ Returns:
119
+ Agent execution result
120
+ """
121
+ self._start_time = datetime.now()
122
+
123
+ try:
124
+ # Setup environment
125
+ env = self._prepare_environment()
126
+
127
+ # Log start
128
+ if context and hasattr(context, "log"):
129
+ await context.log(f"Starting {self.name} with model {self.config.model}")
130
+
131
+ # Execute with retries
132
+ result = await self._execute_with_retries(prompt, context, env, **kwargs)
133
+
134
+ # Calculate duration
135
+ self._end_time = datetime.now()
136
+ duration = (self._end_time - self._start_time).total_seconds()
137
+
138
+ return AgentResult(
139
+ success=True,
140
+ output=result,
141
+ duration=duration,
142
+ metadata={"model": self.config.model, "agent": self.name},
143
+ )
144
+
145
+ except Exception as e:
146
+ self._end_time = datetime.now()
147
+ duration = (self._end_time - self._start_time).total_seconds() if self._start_time else None
148
+
149
+ logger.error(f"Agent {self.name} failed: {e}")
150
+
151
+ return AgentResult(
152
+ success=False,
153
+ error=str(e),
154
+ duration=duration,
155
+ metadata={"model": self.config.model, "agent": self.name},
156
+ )
157
+
158
+ def _prepare_environment(self) -> Dict[str, str]:
159
+ """Prepare environment variables for execution.
160
+
161
+ Returns:
162
+ Environment variables dictionary
163
+ """
164
+ env = os.environ.copy()
165
+
166
+ # Add model-specific API key
167
+ model_config = registry.get(self.config.model)
168
+ if model_config and model_config.api_key_env:
169
+ key_var = model_config.api_key_env
170
+ if key_var in os.environ:
171
+ env[key_var] = os.environ[key_var]
172
+
173
+ # Add Hanzo unified auth
174
+ if "HANZO_API_KEY" in os.environ:
175
+ env["HANZO_API_KEY"] = os.environ["HANZO_API_KEY"]
176
+
177
+ # Add custom environment
178
+ env.update(self.config.environment)
179
+
180
+ return env
181
+
182
+ async def _execute_with_retries(
183
+ self,
184
+ prompt: str,
185
+ context: Optional[TContext],
186
+ env: Dict[str, str],
187
+ **kwargs: Any,
188
+ ) -> str:
189
+ """Execute with retry logic.
190
+
191
+ Args:
192
+ prompt: The prompt
193
+ context: Execution context
194
+ env: Environment variables
195
+ **kwargs: Additional parameters
196
+
197
+ Returns:
198
+ Execution output
199
+
200
+ Raises:
201
+ Exception: If all retries fail
202
+ """
203
+ last_error = None
204
+
205
+ for attempt in range(self.config.max_retries):
206
+ try:
207
+ # Call the implementation
208
+ result = await self._execute_impl(prompt, context, env, **kwargs)
209
+ return result
210
+
211
+ except asyncio.TimeoutError:
212
+ last_error = f"Timeout after {self.config.timeout} seconds"
213
+ if context and hasattr(context, "log"):
214
+ await context.log(f"Attempt {attempt + 1} timed out", "warning")
215
+
216
+ except Exception as e:
217
+ last_error = str(e)
218
+ if context and hasattr(context, "log"):
219
+ await context.log(f"Attempt {attempt + 1} failed: {e}", "warning")
220
+
221
+ # Don't retry on certain errors
222
+ if "unauthorized" in str(e).lower() or "forbidden" in str(e).lower():
223
+ raise
224
+
225
+ # Wait before retry (exponential backoff)
226
+ if attempt < self.config.max_retries - 1:
227
+ await asyncio.sleep(2 ** attempt)
228
+
229
+ raise Exception(f"All {self.config.max_retries} attempts failed. Last error: {last_error}")
230
+
231
+ @abstractmethod
232
+ async def _execute_impl(
233
+ self,
234
+ prompt: str,
235
+ context: Optional[TContext],
236
+ env: Dict[str, str],
237
+ **kwargs: Any,
238
+ ) -> str:
239
+ """Implementation-specific execution.
240
+
241
+ Args:
242
+ prompt: The prompt
243
+ context: Execution context
244
+ env: Environment variables
245
+ **kwargs: Additional parameters
246
+
247
+ Returns:
248
+ Execution output
249
+ """
250
+ ...
251
+
252
+
253
+ class CLIAgent(BaseAgent[TContext, str]):
254
+ """Base class for CLI-based agents."""
255
+
256
+ @property
257
+ @abstractmethod
258
+ def cli_command(self) -> str:
259
+ """CLI command to execute."""
260
+ ...
261
+
262
+ def build_command(self, prompt: str, **kwargs: Any) -> List[str]:
263
+ """Build the CLI command.
264
+
265
+ Args:
266
+ prompt: The prompt
267
+ **kwargs: Additional parameters
268
+
269
+ Returns:
270
+ Command arguments list
271
+ """
272
+ command = [self.cli_command]
273
+
274
+ # Add model if specified
275
+ model_config = registry.get(self.config.model)
276
+ if model_config:
277
+ command.extend(["--model", model_config.full_name])
278
+
279
+ # Add prompt
280
+ command.append(prompt)
281
+
282
+ return command
283
+
284
+ async def _execute_impl(
285
+ self,
286
+ prompt: str,
287
+ context: Optional[TContext],
288
+ env: Dict[str, str],
289
+ **kwargs: Any,
290
+ ) -> str:
291
+ """Execute CLI command.
292
+
293
+ Args:
294
+ prompt: The prompt
295
+ context: Execution context
296
+ env: Environment variables
297
+ **kwargs: Additional parameters
298
+
299
+ Returns:
300
+ Command output
301
+ """
302
+ command = self.build_command(prompt, **kwargs)
303
+
304
+ # Execute command
305
+ process = await asyncio.create_subprocess_exec(
306
+ *command,
307
+ stdin=asyncio.subprocess.PIPE,
308
+ stdout=asyncio.subprocess.PIPE,
309
+ stderr=asyncio.subprocess.PIPE,
310
+ cwd=str(self.config.working_dir) if self.config.working_dir else None,
311
+ env=env,
312
+ )
313
+
314
+ # Handle timeout
315
+ try:
316
+ stdout, stderr = await asyncio.wait_for(
317
+ process.communicate(prompt.encode() if len(command) == 1 else None),
318
+ timeout=self.config.timeout,
319
+ )
320
+ except asyncio.TimeoutError:
321
+ process.kill()
322
+ raise asyncio.TimeoutError(f"Command timed out after {self.config.timeout} seconds")
323
+
324
+ # Check for errors
325
+ if process.returncode != 0:
326
+ error_msg = stderr.decode() if stderr else "Command failed"
327
+ raise Exception(error_msg)
328
+
329
+ return stdout.decode()
330
+
331
+
332
+ class APIAgent(BaseAgent[TContext, str]):
333
+ """Base class for API-based agents."""
334
+
335
+ async def _execute_impl(
336
+ self,
337
+ prompt: str,
338
+ context: Optional[TContext],
339
+ env: Dict[str, str],
340
+ **kwargs: Any,
341
+ ) -> str:
342
+ """Execute via API.
343
+
344
+ Args:
345
+ prompt: The prompt
346
+ context: Execution context
347
+ env: Environment variables
348
+ **kwargs: Additional parameters
349
+
350
+ Returns:
351
+ API response
352
+ """
353
+ # This would be implemented by specific API agents
354
+ # using the appropriate client library
355
+ raise NotImplementedError("API agents must implement _execute_impl")
356
+
357
+
358
+ class AgentOrchestrator:
359
+ """Orchestrator for managing multiple agents."""
360
+
361
+ def __init__(self, default_config: Optional[AgentConfig] = None) -> None:
362
+ """Initialize orchestrator.
363
+
364
+ Args:
365
+ default_config: Default configuration for agents
366
+ """
367
+ self.default_config = default_config or AgentConfig()
368
+ self._agents: Dict[str, BaseAgent] = {}
369
+ self._semaphore: Optional[asyncio.Semaphore] = None
370
+
371
+ def register(self, agent: BaseAgent) -> None:
372
+ """Register an agent.
373
+
374
+ Args:
375
+ agent: Agent to register
376
+ """
377
+ self._agents[agent.name] = agent
378
+
379
+ def get_agent(self, name: str) -> Optional[BaseAgent]:
380
+ """Get agent by name.
381
+
382
+ Args:
383
+ name: Agent name
384
+
385
+ Returns:
386
+ Agent instance or None
387
+ """
388
+ return self._agents.get(name)
389
+
390
+ async def execute_single(
391
+ self,
392
+ agent_name: str,
393
+ prompt: str,
394
+ context: Optional[Any] = None,
395
+ **kwargs: Any,
396
+ ) -> AgentResult:
397
+ """Execute single agent.
398
+
399
+ Args:
400
+ agent_name: Name of agent to use
401
+ prompt: The prompt
402
+ context: Execution context
403
+ **kwargs: Additional parameters
404
+
405
+ Returns:
406
+ Execution result
407
+ """
408
+ agent = self.get_agent(agent_name)
409
+ if not agent:
410
+ return AgentResult(
411
+ success=False,
412
+ error=f"Agent '{agent_name}' not found",
413
+ )
414
+
415
+ return await agent.execute(prompt, context, **kwargs)
416
+
417
+ async def execute_parallel(
418
+ self,
419
+ tasks: List[Dict[str, Any]],
420
+ max_concurrent: int = 5,
421
+ ) -> List[AgentResult]:
422
+ """Execute multiple agents in parallel.
423
+
424
+ Args:
425
+ tasks: List of task definitions
426
+ max_concurrent: Maximum concurrent executions
427
+
428
+ Returns:
429
+ List of results
430
+ """
431
+ self._semaphore = asyncio.Semaphore(max_concurrent)
432
+
433
+ async def run_with_semaphore(task: Dict[str, Any]) -> AgentResult:
434
+ async with self._semaphore:
435
+ return await self.execute_single(
436
+ task["agent"],
437
+ task["prompt"],
438
+ task.get("context"),
439
+ **task.get("kwargs", {}),
440
+ )
441
+
442
+ return await asyncio.gather(
443
+ *[run_with_semaphore(task) for task in tasks],
444
+ return_exceptions=False,
445
+ )
446
+
447
+ async def execute_consensus(
448
+ self,
449
+ prompt: str,
450
+ agents: List[str],
451
+ threshold: float = 0.66,
452
+ ) -> Dict[str, Any]:
453
+ """Execute consensus operation with multiple agents.
454
+
455
+ Args:
456
+ prompt: The prompt
457
+ agents: List of agent names
458
+ threshold: Agreement threshold
459
+
460
+ Returns:
461
+ Consensus results
462
+ """
463
+ # Execute all agents in parallel
464
+ tasks = [{"agent": agent, "prompt": prompt} for agent in agents]
465
+ results = await self.execute_parallel(tasks)
466
+
467
+ # Analyze consensus
468
+ successful = [r for r in results if r.success]
469
+ agreement = len(successful) / len(results) if results else 0
470
+
471
+ return {
472
+ "consensus_reached": agreement >= threshold,
473
+ "agreement_score": agreement,
474
+ "individual_results": results,
475
+ "agents_used": agents,
476
+ }
477
+
478
+ async def execute_chain(
479
+ self,
480
+ initial_prompt: str,
481
+ agents: List[str],
482
+ ) -> List[AgentResult]:
483
+ """Execute agents in a chain, passing output forward.
484
+
485
+ Args:
486
+ initial_prompt: Initial prompt
487
+ agents: List of agent names
488
+
489
+ Returns:
490
+ List of results from each step
491
+ """
492
+ results = []
493
+ current_prompt = initial_prompt
494
+
495
+ for agent_name in agents:
496
+ result = await self.execute_single(agent_name, current_prompt)
497
+ results.append(result)
498
+
499
+ if result.success and result.output:
500
+ # Use output as input for next agent
501
+ current_prompt = f"Review and improve:\n{result.output}"
502
+ else:
503
+ # Chain broken
504
+ break
505
+
506
+ return results
507
+
508
+
509
+ __all__ = [
510
+ "AgentContext",
511
+ "AgentConfig",
512
+ "AgentResult",
513
+ "BaseAgent",
514
+ "CLIAgent",
515
+ "APIAgent",
516
+ "AgentOrchestrator",
517
+ ]