flatagents 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flatagents/__init__.py +136 -0
- flatagents/actions.py +239 -0
- flatagents/assets/__init__.py +0 -0
- flatagents/assets/flatagent.d.ts +189 -0
- flatagents/assets/flatagent.schema.json +210 -0
- flatagents/assets/flatagent.slim.d.ts +52 -0
- flatagents/assets/flatmachine.d.ts +363 -0
- flatagents/assets/flatmachine.schema.json +515 -0
- flatagents/assets/flatmachine.slim.d.ts +94 -0
- flatagents/backends.py +222 -0
- flatagents/baseagent.py +814 -0
- flatagents/execution.py +462 -0
- flatagents/expressions/__init__.py +60 -0
- flatagents/expressions/cel.py +101 -0
- flatagents/expressions/simple.py +166 -0
- flatagents/flatagent.py +735 -0
- flatagents/flatmachine.py +1176 -0
- flatagents/gcp/__init__.py +25 -0
- flatagents/gcp/firestore.py +227 -0
- flatagents/hooks.py +380 -0
- flatagents/locking.py +69 -0
- flatagents/monitoring.py +373 -0
- flatagents/persistence.py +200 -0
- flatagents/utils.py +46 -0
- flatagents/validation.py +141 -0
- flatagents-0.4.1.dist-info/METADATA +310 -0
- flatagents-0.4.1.dist-info/RECORD +28 -0
- flatagents-0.4.1.dist-info/WHEEL +4 -0
flatagents/baseagent.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Self-contained FlatAgent base class with pluggable LLM backends.
|
|
3
|
+
|
|
4
|
+
Unifies the agent interface, configuration, and execution loop into a single class.
|
|
5
|
+
LLM interaction is delegated to an LLMBackend, allowing different providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import os
|
|
10
|
+
import random
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from typing import Any, Tuple, Callable, List, Dict, Optional, Protocol, runtime_checkable
|
|
13
|
+
|
|
14
|
+
from .monitoring import get_logger, track_operation
|
|
15
|
+
from .utils import strip_markdown_json
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import litellm
|
|
21
|
+
except ImportError:
|
|
22
|
+
litellm = None
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import aisuite
|
|
26
|
+
except ImportError:
|
|
27
|
+
aisuite = None
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import yaml
|
|
31
|
+
except ImportError:
|
|
32
|
+
yaml = None
|
|
33
|
+
|
|
34
|
+
import json
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
38
|
+
# LLM Backend Protocol and Implementations
|
|
39
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
40
|
+
|
|
41
|
+
@runtime_checkable
|
|
42
|
+
class LLMBackend(Protocol):
|
|
43
|
+
"""Protocol for LLM backends. Implement this to support different providers."""
|
|
44
|
+
|
|
45
|
+
total_cost: float
|
|
46
|
+
total_api_calls: int
|
|
47
|
+
|
|
48
|
+
async def call(
|
|
49
|
+
self,
|
|
50
|
+
messages: List[Dict[str, str]],
|
|
51
|
+
**kwargs
|
|
52
|
+
) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Call the LLM with the given messages.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
58
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The LLM response content as a string
|
|
62
|
+
"""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
async def call_raw(
|
|
66
|
+
self,
|
|
67
|
+
messages: List[Dict[str, str]],
|
|
68
|
+
**kwargs
|
|
69
|
+
) -> Any:
|
|
70
|
+
"""
|
|
71
|
+
Call the LLM and return the raw response object.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
75
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The raw LiteLLM/provider response object
|
|
79
|
+
"""
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LiteLLMBackend:
|
|
84
|
+
"""LLM backend using the litellm library."""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
model: str,
|
|
89
|
+
temperature: float = 0.7,
|
|
90
|
+
max_tokens: int = 2048,
|
|
91
|
+
top_p: float = 1.0,
|
|
92
|
+
frequency_penalty: float = 0.0,
|
|
93
|
+
presence_penalty: float = 0.0,
|
|
94
|
+
retry_delays: Optional[List[float]] = None,
|
|
95
|
+
):
|
|
96
|
+
if litellm is None:
|
|
97
|
+
raise ImportError("litellm is required. Install with: pip install litellm")
|
|
98
|
+
|
|
99
|
+
self.model = model
|
|
100
|
+
self.llm_kwargs = {
|
|
101
|
+
"temperature": temperature,
|
|
102
|
+
"max_tokens": max_tokens,
|
|
103
|
+
"top_p": top_p,
|
|
104
|
+
"frequency_penalty": frequency_penalty,
|
|
105
|
+
"presence_penalty": presence_penalty,
|
|
106
|
+
}
|
|
107
|
+
self.retry_delays = retry_delays or [1, 2, 4, 8]
|
|
108
|
+
self.total_cost = 0.0
|
|
109
|
+
self.total_api_calls = 0
|
|
110
|
+
|
|
111
|
+
logger.info(f"Initialized LiteLLMBackend with model: {model}")
|
|
112
|
+
|
|
113
|
+
async def call_raw(
|
|
114
|
+
self,
|
|
115
|
+
messages: List[Dict[str, str]],
|
|
116
|
+
**kwargs
|
|
117
|
+
) -> Any:
|
|
118
|
+
"""Call the LLM and return the raw response object with retry logic."""
|
|
119
|
+
call_kwargs = {**self.llm_kwargs, **kwargs}
|
|
120
|
+
|
|
121
|
+
last_exception = None
|
|
122
|
+
for attempt, delay in enumerate(self.retry_delays):
|
|
123
|
+
try:
|
|
124
|
+
self.total_api_calls += 1
|
|
125
|
+
logger.info(f"Calling LLM (Attempt {attempt + 1}/{len(self.retry_delays)})...")
|
|
126
|
+
|
|
127
|
+
response = await litellm.acompletion(
|
|
128
|
+
model=self.model,
|
|
129
|
+
messages=messages,
|
|
130
|
+
**call_kwargs
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if response is None or response.choices is None or len(response.choices) == 0:
|
|
134
|
+
raise ValueError("Received an empty or invalid response from the LLM.")
|
|
135
|
+
|
|
136
|
+
# Track cost if available
|
|
137
|
+
if hasattr(response, '_hidden_params') and 'response_cost' in response._hidden_params:
|
|
138
|
+
self.total_cost += response._hidden_params['response_cost']
|
|
139
|
+
|
|
140
|
+
return response
|
|
141
|
+
|
|
142
|
+
except Exception as e:
|
|
143
|
+
last_exception = e
|
|
144
|
+
logger.warning(f"LLM call failed on attempt {attempt + 1}: {e}")
|
|
145
|
+
if attempt < len(self.retry_delays) - 1:
|
|
146
|
+
jittered_delay = delay + random.random()
|
|
147
|
+
logger.info(f"Retrying in {jittered_delay:.2f} seconds...")
|
|
148
|
+
await asyncio.sleep(jittered_delay)
|
|
149
|
+
|
|
150
|
+
logger.error("All retry attempts failed.")
|
|
151
|
+
raise last_exception or RuntimeError("LLM call failed after all retries")
|
|
152
|
+
|
|
153
|
+
async def call(
|
|
154
|
+
self,
|
|
155
|
+
messages: List[Dict[str, str]],
|
|
156
|
+
**kwargs
|
|
157
|
+
) -> str:
|
|
158
|
+
"""Call the LLM and return the content string."""
|
|
159
|
+
response = await self.call_raw(messages, **kwargs)
|
|
160
|
+
content = response.choices[0].message.content
|
|
161
|
+
if content is None:
|
|
162
|
+
raise ValueError("The LLM response content was empty.")
|
|
163
|
+
logger.info(f"LLM response received: '{content[:100]}...'")
|
|
164
|
+
return content
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class AISuiteBackend:
|
|
168
|
+
"""
|
|
169
|
+
LLM backend using the aisuite library (by Andrew Ng).
|
|
170
|
+
|
|
171
|
+
Provides a unified interface to multiple providers:
|
|
172
|
+
OpenAI, Anthropic, Google, AWS, Cohere, Mistral, Ollama, HuggingFace.
|
|
173
|
+
|
|
174
|
+
Model format: "provider:model" (e.g., "openai:gpt-4o", "anthropic:claude-3-5-sonnet")
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
model: str,
|
|
180
|
+
temperature: float = 0.7,
|
|
181
|
+
max_tokens: int = 2048,
|
|
182
|
+
top_p: float = 1.0,
|
|
183
|
+
retry_delays: Optional[List[float]] = None,
|
|
184
|
+
):
|
|
185
|
+
if aisuite is None:
|
|
186
|
+
raise ImportError("aisuite is required. Install with: pip install aisuite")
|
|
187
|
+
|
|
188
|
+
# Normalize model format: accept both "provider/model" and "provider:model"
|
|
189
|
+
self.model = model.replace("/", ":", 1) if "/" in model else model
|
|
190
|
+
self.llm_kwargs = {
|
|
191
|
+
"temperature": temperature,
|
|
192
|
+
"max_tokens": max_tokens,
|
|
193
|
+
"top_p": top_p,
|
|
194
|
+
}
|
|
195
|
+
self.retry_delays = retry_delays or [1, 2, 4, 8]
|
|
196
|
+
self.total_cost = 0.0
|
|
197
|
+
self.total_api_calls = 0
|
|
198
|
+
self.client = aisuite.Client()
|
|
199
|
+
|
|
200
|
+
logger.info(f"Initialized AISuiteBackend with model: {self.model}")
|
|
201
|
+
|
|
202
|
+
async def call_raw(
|
|
203
|
+
self,
|
|
204
|
+
messages: List[Dict[str, str]],
|
|
205
|
+
**kwargs
|
|
206
|
+
) -> Any:
|
|
207
|
+
"""Call the LLM and return the raw response object with retry logic."""
|
|
208
|
+
call_kwargs = {**self.llm_kwargs, **kwargs}
|
|
209
|
+
|
|
210
|
+
last_exception = None
|
|
211
|
+
for attempt, delay in enumerate(self.retry_delays):
|
|
212
|
+
try:
|
|
213
|
+
self.total_api_calls += 1
|
|
214
|
+
logger.info(f"Calling LLM via AISuite (Attempt {attempt + 1}/{len(self.retry_delays)})...")
|
|
215
|
+
|
|
216
|
+
# aisuite is sync-only, wrap in thread for async compatibility
|
|
217
|
+
response = await asyncio.to_thread(
|
|
218
|
+
self.client.chat.completions.create,
|
|
219
|
+
model=self.model,
|
|
220
|
+
messages=messages,
|
|
221
|
+
**call_kwargs
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if response is None or response.choices is None or len(response.choices) == 0:
|
|
225
|
+
raise ValueError("Received an empty or invalid response from the LLM.")
|
|
226
|
+
|
|
227
|
+
# Track cost from usage if available
|
|
228
|
+
if hasattr(response, 'usage') and response.usage:
|
|
229
|
+
# Estimate cost based on token counts (rough estimate)
|
|
230
|
+
# This is approximate; providers have different pricing
|
|
231
|
+
usage = response.usage
|
|
232
|
+
prompt_tokens = getattr(usage, 'prompt_tokens', 0) or 0
|
|
233
|
+
completion_tokens = getattr(usage, 'completion_tokens', 0) or 0
|
|
234
|
+
# Very rough estimate: $0.01 per 1K tokens average
|
|
235
|
+
estimated_cost = (prompt_tokens + completion_tokens) * 0.00001
|
|
236
|
+
self.total_cost += estimated_cost
|
|
237
|
+
|
|
238
|
+
return response
|
|
239
|
+
|
|
240
|
+
except Exception as e:
|
|
241
|
+
last_exception = e
|
|
242
|
+
logger.warning(f"AISuite call failed on attempt {attempt + 1}: {e}")
|
|
243
|
+
if attempt < len(self.retry_delays) - 1:
|
|
244
|
+
jittered_delay = delay + random.random()
|
|
245
|
+
logger.info(f"Retrying in {jittered_delay:.2f} seconds...")
|
|
246
|
+
await asyncio.sleep(jittered_delay)
|
|
247
|
+
|
|
248
|
+
logger.error("All retry attempts failed.")
|
|
249
|
+
raise last_exception or RuntimeError("AISuite call failed after all retries")
|
|
250
|
+
|
|
251
|
+
async def call(
|
|
252
|
+
self,
|
|
253
|
+
messages: List[Dict[str, str]],
|
|
254
|
+
**kwargs
|
|
255
|
+
) -> str:
|
|
256
|
+
"""Call the LLM and return the content string."""
|
|
257
|
+
response = await self.call_raw(messages, **kwargs)
|
|
258
|
+
content = response.choices[0].message.content
|
|
259
|
+
if content is None:
|
|
260
|
+
raise ValueError("The LLM response content was empty.")
|
|
261
|
+
logger.info(f"LLM response received: '{content[:100]}...'")
|
|
262
|
+
return content
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
266
|
+
# Extractors (process LiteLLM responses into structured output)
|
|
267
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
268
|
+
|
|
269
|
+
@runtime_checkable
|
|
270
|
+
class Extractor(Protocol):
|
|
271
|
+
"""Protocol for response extractors. Process raw LLM responses into structured output."""
|
|
272
|
+
|
|
273
|
+
def extract(self, response: Any) -> Any:
|
|
274
|
+
"""
|
|
275
|
+
Extract structured data from a raw LLM response.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
response: Raw response object from LLMBackend.call_raw()
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Extracted/structured data
|
|
282
|
+
"""
|
|
283
|
+
...
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class FreeExtractor:
|
|
287
|
+
"""Returns the raw response content as-is. No parsing."""
|
|
288
|
+
|
|
289
|
+
def extract(self, response: Any) -> str:
|
|
290
|
+
"""Extract raw content string."""
|
|
291
|
+
content = response.choices[0].message.content
|
|
292
|
+
return content if content is not None else ""
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class FreeThinkingExtractor:
|
|
296
|
+
"""
|
|
297
|
+
Preserves reasoning/thinking from the response.
|
|
298
|
+
Returns: { "thinking": str, "response": str }
|
|
299
|
+
|
|
300
|
+
Works with models that return thinking in:
|
|
301
|
+
- A separate 'thinking' field
|
|
302
|
+
- Content blocks with type='thinking'
|
|
303
|
+
- <thinking> tags in content
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def extract(self, response: Any) -> Dict[str, str]:
|
|
307
|
+
"""Extract thinking and response separately."""
|
|
308
|
+
import re
|
|
309
|
+
message = response.choices[0].message
|
|
310
|
+
content = message.content or ""
|
|
311
|
+
thinking = ""
|
|
312
|
+
|
|
313
|
+
# Check for thinking in message attributes (provider-specific)
|
|
314
|
+
if hasattr(message, 'thinking') and message.thinking:
|
|
315
|
+
thinking = message.thinking
|
|
316
|
+
# Check for thinking in content blocks (Anthropic style)
|
|
317
|
+
elif hasattr(message, 'content_blocks'):
|
|
318
|
+
for block in message.content_blocks or []:
|
|
319
|
+
if getattr(block, 'type', None) == 'thinking':
|
|
320
|
+
thinking = getattr(block, 'text', '')
|
|
321
|
+
elif getattr(block, 'type', None) == 'text':
|
|
322
|
+
content = getattr(block, 'text', content)
|
|
323
|
+
# Check for <thinking> tags in content
|
|
324
|
+
elif '<thinking>' in content and '</thinking>' in content:
|
|
325
|
+
match = re.search(r'<thinking>(.*?)</thinking>', content, re.DOTALL)
|
|
326
|
+
if match:
|
|
327
|
+
thinking = match.group(1).strip()
|
|
328
|
+
content = re.sub(r'<thinking>.*?</thinking>', '', content, flags=re.DOTALL).strip()
|
|
329
|
+
|
|
330
|
+
return {"thinking": thinking, "response": content}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class StructuredExtractor:
|
|
334
|
+
"""
|
|
335
|
+
Extracts structured JSON output using response_format.
|
|
336
|
+
Requires the LLM call to include response_format parameter.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(self, schema: Optional[Dict] = None):
|
|
340
|
+
"""
|
|
341
|
+
Args:
|
|
342
|
+
schema: Optional JSON schema for validation
|
|
343
|
+
"""
|
|
344
|
+
self.schema = schema
|
|
345
|
+
|
|
346
|
+
def extract(self, response: Any) -> Dict[str, Any]:
|
|
347
|
+
"""Extract and parse JSON from response."""
|
|
348
|
+
content = response.choices[0].message.content
|
|
349
|
+
if content is None:
|
|
350
|
+
return {}
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# Strip markdown fences - LLMs sometimes wrap JSON in ```json blocks
|
|
354
|
+
parsed = json.loads(strip_markdown_json(content))
|
|
355
|
+
return parsed
|
|
356
|
+
except json.JSONDecodeError as e:
|
|
357
|
+
logger.warning(f"Failed to parse JSON response: {e}")
|
|
358
|
+
return {"_raw": content, "_error": str(e)}
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class ToolsExtractor:
|
|
362
|
+
"""
|
|
363
|
+
Extracts tool calls from the response.
|
|
364
|
+
Returns: { "tool_calls": [...], "content": str }
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def extract(self, response: Any) -> Dict[str, Any]:
|
|
368
|
+
"""Extract tool calls and content."""
|
|
369
|
+
message = response.choices[0].message
|
|
370
|
+
content = message.content or ""
|
|
371
|
+
tool_calls = []
|
|
372
|
+
|
|
373
|
+
if hasattr(message, 'tool_calls') and message.tool_calls:
|
|
374
|
+
for tc in message.tool_calls:
|
|
375
|
+
tool_call = {
|
|
376
|
+
"id": getattr(tc, 'id', None),
|
|
377
|
+
"type": getattr(tc, 'type', 'function'),
|
|
378
|
+
"function": {
|
|
379
|
+
"name": tc.function.name if hasattr(tc, 'function') else None,
|
|
380
|
+
"arguments": tc.function.arguments if hasattr(tc, 'function') else None,
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
# Parse arguments JSON if present
|
|
384
|
+
if tool_call["function"]["arguments"]:
|
|
385
|
+
try:
|
|
386
|
+
tool_call["function"]["arguments"] = json.loads(
|
|
387
|
+
tool_call["function"]["arguments"]
|
|
388
|
+
)
|
|
389
|
+
except json.JSONDecodeError:
|
|
390
|
+
pass # Keep as string if not valid JSON
|
|
391
|
+
tool_calls.append(tool_call)
|
|
392
|
+
|
|
393
|
+
return {"tool_calls": tool_calls, "content": content}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class RegexExtractor:
|
|
397
|
+
"""
|
|
398
|
+
Extracts fields from response using regex patterns.
|
|
399
|
+
Patterns are provided at runtime, not in the spec.
|
|
400
|
+
|
|
401
|
+
Can extract from:
|
|
402
|
+
- Raw LLM response object (response.choices[0].message.content)
|
|
403
|
+
- Plain string
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
def __init__(self, patterns: Dict[str, str], types: Optional[Dict[str, str]] = None):
|
|
407
|
+
"""
|
|
408
|
+
Args:
|
|
409
|
+
patterns: Map of field names to regex patterns (must have capture group)
|
|
410
|
+
types: Optional map of field names to type names ('str', 'int', 'float', 'bool', 'json')
|
|
411
|
+
"""
|
|
412
|
+
import re
|
|
413
|
+
self.patterns = {name: re.compile(pattern) for name, pattern in patterns.items()}
|
|
414
|
+
self.types = types or {}
|
|
415
|
+
|
|
416
|
+
def extract(self, response: Any) -> Optional[Dict[str, Any]]:
|
|
417
|
+
"""Extract fields using regex patterns."""
|
|
418
|
+
# Handle both response object and plain string
|
|
419
|
+
if isinstance(response, str):
|
|
420
|
+
content = response
|
|
421
|
+
else:
|
|
422
|
+
content = response.choices[0].message.content
|
|
423
|
+
|
|
424
|
+
if content is None:
|
|
425
|
+
return None
|
|
426
|
+
|
|
427
|
+
result = {}
|
|
428
|
+
for field_name, pattern in self.patterns.items():
|
|
429
|
+
match = pattern.search(content)
|
|
430
|
+
if not match:
|
|
431
|
+
logger.debug(f"Field '{field_name}' pattern did not match")
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
value = match.group(1)
|
|
435
|
+
field_type = self.types.get(field_name, 'str')
|
|
436
|
+
|
|
437
|
+
try:
|
|
438
|
+
if field_type == 'json':
|
|
439
|
+
result[field_name] = json.loads(value)
|
|
440
|
+
elif field_type == 'int':
|
|
441
|
+
result[field_name] = int(value)
|
|
442
|
+
elif field_type == 'float':
|
|
443
|
+
result[field_name] = float(value)
|
|
444
|
+
elif field_type == 'bool':
|
|
445
|
+
result[field_name] = value.lower() in ('true', '1', 'yes')
|
|
446
|
+
else:
|
|
447
|
+
result[field_name] = value
|
|
448
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
449
|
+
logger.debug(f"Failed to parse field '{field_name}': {e}")
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
return result
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
456
|
+
# MCP Tool Provider Protocol and Types
|
|
457
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
458
|
+
|
|
459
|
+
from dataclasses import dataclass, field
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@runtime_checkable
|
|
463
|
+
class MCPToolProvider(Protocol):
|
|
464
|
+
"""
|
|
465
|
+
Protocol for MCP tool providers.
|
|
466
|
+
|
|
467
|
+
Users implement this to connect their MCP backend (e.g., aisuite.mcp.MCPClient).
|
|
468
|
+
The SDK does not provide an implementation - users bring their own.
|
|
469
|
+
|
|
470
|
+
Example implementation using aisuite:
|
|
471
|
+
|
|
472
|
+
class AISuiteMCPProvider:
|
|
473
|
+
def __init__(self):
|
|
474
|
+
self._clients = {}
|
|
475
|
+
|
|
476
|
+
def connect(self, server_name: str, config: dict):
|
|
477
|
+
from aisuite.mcp import MCPClient
|
|
478
|
+
if server_name not in self._clients:
|
|
479
|
+
self._clients[server_name] = MCPClient.from_config(config)
|
|
480
|
+
|
|
481
|
+
def get_tools(self, server_name: str) -> list:
|
|
482
|
+
return self._clients[server_name].list_tools()
|
|
483
|
+
|
|
484
|
+
def call_tool(self, server_name: str, tool_name: str, arguments: dict):
|
|
485
|
+
return self._clients[server_name].call_tool(tool_name, arguments)
|
|
486
|
+
|
|
487
|
+
def close(self):
|
|
488
|
+
for c in self._clients.values():
|
|
489
|
+
c.close()
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
def connect(self, server_name: str, config: Dict[str, Any]) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Connect to an MCP server with the given configuration.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
server_name: Identifier for this server (matches key in mcp.servers)
|
|
498
|
+
config: Server configuration (command/args for stdio, server_url for HTTP)
|
|
499
|
+
"""
|
|
500
|
+
...
|
|
501
|
+
|
|
502
|
+
def get_tools(self, server_name: str) -> List[Dict[str, Any]]:
|
|
503
|
+
"""
|
|
504
|
+
Get available tools from an MCP server.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
server_name: Server identifier
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
List of tool definitions with 'name', 'description', 'inputSchema'
|
|
511
|
+
"""
|
|
512
|
+
...
|
|
513
|
+
|
|
514
|
+
def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
515
|
+
"""
|
|
516
|
+
Execute a tool call on an MCP server.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
server_name: Server identifier
|
|
520
|
+
tool_name: Name of the tool to call
|
|
521
|
+
arguments: Tool arguments
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
Tool execution result
|
|
525
|
+
"""
|
|
526
|
+
...
|
|
527
|
+
|
|
528
|
+
def close(self) -> None:
|
|
529
|
+
"""Cleanup all server connections."""
|
|
530
|
+
...
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@dataclass
|
|
534
|
+
class ToolCall:
|
|
535
|
+
"""
|
|
536
|
+
Represents a tool call request from the LLM.
|
|
537
|
+
|
|
538
|
+
Attributes:
|
|
539
|
+
id: Unique identifier for this tool call (from LLM response)
|
|
540
|
+
server: MCP server name (matches key in mcp.servers config)
|
|
541
|
+
tool: Tool name
|
|
542
|
+
arguments: Tool arguments as a dictionary
|
|
543
|
+
"""
|
|
544
|
+
id: str
|
|
545
|
+
server: str
|
|
546
|
+
tool: str
|
|
547
|
+
arguments: Dict[str, Any] = field(default_factory=dict)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@dataclass
|
|
551
|
+
class AgentResponse:
|
|
552
|
+
"""
|
|
553
|
+
Response from an agent call.
|
|
554
|
+
|
|
555
|
+
Attributes:
|
|
556
|
+
content: Raw text content from LLM (may be None if only tool calls)
|
|
557
|
+
output: Parsed output according to output schema (if defined)
|
|
558
|
+
tool_calls: List of tool calls requested by LLM (if any)
|
|
559
|
+
raw_response: Raw LLM response object for advanced use cases
|
|
560
|
+
"""
|
|
561
|
+
content: Optional[str] = None
|
|
562
|
+
output: Optional[Dict[str, Any]] = None
|
|
563
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
564
|
+
raw_response: Optional[Any] = None
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
568
|
+
# FlatAgent Base Class
|
|
569
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
570
|
+
|
|
571
|
+
class FlatAgent(ABC):
|
|
572
|
+
"""
|
|
573
|
+
Abstract base class for self-contained flat agents.
|
|
574
|
+
|
|
575
|
+
Combines the agent interface, configuration, and execution loop.
|
|
576
|
+
LLM interaction is delegated to a pluggable LLMBackend.
|
|
577
|
+
|
|
578
|
+
Configuration can be provided via:
|
|
579
|
+
- config_file: Path to a YAML configuration file
|
|
580
|
+
- config_dict: A dictionary with configuration
|
|
581
|
+
- backend: Custom LLMBackend instance (overrides config-based backend)
|
|
582
|
+
- **kwargs: Override individual parameters
|
|
583
|
+
|
|
584
|
+
Example usage:
|
|
585
|
+
class MyAgent(FlatAgent):
|
|
586
|
+
def create_initial_state(self): return {}
|
|
587
|
+
def generate_step_prompt(self, state): return "..."
|
|
588
|
+
def update_state(self, state, result): return {**state, 'result': result}
|
|
589
|
+
def is_solved(self, state): return state.get('done', False)
|
|
590
|
+
|
|
591
|
+
# Using config file (creates LiteLLMBackend automatically)
|
|
592
|
+
agent = MyAgent(config_file="config.yaml")
|
|
593
|
+
|
|
594
|
+
# Using custom backend
|
|
595
|
+
backend = LiteLLMBackend(model="openai/gpt-4", temperature=0.5)
|
|
596
|
+
agent = MyAgent(backend=backend)
|
|
597
|
+
|
|
598
|
+
trace = await agent.execute()
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
|
602
|
+
|
|
603
|
+
def __init__(
|
|
604
|
+
self,
|
|
605
|
+
config_file: Optional[str] = None,
|
|
606
|
+
config_dict: Optional[Dict] = None,
|
|
607
|
+
backend: Optional[LLMBackend] = None,
|
|
608
|
+
**kwargs
|
|
609
|
+
):
|
|
610
|
+
"""
|
|
611
|
+
Initialize the agent with configuration and optional backend.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
config_file: Path to YAML config file
|
|
615
|
+
config_dict: Configuration dictionary
|
|
616
|
+
backend: Custom LLMBackend (if not provided, creates LiteLLMBackend from config)
|
|
617
|
+
**kwargs: Override specific config values
|
|
618
|
+
"""
|
|
619
|
+
self._load_config(config_file, config_dict, **kwargs)
|
|
620
|
+
|
|
621
|
+
if backend is not None:
|
|
622
|
+
self.backend = backend
|
|
623
|
+
else:
|
|
624
|
+
self.backend = self._create_default_backend()
|
|
625
|
+
|
|
626
|
+
logger.info(f"Initialized {self.__class__.__name__} with backend: {self.backend.__class__.__name__}")
|
|
627
|
+
|
|
628
|
+
def _load_config(
|
|
629
|
+
self,
|
|
630
|
+
config_file: Optional[str],
|
|
631
|
+
config_dict: Optional[Dict],
|
|
632
|
+
**kwargs
|
|
633
|
+
):
|
|
634
|
+
"""Load and process configuration from file (YAML or JSON), dict, or kwargs."""
|
|
635
|
+
config = {}
|
|
636
|
+
|
|
637
|
+
if config_file is not None:
|
|
638
|
+
if not os.path.exists(config_file):
|
|
639
|
+
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
|
640
|
+
|
|
641
|
+
with open(config_file, 'r') as f:
|
|
642
|
+
if config_file.endswith('.json'):
|
|
643
|
+
config = json.load(f) or {}
|
|
644
|
+
else:
|
|
645
|
+
if yaml is None:
|
|
646
|
+
raise ImportError("pyyaml is required for YAML config files. Install with: pip install pyyaml")
|
|
647
|
+
config = yaml.safe_load(f) or {}
|
|
648
|
+
elif config_dict is not None:
|
|
649
|
+
config = config_dict
|
|
650
|
+
|
|
651
|
+
model_config = config.get('model', {})
|
|
652
|
+
defaults = config.get('litellm_defaults', {})
|
|
653
|
+
|
|
654
|
+
# Build model name from provider/name if needed
|
|
655
|
+
provider = model_config.get('provider')
|
|
656
|
+
model_name = model_config.get('name')
|
|
657
|
+
if provider and model_name and '/' not in model_name:
|
|
658
|
+
full_model_name = f"{provider}/{model_name}"
|
|
659
|
+
else:
|
|
660
|
+
full_model_name = model_name
|
|
661
|
+
|
|
662
|
+
def get_value(key: str, fallback: Any) -> Any:
|
|
663
|
+
return kwargs.get(key, model_config.get(key, defaults.get(key, fallback)))
|
|
664
|
+
|
|
665
|
+
# Store config values for backend creation
|
|
666
|
+
self.model = kwargs.get('model', full_model_name)
|
|
667
|
+
self.temperature = get_value('temperature', 0.7)
|
|
668
|
+
self.max_tokens = get_value('max_tokens', 2048)
|
|
669
|
+
self.top_p = get_value('top_p', 1.0)
|
|
670
|
+
self.frequency_penalty = get_value('frequency_penalty', 0.0)
|
|
671
|
+
self.presence_penalty = get_value('presence_penalty', 0.0)
|
|
672
|
+
self.retry_delays = model_config.get('retry_delays', [1, 2, 4, 8])
|
|
673
|
+
|
|
674
|
+
# Store raw config for subclass access
|
|
675
|
+
self.config = config
|
|
676
|
+
|
|
677
|
+
def _create_default_backend(self) -> LLMBackend:
|
|
678
|
+
"""Create the default LiteLLMBackend from loaded config."""
|
|
679
|
+
if self.model is None:
|
|
680
|
+
raise ValueError("Model name is required. Provide via config file, config_dict, or 'model' kwarg.")
|
|
681
|
+
|
|
682
|
+
return LiteLLMBackend(
|
|
683
|
+
model=self.model,
|
|
684
|
+
temperature=self.temperature,
|
|
685
|
+
max_tokens=self.max_tokens,
|
|
686
|
+
top_p=self.top_p,
|
|
687
|
+
frequency_penalty=self.frequency_penalty,
|
|
688
|
+
presence_penalty=self.presence_penalty,
|
|
689
|
+
retry_delays=self.retry_delays,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
693
|
+
# Convenience Properties (delegate to backend)
|
|
694
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
695
|
+
|
|
696
|
+
@property
|
|
697
|
+
def total_cost(self) -> float:
|
|
698
|
+
"""Total cost accumulated by the backend."""
|
|
699
|
+
return self.backend.total_cost
|
|
700
|
+
|
|
701
|
+
@property
|
|
702
|
+
def total_api_calls(self) -> int:
|
|
703
|
+
"""Total API calls made by the backend."""
|
|
704
|
+
return self.backend.total_api_calls
|
|
705
|
+
|
|
706
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
707
|
+
# Abstract Methods (subclasses must implement)
|
|
708
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
709
|
+
|
|
710
|
+
@abstractmethod
|
|
711
|
+
def create_initial_state(self, *args, **kwargs) -> Any:
|
|
712
|
+
"""Create the initial state for the problem."""
|
|
713
|
+
pass
|
|
714
|
+
|
|
715
|
+
@abstractmethod
|
|
716
|
+
def generate_step_prompt(self, state: Any) -> str:
|
|
717
|
+
"""Generate the user prompt for the next step based on current state."""
|
|
718
|
+
pass
|
|
719
|
+
|
|
720
|
+
@abstractmethod
|
|
721
|
+
def update_state(self, current_state: Any, step_result: Any) -> Any:
|
|
722
|
+
"""Update the state based on the step result."""
|
|
723
|
+
pass
|
|
724
|
+
|
|
725
|
+
@abstractmethod
|
|
726
|
+
def is_solved(self, state: Any) -> bool:
|
|
727
|
+
"""Check if the problem is solved."""
|
|
728
|
+
pass
|
|
729
|
+
|
|
730
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
731
|
+
# Overridable Hooks
|
|
732
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
733
|
+
|
|
734
|
+
def get_system_prompt(self) -> str:
|
|
735
|
+
"""
|
|
736
|
+
Get the system prompt for LLM calls.
|
|
737
|
+
Override to customize the system prompt for your agent.
|
|
738
|
+
"""
|
|
739
|
+
return self.DEFAULT_SYSTEM_PROMPT
|
|
740
|
+
|
|
741
|
+
def get_response_parser(self) -> Callable[[str], Any]:
|
|
742
|
+
"""
|
|
743
|
+
Get the response parser for this agent.
|
|
744
|
+
Override to provide domain-specific parsing of LLM responses.
|
|
745
|
+
"""
|
|
746
|
+
return lambda x: x
|
|
747
|
+
|
|
748
|
+
def validate_step_result(self, step_result: Any) -> bool:
|
|
749
|
+
"""
|
|
750
|
+
Validate that a step result is acceptable before updating state.
|
|
751
|
+
Override for domain-specific validation.
|
|
752
|
+
"""
|
|
753
|
+
return step_result is not None
|
|
754
|
+
|
|
755
|
+
def step_generator(self, state: Any) -> Tuple[Tuple[str, str], Callable[[str], Any]]:
|
|
756
|
+
"""
|
|
757
|
+
Generate the prompt tuple and parser for the current state.
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
Tuple of ((system_prompt, user_prompt), response_parser)
|
|
761
|
+
|
|
762
|
+
Override for full control over prompt generation.
|
|
763
|
+
"""
|
|
764
|
+
system_prompt = self.get_system_prompt()
|
|
765
|
+
user_prompt = self.generate_step_prompt(state)
|
|
766
|
+
parser = self.get_response_parser()
|
|
767
|
+
return (system_prompt, user_prompt), parser
|
|
768
|
+
|
|
769
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
770
|
+
# Execution
|
|
771
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
772
|
+
|
|
773
|
+
async def execute(self, *args, **kwargs) -> List[Any]:
|
|
774
|
+
"""
|
|
775
|
+
Execute the agent to solve the problem.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
*args, **kwargs: Passed to create_initial_state()
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
List of states representing the execution trace
|
|
782
|
+
"""
|
|
783
|
+
logger.info(f"Starting execution with args={args}, kwargs={kwargs}")
|
|
784
|
+
|
|
785
|
+
state = self.create_initial_state(*args, **kwargs)
|
|
786
|
+
trace = [state]
|
|
787
|
+
|
|
788
|
+
while not self.is_solved(state):
|
|
789
|
+
prompt_tuple, parser = self.step_generator(state)
|
|
790
|
+
raw_result = await self._call_llm(prompt_tuple)
|
|
791
|
+
parsed_result = parser(raw_result)
|
|
792
|
+
|
|
793
|
+
if not self.validate_step_result(parsed_result):
|
|
794
|
+
logger.warning(f"Step result validation failed: {parsed_result}")
|
|
795
|
+
|
|
796
|
+
state = self.update_state(state, parsed_result)
|
|
797
|
+
trace.append(state)
|
|
798
|
+
logger.info("State updated.")
|
|
799
|
+
|
|
800
|
+
logger.info(f"Execution completed. Trace length: {len(trace)} states")
|
|
801
|
+
return trace
|
|
802
|
+
|
|
803
|
+
async def _call_llm(self, prompt_tuple: Tuple[str, str]) -> str:
|
|
804
|
+
"""
|
|
805
|
+
Call the LLM backend with the given prompt.
|
|
806
|
+
|
|
807
|
+
Override this for custom pre/post processing around LLM calls.
|
|
808
|
+
"""
|
|
809
|
+
system_prompt, user_prompt = prompt_tuple
|
|
810
|
+
messages = [
|
|
811
|
+
{"role": "system", "content": system_prompt},
|
|
812
|
+
{"role": "user", "content": user_prompt},
|
|
813
|
+
]
|
|
814
|
+
return await self.backend.call(messages)
|