flatagents 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,814 @@
1
+ """
2
+ Self-contained FlatAgent base class with pluggable LLM backends.
3
+
4
+ Unifies the agent interface, configuration, and execution loop into a single class.
5
+ LLM interaction is delegated to an LLMBackend, allowing different providers.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ import random
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any, Tuple, Callable, List, Dict, Optional, Protocol, runtime_checkable
13
+
14
+ from .monitoring import get_logger, track_operation
15
+ from .utils import strip_markdown_json
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ try:
20
+ import litellm
21
+ except ImportError:
22
+ litellm = None
23
+
24
+ try:
25
+ import aisuite
26
+ except ImportError:
27
+ aisuite = None
28
+
29
+ try:
30
+ import yaml
31
+ except ImportError:
32
+ yaml = None
33
+
34
+ import json
35
+
36
+
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
+ # LLM Backend Protocol and Implementations
39
+ # ─────────────────────────────────────────────────────────────────────────────
40
+
41
+ @runtime_checkable
42
+ class LLMBackend(Protocol):
43
+ """Protocol for LLM backends. Implement this to support different providers."""
44
+
45
+ total_cost: float
46
+ total_api_calls: int
47
+
48
+ async def call(
49
+ self,
50
+ messages: List[Dict[str, str]],
51
+ **kwargs
52
+ ) -> str:
53
+ """
54
+ Call the LLM with the given messages.
55
+
56
+ Args:
57
+ messages: List of message dicts with 'role' and 'content' keys
58
+ **kwargs: Additional parameters (temperature, max_tokens, etc.)
59
+
60
+ Returns:
61
+ The LLM response content as a string
62
+ """
63
+ ...
64
+
65
+ async def call_raw(
66
+ self,
67
+ messages: List[Dict[str, str]],
68
+ **kwargs
69
+ ) -> Any:
70
+ """
71
+ Call the LLM and return the raw response object.
72
+
73
+ Args:
74
+ messages: List of message dicts with 'role' and 'content' keys
75
+ **kwargs: Additional parameters (temperature, max_tokens, etc.)
76
+
77
+ Returns:
78
+ The raw LiteLLM/provider response object
79
+ """
80
+ ...
81
+
82
+
83
+ class LiteLLMBackend:
84
+ """LLM backend using the litellm library."""
85
+
86
+ def __init__(
87
+ self,
88
+ model: str,
89
+ temperature: float = 0.7,
90
+ max_tokens: int = 2048,
91
+ top_p: float = 1.0,
92
+ frequency_penalty: float = 0.0,
93
+ presence_penalty: float = 0.0,
94
+ retry_delays: Optional[List[float]] = None,
95
+ ):
96
+ if litellm is None:
97
+ raise ImportError("litellm is required. Install with: pip install litellm")
98
+
99
+ self.model = model
100
+ self.llm_kwargs = {
101
+ "temperature": temperature,
102
+ "max_tokens": max_tokens,
103
+ "top_p": top_p,
104
+ "frequency_penalty": frequency_penalty,
105
+ "presence_penalty": presence_penalty,
106
+ }
107
+ self.retry_delays = retry_delays or [1, 2, 4, 8]
108
+ self.total_cost = 0.0
109
+ self.total_api_calls = 0
110
+
111
+ logger.info(f"Initialized LiteLLMBackend with model: {model}")
112
+
113
+ async def call_raw(
114
+ self,
115
+ messages: List[Dict[str, str]],
116
+ **kwargs
117
+ ) -> Any:
118
+ """Call the LLM and return the raw response object with retry logic."""
119
+ call_kwargs = {**self.llm_kwargs, **kwargs}
120
+
121
+ last_exception = None
122
+ for attempt, delay in enumerate(self.retry_delays):
123
+ try:
124
+ self.total_api_calls += 1
125
+ logger.info(f"Calling LLM (Attempt {attempt + 1}/{len(self.retry_delays)})...")
126
+
127
+ response = await litellm.acompletion(
128
+ model=self.model,
129
+ messages=messages,
130
+ **call_kwargs
131
+ )
132
+
133
+ if response is None or response.choices is None or len(response.choices) == 0:
134
+ raise ValueError("Received an empty or invalid response from the LLM.")
135
+
136
+ # Track cost if available
137
+ if hasattr(response, '_hidden_params') and 'response_cost' in response._hidden_params:
138
+ self.total_cost += response._hidden_params['response_cost']
139
+
140
+ return response
141
+
142
+ except Exception as e:
143
+ last_exception = e
144
+ logger.warning(f"LLM call failed on attempt {attempt + 1}: {e}")
145
+ if attempt < len(self.retry_delays) - 1:
146
+ jittered_delay = delay + random.random()
147
+ logger.info(f"Retrying in {jittered_delay:.2f} seconds...")
148
+ await asyncio.sleep(jittered_delay)
149
+
150
+ logger.error("All retry attempts failed.")
151
+ raise last_exception or RuntimeError("LLM call failed after all retries")
152
+
153
+ async def call(
154
+ self,
155
+ messages: List[Dict[str, str]],
156
+ **kwargs
157
+ ) -> str:
158
+ """Call the LLM and return the content string."""
159
+ response = await self.call_raw(messages, **kwargs)
160
+ content = response.choices[0].message.content
161
+ if content is None:
162
+ raise ValueError("The LLM response content was empty.")
163
+ logger.info(f"LLM response received: '{content[:100]}...'")
164
+ return content
165
+
166
+
167
+ class AISuiteBackend:
168
+ """
169
+ LLM backend using the aisuite library (by Andrew Ng).
170
+
171
+ Provides a unified interface to multiple providers:
172
+ OpenAI, Anthropic, Google, AWS, Cohere, Mistral, Ollama, HuggingFace.
173
+
174
+ Model format: "provider:model" (e.g., "openai:gpt-4o", "anthropic:claude-3-5-sonnet")
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ model: str,
180
+ temperature: float = 0.7,
181
+ max_tokens: int = 2048,
182
+ top_p: float = 1.0,
183
+ retry_delays: Optional[List[float]] = None,
184
+ ):
185
+ if aisuite is None:
186
+ raise ImportError("aisuite is required. Install with: pip install aisuite")
187
+
188
+ # Normalize model format: accept both "provider/model" and "provider:model"
189
+ self.model = model.replace("/", ":", 1) if "/" in model else model
190
+ self.llm_kwargs = {
191
+ "temperature": temperature,
192
+ "max_tokens": max_tokens,
193
+ "top_p": top_p,
194
+ }
195
+ self.retry_delays = retry_delays or [1, 2, 4, 8]
196
+ self.total_cost = 0.0
197
+ self.total_api_calls = 0
198
+ self.client = aisuite.Client()
199
+
200
+ logger.info(f"Initialized AISuiteBackend with model: {self.model}")
201
+
202
+ async def call_raw(
203
+ self,
204
+ messages: List[Dict[str, str]],
205
+ **kwargs
206
+ ) -> Any:
207
+ """Call the LLM and return the raw response object with retry logic."""
208
+ call_kwargs = {**self.llm_kwargs, **kwargs}
209
+
210
+ last_exception = None
211
+ for attempt, delay in enumerate(self.retry_delays):
212
+ try:
213
+ self.total_api_calls += 1
214
+ logger.info(f"Calling LLM via AISuite (Attempt {attempt + 1}/{len(self.retry_delays)})...")
215
+
216
+ # aisuite is sync-only, wrap in thread for async compatibility
217
+ response = await asyncio.to_thread(
218
+ self.client.chat.completions.create,
219
+ model=self.model,
220
+ messages=messages,
221
+ **call_kwargs
222
+ )
223
+
224
+ if response is None or response.choices is None or len(response.choices) == 0:
225
+ raise ValueError("Received an empty or invalid response from the LLM.")
226
+
227
+ # Track cost from usage if available
228
+ if hasattr(response, 'usage') and response.usage:
229
+ # Estimate cost based on token counts (rough estimate)
230
+ # This is approximate; providers have different pricing
231
+ usage = response.usage
232
+ prompt_tokens = getattr(usage, 'prompt_tokens', 0) or 0
233
+ completion_tokens = getattr(usage, 'completion_tokens', 0) or 0
234
+ # Very rough estimate: $0.01 per 1K tokens average
235
+ estimated_cost = (prompt_tokens + completion_tokens) * 0.00001
236
+ self.total_cost += estimated_cost
237
+
238
+ return response
239
+
240
+ except Exception as e:
241
+ last_exception = e
242
+ logger.warning(f"AISuite call failed on attempt {attempt + 1}: {e}")
243
+ if attempt < len(self.retry_delays) - 1:
244
+ jittered_delay = delay + random.random()
245
+ logger.info(f"Retrying in {jittered_delay:.2f} seconds...")
246
+ await asyncio.sleep(jittered_delay)
247
+
248
+ logger.error("All retry attempts failed.")
249
+ raise last_exception or RuntimeError("AISuite call failed after all retries")
250
+
251
+ async def call(
252
+ self,
253
+ messages: List[Dict[str, str]],
254
+ **kwargs
255
+ ) -> str:
256
+ """Call the LLM and return the content string."""
257
+ response = await self.call_raw(messages, **kwargs)
258
+ content = response.choices[0].message.content
259
+ if content is None:
260
+ raise ValueError("The LLM response content was empty.")
261
+ logger.info(f"LLM response received: '{content[:100]}...'")
262
+ return content
263
+
264
+
265
+ # ─────────────────────────────────────────────────────────────────────────────
266
+ # Extractors (process LiteLLM responses into structured output)
267
+ # ─────────────────────────────────────────────────────────────────────────────
268
+
269
+ @runtime_checkable
270
+ class Extractor(Protocol):
271
+ """Protocol for response extractors. Process raw LLM responses into structured output."""
272
+
273
+ def extract(self, response: Any) -> Any:
274
+ """
275
+ Extract structured data from a raw LLM response.
276
+
277
+ Args:
278
+ response: Raw response object from LLMBackend.call_raw()
279
+
280
+ Returns:
281
+ Extracted/structured data
282
+ """
283
+ ...
284
+
285
+
286
+ class FreeExtractor:
287
+ """Returns the raw response content as-is. No parsing."""
288
+
289
+ def extract(self, response: Any) -> str:
290
+ """Extract raw content string."""
291
+ content = response.choices[0].message.content
292
+ return content if content is not None else ""
293
+
294
+
295
+ class FreeThinkingExtractor:
296
+ """
297
+ Preserves reasoning/thinking from the response.
298
+ Returns: { "thinking": str, "response": str }
299
+
300
+ Works with models that return thinking in:
301
+ - A separate 'thinking' field
302
+ - Content blocks with type='thinking'
303
+ - <thinking> tags in content
304
+ """
305
+
306
+ def extract(self, response: Any) -> Dict[str, str]:
307
+ """Extract thinking and response separately."""
308
+ import re
309
+ message = response.choices[0].message
310
+ content = message.content or ""
311
+ thinking = ""
312
+
313
+ # Check for thinking in message attributes (provider-specific)
314
+ if hasattr(message, 'thinking') and message.thinking:
315
+ thinking = message.thinking
316
+ # Check for thinking in content blocks (Anthropic style)
317
+ elif hasattr(message, 'content_blocks'):
318
+ for block in message.content_blocks or []:
319
+ if getattr(block, 'type', None) == 'thinking':
320
+ thinking = getattr(block, 'text', '')
321
+ elif getattr(block, 'type', None) == 'text':
322
+ content = getattr(block, 'text', content)
323
+ # Check for <thinking> tags in content
324
+ elif '<thinking>' in content and '</thinking>' in content:
325
+ match = re.search(r'<thinking>(.*?)</thinking>', content, re.DOTALL)
326
+ if match:
327
+ thinking = match.group(1).strip()
328
+ content = re.sub(r'<thinking>.*?</thinking>', '', content, flags=re.DOTALL).strip()
329
+
330
+ return {"thinking": thinking, "response": content}
331
+
332
+
333
+ class StructuredExtractor:
334
+ """
335
+ Extracts structured JSON output using response_format.
336
+ Requires the LLM call to include response_format parameter.
337
+ """
338
+
339
+ def __init__(self, schema: Optional[Dict] = None):
340
+ """
341
+ Args:
342
+ schema: Optional JSON schema for validation
343
+ """
344
+ self.schema = schema
345
+
346
+ def extract(self, response: Any) -> Dict[str, Any]:
347
+ """Extract and parse JSON from response."""
348
+ content = response.choices[0].message.content
349
+ if content is None:
350
+ return {}
351
+
352
+ try:
353
+ # Strip markdown fences - LLMs sometimes wrap JSON in ```json blocks
354
+ parsed = json.loads(strip_markdown_json(content))
355
+ return parsed
356
+ except json.JSONDecodeError as e:
357
+ logger.warning(f"Failed to parse JSON response: {e}")
358
+ return {"_raw": content, "_error": str(e)}
359
+
360
+
361
+ class ToolsExtractor:
362
+ """
363
+ Extracts tool calls from the response.
364
+ Returns: { "tool_calls": [...], "content": str }
365
+ """
366
+
367
+ def extract(self, response: Any) -> Dict[str, Any]:
368
+ """Extract tool calls and content."""
369
+ message = response.choices[0].message
370
+ content = message.content or ""
371
+ tool_calls = []
372
+
373
+ if hasattr(message, 'tool_calls') and message.tool_calls:
374
+ for tc in message.tool_calls:
375
+ tool_call = {
376
+ "id": getattr(tc, 'id', None),
377
+ "type": getattr(tc, 'type', 'function'),
378
+ "function": {
379
+ "name": tc.function.name if hasattr(tc, 'function') else None,
380
+ "arguments": tc.function.arguments if hasattr(tc, 'function') else None,
381
+ }
382
+ }
383
+ # Parse arguments JSON if present
384
+ if tool_call["function"]["arguments"]:
385
+ try:
386
+ tool_call["function"]["arguments"] = json.loads(
387
+ tool_call["function"]["arguments"]
388
+ )
389
+ except json.JSONDecodeError:
390
+ pass # Keep as string if not valid JSON
391
+ tool_calls.append(tool_call)
392
+
393
+ return {"tool_calls": tool_calls, "content": content}
394
+
395
+
396
+ class RegexExtractor:
397
+ """
398
+ Extracts fields from response using regex patterns.
399
+ Patterns are provided at runtime, not in the spec.
400
+
401
+ Can extract from:
402
+ - Raw LLM response object (response.choices[0].message.content)
403
+ - Plain string
404
+ """
405
+
406
+ def __init__(self, patterns: Dict[str, str], types: Optional[Dict[str, str]] = None):
407
+ """
408
+ Args:
409
+ patterns: Map of field names to regex patterns (must have capture group)
410
+ types: Optional map of field names to type names ('str', 'int', 'float', 'bool', 'json')
411
+ """
412
+ import re
413
+ self.patterns = {name: re.compile(pattern) for name, pattern in patterns.items()}
414
+ self.types = types or {}
415
+
416
+ def extract(self, response: Any) -> Optional[Dict[str, Any]]:
417
+ """Extract fields using regex patterns."""
418
+ # Handle both response object and plain string
419
+ if isinstance(response, str):
420
+ content = response
421
+ else:
422
+ content = response.choices[0].message.content
423
+
424
+ if content is None:
425
+ return None
426
+
427
+ result = {}
428
+ for field_name, pattern in self.patterns.items():
429
+ match = pattern.search(content)
430
+ if not match:
431
+ logger.debug(f"Field '{field_name}' pattern did not match")
432
+ return None
433
+
434
+ value = match.group(1)
435
+ field_type = self.types.get(field_name, 'str')
436
+
437
+ try:
438
+ if field_type == 'json':
439
+ result[field_name] = json.loads(value)
440
+ elif field_type == 'int':
441
+ result[field_name] = int(value)
442
+ elif field_type == 'float':
443
+ result[field_name] = float(value)
444
+ elif field_type == 'bool':
445
+ result[field_name] = value.lower() in ('true', '1', 'yes')
446
+ else:
447
+ result[field_name] = value
448
+ except (json.JSONDecodeError, ValueError) as e:
449
+ logger.debug(f"Failed to parse field '{field_name}': {e}")
450
+ return None
451
+
452
+ return result
453
+
454
+
455
+ # ─────────────────────────────────────────────────────────────────────────────
456
+ # MCP Tool Provider Protocol and Types
457
+ # ─────────────────────────────────────────────────────────────────────────────
458
+
459
+ from dataclasses import dataclass, field
460
+
461
+
462
+ @runtime_checkable
463
+ class MCPToolProvider(Protocol):
464
+ """
465
+ Protocol for MCP tool providers.
466
+
467
+ Users implement this to connect their MCP backend (e.g., aisuite.mcp.MCPClient).
468
+ The SDK does not provide an implementation - users bring their own.
469
+
470
+ Example implementation using aisuite:
471
+
472
+ class AISuiteMCPProvider:
473
+ def __init__(self):
474
+ self._clients = {}
475
+
476
+ def connect(self, server_name: str, config: dict):
477
+ from aisuite.mcp import MCPClient
478
+ if server_name not in self._clients:
479
+ self._clients[server_name] = MCPClient.from_config(config)
480
+
481
+ def get_tools(self, server_name: str) -> list:
482
+ return self._clients[server_name].list_tools()
483
+
484
+ def call_tool(self, server_name: str, tool_name: str, arguments: dict):
485
+ return self._clients[server_name].call_tool(tool_name, arguments)
486
+
487
+ def close(self):
488
+ for c in self._clients.values():
489
+ c.close()
490
+ """
491
+
492
+ def connect(self, server_name: str, config: Dict[str, Any]) -> None:
493
+ """
494
+ Connect to an MCP server with the given configuration.
495
+
496
+ Args:
497
+ server_name: Identifier for this server (matches key in mcp.servers)
498
+ config: Server configuration (command/args for stdio, server_url for HTTP)
499
+ """
500
+ ...
501
+
502
+ def get_tools(self, server_name: str) -> List[Dict[str, Any]]:
503
+ """
504
+ Get available tools from an MCP server.
505
+
506
+ Args:
507
+ server_name: Server identifier
508
+
509
+ Returns:
510
+ List of tool definitions with 'name', 'description', 'inputSchema'
511
+ """
512
+ ...
513
+
514
+ def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
515
+ """
516
+ Execute a tool call on an MCP server.
517
+
518
+ Args:
519
+ server_name: Server identifier
520
+ tool_name: Name of the tool to call
521
+ arguments: Tool arguments
522
+
523
+ Returns:
524
+ Tool execution result
525
+ """
526
+ ...
527
+
528
+ def close(self) -> None:
529
+ """Cleanup all server connections."""
530
+ ...
531
+
532
+
533
+ @dataclass
534
+ class ToolCall:
535
+ """
536
+ Represents a tool call request from the LLM.
537
+
538
+ Attributes:
539
+ id: Unique identifier for this tool call (from LLM response)
540
+ server: MCP server name (matches key in mcp.servers config)
541
+ tool: Tool name
542
+ arguments: Tool arguments as a dictionary
543
+ """
544
+ id: str
545
+ server: str
546
+ tool: str
547
+ arguments: Dict[str, Any] = field(default_factory=dict)
548
+
549
+
550
+ @dataclass
551
+ class AgentResponse:
552
+ """
553
+ Response from an agent call.
554
+
555
+ Attributes:
556
+ content: Raw text content from LLM (may be None if only tool calls)
557
+ output: Parsed output according to output schema (if defined)
558
+ tool_calls: List of tool calls requested by LLM (if any)
559
+ raw_response: Raw LLM response object for advanced use cases
560
+ """
561
+ content: Optional[str] = None
562
+ output: Optional[Dict[str, Any]] = None
563
+ tool_calls: Optional[List[ToolCall]] = None
564
+ raw_response: Optional[Any] = None
565
+
566
+
567
+ # ─────────────────────────────────────────────────────────────────────────────
568
+ # FlatAgent Base Class
569
+ # ─────────────────────────────────────────────────────────────────────────────
570
+
571
+ class FlatAgent(ABC):
572
+ """
573
+ Abstract base class for self-contained flat agents.
574
+
575
+ Combines the agent interface, configuration, and execution loop.
576
+ LLM interaction is delegated to a pluggable LLMBackend.
577
+
578
+ Configuration can be provided via:
579
+ - config_file: Path to a YAML configuration file
580
+ - config_dict: A dictionary with configuration
581
+ - backend: Custom LLMBackend instance (overrides config-based backend)
582
+ - **kwargs: Override individual parameters
583
+
584
+ Example usage:
585
+ class MyAgent(FlatAgent):
586
+ def create_initial_state(self): return {}
587
+ def generate_step_prompt(self, state): return "..."
588
+ def update_state(self, state, result): return {**state, 'result': result}
589
+ def is_solved(self, state): return state.get('done', False)
590
+
591
+ # Using config file (creates LiteLLMBackend automatically)
592
+ agent = MyAgent(config_file="config.yaml")
593
+
594
+ # Using custom backend
595
+ backend = LiteLLMBackend(model="openai/gpt-4", temperature=0.5)
596
+ agent = MyAgent(backend=backend)
597
+
598
+ trace = await agent.execute()
599
+ """
600
+
601
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
602
+
603
+ def __init__(
604
+ self,
605
+ config_file: Optional[str] = None,
606
+ config_dict: Optional[Dict] = None,
607
+ backend: Optional[LLMBackend] = None,
608
+ **kwargs
609
+ ):
610
+ """
611
+ Initialize the agent with configuration and optional backend.
612
+
613
+ Args:
614
+ config_file: Path to YAML config file
615
+ config_dict: Configuration dictionary
616
+ backend: Custom LLMBackend (if not provided, creates LiteLLMBackend from config)
617
+ **kwargs: Override specific config values
618
+ """
619
+ self._load_config(config_file, config_dict, **kwargs)
620
+
621
+ if backend is not None:
622
+ self.backend = backend
623
+ else:
624
+ self.backend = self._create_default_backend()
625
+
626
+ logger.info(f"Initialized {self.__class__.__name__} with backend: {self.backend.__class__.__name__}")
627
+
628
+ def _load_config(
629
+ self,
630
+ config_file: Optional[str],
631
+ config_dict: Optional[Dict],
632
+ **kwargs
633
+ ):
634
+ """Load and process configuration from file (YAML or JSON), dict, or kwargs."""
635
+ config = {}
636
+
637
+ if config_file is not None:
638
+ if not os.path.exists(config_file):
639
+ raise FileNotFoundError(f"Configuration file not found: {config_file}")
640
+
641
+ with open(config_file, 'r') as f:
642
+ if config_file.endswith('.json'):
643
+ config = json.load(f) or {}
644
+ else:
645
+ if yaml is None:
646
+ raise ImportError("pyyaml is required for YAML config files. Install with: pip install pyyaml")
647
+ config = yaml.safe_load(f) or {}
648
+ elif config_dict is not None:
649
+ config = config_dict
650
+
651
+ model_config = config.get('model', {})
652
+ defaults = config.get('litellm_defaults', {})
653
+
654
+ # Build model name from provider/name if needed
655
+ provider = model_config.get('provider')
656
+ model_name = model_config.get('name')
657
+ if provider and model_name and '/' not in model_name:
658
+ full_model_name = f"{provider}/{model_name}"
659
+ else:
660
+ full_model_name = model_name
661
+
662
+ def get_value(key: str, fallback: Any) -> Any:
663
+ return kwargs.get(key, model_config.get(key, defaults.get(key, fallback)))
664
+
665
+ # Store config values for backend creation
666
+ self.model = kwargs.get('model', full_model_name)
667
+ self.temperature = get_value('temperature', 0.7)
668
+ self.max_tokens = get_value('max_tokens', 2048)
669
+ self.top_p = get_value('top_p', 1.0)
670
+ self.frequency_penalty = get_value('frequency_penalty', 0.0)
671
+ self.presence_penalty = get_value('presence_penalty', 0.0)
672
+ self.retry_delays = model_config.get('retry_delays', [1, 2, 4, 8])
673
+
674
+ # Store raw config for subclass access
675
+ self.config = config
676
+
677
+ def _create_default_backend(self) -> LLMBackend:
678
+ """Create the default LiteLLMBackend from loaded config."""
679
+ if self.model is None:
680
+ raise ValueError("Model name is required. Provide via config file, config_dict, or 'model' kwarg.")
681
+
682
+ return LiteLLMBackend(
683
+ model=self.model,
684
+ temperature=self.temperature,
685
+ max_tokens=self.max_tokens,
686
+ top_p=self.top_p,
687
+ frequency_penalty=self.frequency_penalty,
688
+ presence_penalty=self.presence_penalty,
689
+ retry_delays=self.retry_delays,
690
+ )
691
+
692
+ # ─────────────────────────────────────────────────────────────────────────
693
+ # Convenience Properties (delegate to backend)
694
+ # ─────────────────────────────────────────────────────────────────────────
695
+
696
+ @property
697
+ def total_cost(self) -> float:
698
+ """Total cost accumulated by the backend."""
699
+ return self.backend.total_cost
700
+
701
+ @property
702
+ def total_api_calls(self) -> int:
703
+ """Total API calls made by the backend."""
704
+ return self.backend.total_api_calls
705
+
706
+ # ─────────────────────────────────────────────────────────────────────────
707
+ # Abstract Methods (subclasses must implement)
708
+ # ─────────────────────────────────────────────────────────────────────────
709
+
710
+ @abstractmethod
711
+ def create_initial_state(self, *args, **kwargs) -> Any:
712
+ """Create the initial state for the problem."""
713
+ pass
714
+
715
+ @abstractmethod
716
+ def generate_step_prompt(self, state: Any) -> str:
717
+ """Generate the user prompt for the next step based on current state."""
718
+ pass
719
+
720
+ @abstractmethod
721
+ def update_state(self, current_state: Any, step_result: Any) -> Any:
722
+ """Update the state based on the step result."""
723
+ pass
724
+
725
+ @abstractmethod
726
+ def is_solved(self, state: Any) -> bool:
727
+ """Check if the problem is solved."""
728
+ pass
729
+
730
+ # ─────────────────────────────────────────────────────────────────────────
731
+ # Overridable Hooks
732
+ # ─────────────────────────────────────────────────────────────────────────
733
+
734
+ def get_system_prompt(self) -> str:
735
+ """
736
+ Get the system prompt for LLM calls.
737
+ Override to customize the system prompt for your agent.
738
+ """
739
+ return self.DEFAULT_SYSTEM_PROMPT
740
+
741
+ def get_response_parser(self) -> Callable[[str], Any]:
742
+ """
743
+ Get the response parser for this agent.
744
+ Override to provide domain-specific parsing of LLM responses.
745
+ """
746
+ return lambda x: x
747
+
748
+ def validate_step_result(self, step_result: Any) -> bool:
749
+ """
750
+ Validate that a step result is acceptable before updating state.
751
+ Override for domain-specific validation.
752
+ """
753
+ return step_result is not None
754
+
755
+ def step_generator(self, state: Any) -> Tuple[Tuple[str, str], Callable[[str], Any]]:
756
+ """
757
+ Generate the prompt tuple and parser for the current state.
758
+
759
+ Returns:
760
+ Tuple of ((system_prompt, user_prompt), response_parser)
761
+
762
+ Override for full control over prompt generation.
763
+ """
764
+ system_prompt = self.get_system_prompt()
765
+ user_prompt = self.generate_step_prompt(state)
766
+ parser = self.get_response_parser()
767
+ return (system_prompt, user_prompt), parser
768
+
769
+ # ─────────────────────────────────────────────────────────────────────────
770
+ # Execution
771
+ # ─────────────────────────────────────────────────────────────────────────
772
+
773
+ async def execute(self, *args, **kwargs) -> List[Any]:
774
+ """
775
+ Execute the agent to solve the problem.
776
+
777
+ Args:
778
+ *args, **kwargs: Passed to create_initial_state()
779
+
780
+ Returns:
781
+ List of states representing the execution trace
782
+ """
783
+ logger.info(f"Starting execution with args={args}, kwargs={kwargs}")
784
+
785
+ state = self.create_initial_state(*args, **kwargs)
786
+ trace = [state]
787
+
788
+ while not self.is_solved(state):
789
+ prompt_tuple, parser = self.step_generator(state)
790
+ raw_result = await self._call_llm(prompt_tuple)
791
+ parsed_result = parser(raw_result)
792
+
793
+ if not self.validate_step_result(parsed_result):
794
+ logger.warning(f"Step result validation failed: {parsed_result}")
795
+
796
+ state = self.update_state(state, parsed_result)
797
+ trace.append(state)
798
+ logger.info("State updated.")
799
+
800
+ logger.info(f"Execution completed. Trace length: {len(trace)} states")
801
+ return trace
802
+
803
+ async def _call_llm(self, prompt_tuple: Tuple[str, str]) -> str:
804
+ """
805
+ Call the LLM backend with the given prompt.
806
+
807
+ Override this for custom pre/post processing around LLM calls.
808
+ """
809
+ system_prompt, user_prompt = prompt_tuple
810
+ messages = [
811
+ {"role": "system", "content": system_prompt},
812
+ {"role": "user", "content": user_prompt},
813
+ ]
814
+ return await self.backend.call(messages)