flatagents 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flatagents/__init__.py +136 -0
- flatagents/actions.py +239 -0
- flatagents/assets/__init__.py +0 -0
- flatagents/assets/flatagent.d.ts +189 -0
- flatagents/assets/flatagent.schema.json +210 -0
- flatagents/assets/flatagent.slim.d.ts +52 -0
- flatagents/assets/flatmachine.d.ts +363 -0
- flatagents/assets/flatmachine.schema.json +515 -0
- flatagents/assets/flatmachine.slim.d.ts +94 -0
- flatagents/backends.py +222 -0
- flatagents/baseagent.py +814 -0
- flatagents/execution.py +462 -0
- flatagents/expressions/__init__.py +60 -0
- flatagents/expressions/cel.py +101 -0
- flatagents/expressions/simple.py +166 -0
- flatagents/flatagent.py +735 -0
- flatagents/flatmachine.py +1176 -0
- flatagents/gcp/__init__.py +25 -0
- flatagents/gcp/firestore.py +227 -0
- flatagents/hooks.py +380 -0
- flatagents/locking.py +69 -0
- flatagents/monitoring.py +373 -0
- flatagents/persistence.py +200 -0
- flatagents/utils.py +46 -0
- flatagents/validation.py +141 -0
- flatagents-0.4.1.dist-info/METADATA +310 -0
- flatagents-0.4.1.dist-info/RECORD +28 -0
- flatagents-0.4.1.dist-info/WHEEL +4 -0
flatagents/flatagent.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FlatAgent - Single-call agent implementation.
|
|
3
|
+
|
|
4
|
+
A single-call agent executes one prompt/response cycle with optional:
|
|
5
|
+
- Input/output schemas
|
|
6
|
+
- Response extraction (free, structured, tools, regex)
|
|
7
|
+
- MCP tool integration
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
from .monitoring import get_logger
|
|
16
|
+
from .utils import strip_markdown_json
|
|
17
|
+
from .baseagent import (
|
|
18
|
+
FlatAgent as BaseFlatAgent,
|
|
19
|
+
LLMBackend,
|
|
20
|
+
LiteLLMBackend,
|
|
21
|
+
AISuiteBackend,
|
|
22
|
+
Extractor,
|
|
23
|
+
FreeExtractor,
|
|
24
|
+
FreeThinkingExtractor,
|
|
25
|
+
StructuredExtractor,
|
|
26
|
+
ToolsExtractor,
|
|
27
|
+
RegexExtractor,
|
|
28
|
+
MCPToolProvider,
|
|
29
|
+
AgentResponse,
|
|
30
|
+
ToolCall,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
logger = get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
import jinja2
|
|
37
|
+
except ImportError:
|
|
38
|
+
jinja2 = None
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import litellm
|
|
42
|
+
except ImportError:
|
|
43
|
+
litellm = None
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import aisuite
|
|
47
|
+
except ImportError:
|
|
48
|
+
aisuite = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class FlatAgent:
|
|
52
|
+
"""
|
|
53
|
+
A single LLM call configured entirely via YAML. No code required.
|
|
54
|
+
|
|
55
|
+
v0.6.0 Container format:
|
|
56
|
+
|
|
57
|
+
spec: flatagent
|
|
58
|
+
spec_version: "0.6.0"
|
|
59
|
+
|
|
60
|
+
data:
|
|
61
|
+
name: greeter
|
|
62
|
+
|
|
63
|
+
model:
|
|
64
|
+
provider: cerebras
|
|
65
|
+
name: zai-glm-4.6
|
|
66
|
+
temperature: 0.7
|
|
67
|
+
|
|
68
|
+
system: "You are a friendly greeter."
|
|
69
|
+
|
|
70
|
+
user: |
|
|
71
|
+
Greet the user named {{ input.name }}.
|
|
72
|
+
|
|
73
|
+
output:
|
|
74
|
+
greeting:
|
|
75
|
+
type: str
|
|
76
|
+
description: "A friendly greeting message"
|
|
77
|
+
|
|
78
|
+
# Optional MCP configuration
|
|
79
|
+
mcp:
|
|
80
|
+
servers:
|
|
81
|
+
filesystem:
|
|
82
|
+
command: npx
|
|
83
|
+
args: ["-y", "@modelcontextprotocol/server-filesystem", "/docs"]
|
|
84
|
+
tool_filter:
|
|
85
|
+
allow: ["filesystem:read_file"]
|
|
86
|
+
tool_prompt: |
|
|
87
|
+
You have access to these tools:
|
|
88
|
+
{% for tool in tools %}
|
|
89
|
+
- {{ tool.name }}: {{ tool.description }}
|
|
90
|
+
{% endfor %}
|
|
91
|
+
|
|
92
|
+
metadata:
|
|
93
|
+
author: "your-name"
|
|
94
|
+
|
|
95
|
+
Example usage:
|
|
96
|
+
from flatagents import setup_logging, get_logger
|
|
97
|
+
setup_logging(level="INFO")
|
|
98
|
+
logger = get_logger(__name__)
|
|
99
|
+
|
|
100
|
+
agent = FlatAgent(config_file="agent.yaml")
|
|
101
|
+
result = await agent.call(name="Alice")
|
|
102
|
+
logger.info(f"Result: {result}")
|
|
103
|
+
|
|
104
|
+
Example with MCP:
|
|
105
|
+
from flatagents import FlatAgent, MCPToolProvider
|
|
106
|
+
|
|
107
|
+
agent = FlatAgent(config_file="agent.yaml")
|
|
108
|
+
provider = MyMCPProvider() # User implements MCPToolProvider protocol
|
|
109
|
+
result = await agent.call(tool_provider=provider, question="List files")
|
|
110
|
+
|
|
111
|
+
if result.tool_calls:
|
|
112
|
+
for tc in result.tool_calls:
|
|
113
|
+
tool_result = provider.call_tool(tc.server, tc.tool, tc.arguments)
|
|
114
|
+
# Handle tool result...
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
SPEC_VERSION = "0.6.0"
|
|
118
|
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
config_file: Optional[str] = None,
|
|
123
|
+
config_dict: Optional[Dict] = None,
|
|
124
|
+
tool_provider: Optional["MCPToolProvider"] = None,
|
|
125
|
+
backend: Optional[str] = None,
|
|
126
|
+
**kwargs
|
|
127
|
+
):
|
|
128
|
+
if jinja2 is None:
|
|
129
|
+
raise ImportError("jinja2 is required for FlatAgent. Install with: pip install jinja2")
|
|
130
|
+
|
|
131
|
+
self._load_config(config_file, config_dict, **kwargs)
|
|
132
|
+
self._validate_spec()
|
|
133
|
+
self._parse_agent_config()
|
|
134
|
+
|
|
135
|
+
# Determine backend: explicit param > config > auto-detect
|
|
136
|
+
config_backend = self.config.get('data', {}).get('model', {}).get('backend')
|
|
137
|
+
self._backend = backend or config_backend or self._auto_detect_backend()
|
|
138
|
+
self._init_backend()
|
|
139
|
+
|
|
140
|
+
# MCP support
|
|
141
|
+
self._tool_provider = tool_provider
|
|
142
|
+
self._tools_cache: Optional[List[Dict]] = None
|
|
143
|
+
|
|
144
|
+
# Tracking
|
|
145
|
+
self.total_cost = 0.0
|
|
146
|
+
self.total_api_calls = 0
|
|
147
|
+
|
|
148
|
+
logger.info(f"Initialized FlatAgent: {self.agent_name} (backend: {self._backend})")
|
|
149
|
+
|
|
150
|
+
def _auto_detect_backend(self) -> str:
|
|
151
|
+
"""
|
|
152
|
+
Auto-detect available LLM backend.
|
|
153
|
+
|
|
154
|
+
Priority:
|
|
155
|
+
1. FLATAGENTS_BACKEND env var (e.g., "litellm" or "aisuite")
|
|
156
|
+
2. litellm if installed (preferred for stability)
|
|
157
|
+
3. aisuite if installed
|
|
158
|
+
"""
|
|
159
|
+
# Check env var first (SDK-specific override, not in config)
|
|
160
|
+
env_backend = os.environ.get("FLATAGENTS_BACKEND", "").lower()
|
|
161
|
+
if env_backend in ("litellm", "aisuite"):
|
|
162
|
+
return env_backend
|
|
163
|
+
|
|
164
|
+
# Prefer litellm for stability
|
|
165
|
+
if litellm is not None:
|
|
166
|
+
return "litellm"
|
|
167
|
+
if aisuite is not None:
|
|
168
|
+
return "aisuite"
|
|
169
|
+
raise ImportError(
|
|
170
|
+
"No LLM backend available. Install one of:\n"
|
|
171
|
+
" pip install litellm (recommended)\n"
|
|
172
|
+
" pip install aisuite"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _init_backend(self) -> None:
|
|
176
|
+
"""Initialize the selected backend."""
|
|
177
|
+
if self._backend == "aisuite":
|
|
178
|
+
if aisuite is None:
|
|
179
|
+
raise ImportError("aisuite backend selected but not installed. Install with: pip install aisuite")
|
|
180
|
+
self._aisuite_client = aisuite.Client()
|
|
181
|
+
elif self._backend == "litellm":
|
|
182
|
+
if litellm is None:
|
|
183
|
+
raise ImportError("litellm backend selected but not installed. Install with: pip install litellm")
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(f"Unknown backend: {self._backend}. Use 'aisuite' or 'litellm'.")
|
|
186
|
+
|
|
187
|
+
async def _call_llm(self, params: Dict[str, Any]) -> Any:
|
|
188
|
+
"""Call the LLM using the selected backend."""
|
|
189
|
+
import asyncio
|
|
190
|
+
|
|
191
|
+
if self._backend == "aisuite":
|
|
192
|
+
return await self._call_aisuite(params)
|
|
193
|
+
else:
|
|
194
|
+
return await litellm.acompletion(**params)
|
|
195
|
+
|
|
196
|
+
async def _call_aisuite(self, params: Dict[str, Any]) -> Any:
|
|
197
|
+
"""Call LLM via aisuite backend."""
|
|
198
|
+
import asyncio
|
|
199
|
+
|
|
200
|
+
model = params["model"]
|
|
201
|
+
if "/" in model:
|
|
202
|
+
model = model.replace("/", ":", 1)
|
|
203
|
+
|
|
204
|
+
provider_key, model_name = model.split(":", 1)
|
|
205
|
+
|
|
206
|
+
# WORKAROUND: aisuite drops tools unless max_turns is set.
|
|
207
|
+
# Use direct provider call for Cerebras.
|
|
208
|
+
if provider_key == "cerebras":
|
|
209
|
+
return await self._call_aisuite_cerebras_direct(model_name, params)
|
|
210
|
+
|
|
211
|
+
call_params = {
|
|
212
|
+
"model": model,
|
|
213
|
+
"messages": params["messages"],
|
|
214
|
+
"temperature": params.get("temperature", 0.7),
|
|
215
|
+
"max_tokens": params.get("max_tokens", 2048),
|
|
216
|
+
}
|
|
217
|
+
if "tools" in params:
|
|
218
|
+
call_params["tools"] = params["tools"]
|
|
219
|
+
|
|
220
|
+
response = await asyncio.to_thread(
|
|
221
|
+
self._aisuite_client.chat.completions.create,
|
|
222
|
+
**call_params
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return response
|
|
226
|
+
|
|
227
|
+
async def _call_aisuite_cerebras_direct(self, model_name: str, params: Dict[str, Any]) -> Any:
|
|
228
|
+
"""Direct Cerebras provider call. Workaround for aisuite dropping tools."""
|
|
229
|
+
import asyncio
|
|
230
|
+
from aisuite.provider import ProviderFactory
|
|
231
|
+
|
|
232
|
+
config = self._aisuite_client.provider_configs.get("cerebras", {})
|
|
233
|
+
provider = ProviderFactory.create_provider("cerebras", config)
|
|
234
|
+
|
|
235
|
+
kwargs = {
|
|
236
|
+
"temperature": params.get("temperature", 0.7),
|
|
237
|
+
"max_tokens": params.get("max_tokens", 2048),
|
|
238
|
+
}
|
|
239
|
+
if "tools" in params:
|
|
240
|
+
kwargs["tools"] = params["tools"]
|
|
241
|
+
|
|
242
|
+
response = await asyncio.to_thread(
|
|
243
|
+
provider.chat_completions_create,
|
|
244
|
+
model_name,
|
|
245
|
+
params["messages"],
|
|
246
|
+
**kwargs
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return response
|
|
250
|
+
|
|
251
|
+
def _load_config(
|
|
252
|
+
self,
|
|
253
|
+
config_file: Optional[str],
|
|
254
|
+
config_dict: Optional[Dict],
|
|
255
|
+
**kwargs
|
|
256
|
+
):
|
|
257
|
+
"""Load v0.6.0 container config."""
|
|
258
|
+
import os
|
|
259
|
+
try:
|
|
260
|
+
import yaml
|
|
261
|
+
except ImportError:
|
|
262
|
+
yaml = None
|
|
263
|
+
|
|
264
|
+
config = {}
|
|
265
|
+
if config_file is not None:
|
|
266
|
+
if not os.path.exists(config_file):
|
|
267
|
+
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
|
268
|
+
with open(config_file, 'r') as f:
|
|
269
|
+
if config_file.endswith('.json'):
|
|
270
|
+
config = json.load(f) or {}
|
|
271
|
+
else:
|
|
272
|
+
if yaml is None:
|
|
273
|
+
raise ImportError("pyyaml is required for YAML config files.")
|
|
274
|
+
config = yaml.safe_load(f) or {}
|
|
275
|
+
elif config_dict is not None:
|
|
276
|
+
config = config_dict
|
|
277
|
+
|
|
278
|
+
self.config = config
|
|
279
|
+
|
|
280
|
+
# Extract model config from data section
|
|
281
|
+
data = config.get('data', {})
|
|
282
|
+
model_config = data.get('model', {})
|
|
283
|
+
|
|
284
|
+
# Build model name from provider/name
|
|
285
|
+
provider = model_config.get('provider')
|
|
286
|
+
model_name = model_config.get('name')
|
|
287
|
+
if provider and model_name and '/' not in model_name:
|
|
288
|
+
full_model_name = f"{provider}/{model_name}"
|
|
289
|
+
else:
|
|
290
|
+
full_model_name = model_name
|
|
291
|
+
|
|
292
|
+
# Set model attributes (with kwargs override)
|
|
293
|
+
self.model = kwargs.get('model', full_model_name)
|
|
294
|
+
self.temperature = kwargs.get('temperature', model_config.get('temperature', 0.7))
|
|
295
|
+
self.max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 2048))
|
|
296
|
+
|
|
297
|
+
# Store full model config for template access (includes custom fields)
|
|
298
|
+
self._model_config_raw = model_config
|
|
299
|
+
|
|
300
|
+
def _validate_spec(self):
|
|
301
|
+
"""Validate the spec envelope."""
|
|
302
|
+
config = self.config
|
|
303
|
+
|
|
304
|
+
if config.get('spec') != 'flatagent':
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Invalid spec: expected 'flatagent', got '{config.get('spec')}'. "
|
|
307
|
+
"Config must have: spec: flatagent"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if 'data' not in config:
|
|
311
|
+
raise ValueError("Config missing 'data' section")
|
|
312
|
+
|
|
313
|
+
# Version check with warning
|
|
314
|
+
self.spec_version = config.get('spec_version', '0.6.0')
|
|
315
|
+
major_minor = '.'.join(self.spec_version.split('.')[:2])
|
|
316
|
+
if major_minor not in ['0.5', '0.6']:
|
|
317
|
+
logger.warning(
|
|
318
|
+
f"Config version {self.spec_version} may not be fully supported. "
|
|
319
|
+
f"Current SDK supports 0.5.x - 0.6.x."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Schema validation (warnings only, non-blocking)
|
|
323
|
+
try:
|
|
324
|
+
from .validation import validate_flatagent_config
|
|
325
|
+
validate_flatagent_config(config, warn=True, strict=False)
|
|
326
|
+
except ImportError:
|
|
327
|
+
pass # jsonschema not installed, skip validation
|
|
328
|
+
|
|
329
|
+
def _parse_agent_config(self):
|
|
330
|
+
"""Parse the v0.6.0 flatagent configuration."""
|
|
331
|
+
data = self.config['data']
|
|
332
|
+
self.metadata = self.config.get('metadata', {})
|
|
333
|
+
|
|
334
|
+
# Agent name
|
|
335
|
+
self.agent_name = data.get('name') or self.metadata.get('name', 'unnamed-agent')
|
|
336
|
+
|
|
337
|
+
# Prompts
|
|
338
|
+
self._system_prompt_template = data.get('system', self.DEFAULT_SYSTEM_PROMPT)
|
|
339
|
+
self._user_prompt_template = data.get('user', '')
|
|
340
|
+
self._instruction_suffix = data.get('instruction_suffix', '')
|
|
341
|
+
|
|
342
|
+
# Compile Jinja2 templates
|
|
343
|
+
self._jinja_env = jinja2.Environment()
|
|
344
|
+
self._compiled_system = self._jinja_env.from_string(self._system_prompt_template)
|
|
345
|
+
self._compiled_user = self._jinja_env.from_string(self._user_prompt_template)
|
|
346
|
+
|
|
347
|
+
# Output schema (stored for reference, extraction uses json_object mode)
|
|
348
|
+
self.output_schema = data.get('output', {})
|
|
349
|
+
|
|
350
|
+
# MCP configuration
|
|
351
|
+
self.mcp_config = data.get('mcp')
|
|
352
|
+
|
|
353
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
354
|
+
# MCP Tool Support
|
|
355
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
356
|
+
|
|
357
|
+
def set_tool_provider(self, provider: "MCPToolProvider") -> None:
|
|
358
|
+
"""
|
|
359
|
+
Set the MCP tool provider.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
provider: An object implementing the MCPToolProvider protocol
|
|
363
|
+
"""
|
|
364
|
+
self._tool_provider = provider
|
|
365
|
+
self._tools_cache = None # Clear cache when provider changes
|
|
366
|
+
|
|
367
|
+
def _discover_tools(self) -> List[Dict[str, Any]]:
|
|
368
|
+
"""
|
|
369
|
+
Discover and filter tools from configured MCP servers.
|
|
370
|
+
|
|
371
|
+
Tools are cached for the lifetime of this agent instance.
|
|
372
|
+
Call set_tool_provider() to reset the cache.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
List of tool definitions with '_server' and '_qualified' metadata
|
|
376
|
+
"""
|
|
377
|
+
if self._tools_cache is not None:
|
|
378
|
+
return self._tools_cache
|
|
379
|
+
|
|
380
|
+
if not self.mcp_config or not self._tool_provider:
|
|
381
|
+
return []
|
|
382
|
+
|
|
383
|
+
tools = []
|
|
384
|
+
servers = self.mcp_config.get('servers', {})
|
|
385
|
+
tool_filter = self.mcp_config.get('tool_filter', {})
|
|
386
|
+
|
|
387
|
+
for server_name, server_config in servers.items():
|
|
388
|
+
# Connect to server if not already connected
|
|
389
|
+
self._tool_provider.connect(server_name, server_config)
|
|
390
|
+
|
|
391
|
+
# Get tools from this server
|
|
392
|
+
try:
|
|
393
|
+
server_tools = self._tool_provider.get_tools(server_name)
|
|
394
|
+
except Exception as e:
|
|
395
|
+
logger.warning(f"Failed to get tools from server '{server_name}': {e}")
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
for tool in server_tools:
|
|
399
|
+
qualified_name = f"{server_name}:{tool['name']}"
|
|
400
|
+
if self._passes_filter(qualified_name, tool_filter):
|
|
401
|
+
tools.append({
|
|
402
|
+
**tool,
|
|
403
|
+
'_server': server_name,
|
|
404
|
+
'_qualified': qualified_name
|
|
405
|
+
})
|
|
406
|
+
|
|
407
|
+
self._tools_cache = tools
|
|
408
|
+
logger.info(f"Discovered {len(tools)} tools from {len(servers)} MCP server(s)")
|
|
409
|
+
return tools
|
|
410
|
+
|
|
411
|
+
def _passes_filter(self, qualified_name: str, filter_config: Dict) -> bool:
|
|
412
|
+
"""
|
|
413
|
+
Check if a tool passes the allow/deny filters.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
qualified_name: Tool name in "server:tool" format
|
|
417
|
+
filter_config: Dict with optional 'allow' and 'deny' lists
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
True if tool should be included
|
|
421
|
+
"""
|
|
422
|
+
allow = filter_config.get('allow', [])
|
|
423
|
+
deny = filter_config.get('deny', [])
|
|
424
|
+
|
|
425
|
+
# Deny takes precedence
|
|
426
|
+
for pattern in deny:
|
|
427
|
+
if self._match_pattern(qualified_name, pattern):
|
|
428
|
+
return False
|
|
429
|
+
|
|
430
|
+
# If allow list exists, must match at least one pattern
|
|
431
|
+
if allow:
|
|
432
|
+
return any(self._match_pattern(qualified_name, p) for p in allow)
|
|
433
|
+
|
|
434
|
+
# No allow list = allow all (that aren't denied)
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
def _match_pattern(self, name: str, pattern: str) -> bool:
|
|
438
|
+
"""
|
|
439
|
+
Match a qualified name against a pattern.
|
|
440
|
+
|
|
441
|
+
Supports:
|
|
442
|
+
- Exact match: "server:tool"
|
|
443
|
+
- Wildcard: "server:*" matches all tools from server
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
name: Qualified tool name ("server:tool")
|
|
447
|
+
pattern: Pattern to match against
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
True if name matches pattern
|
|
451
|
+
"""
|
|
452
|
+
if pattern.endswith(':*'):
|
|
453
|
+
return name.startswith(pattern[:-1])
|
|
454
|
+
return name == pattern
|
|
455
|
+
|
|
456
|
+
def _render_tool_prompt(self, tools: List[Dict]) -> str:
|
|
457
|
+
"""
|
|
458
|
+
Render the tool_prompt template from the spec.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
tools: List of discovered tool definitions
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Rendered tool prompt string, or empty string if no tools/template
|
|
465
|
+
"""
|
|
466
|
+
if not tools or not self.mcp_config:
|
|
467
|
+
return ""
|
|
468
|
+
|
|
469
|
+
tool_prompt_template = self.mcp_config.get('tool_prompt', '')
|
|
470
|
+
if not tool_prompt_template:
|
|
471
|
+
return ""
|
|
472
|
+
|
|
473
|
+
template = self._jinja_env.from_string(tool_prompt_template)
|
|
474
|
+
return template.render(tools=tools)
|
|
475
|
+
|
|
476
|
+
def _convert_tools_for_llm(self, tools: List[Dict]) -> List[Dict]:
|
|
477
|
+
"""
|
|
478
|
+
Convert MCP tool schemas to OpenAI function calling format.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
tools: List of MCP tool definitions
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
List of tools in OpenAI function format
|
|
485
|
+
"""
|
|
486
|
+
return [
|
|
487
|
+
{
|
|
488
|
+
"type": "function",
|
|
489
|
+
"function": {
|
|
490
|
+
"name": t['name'],
|
|
491
|
+
"description": t.get('description', ''),
|
|
492
|
+
"parameters": t.get('inputSchema', {"type": "object", "properties": {}})
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
for t in tools
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
def _find_tool_server(self, tool_name: str, tools: List[Dict]) -> str:
|
|
499
|
+
"""
|
|
500
|
+
Find which server a tool belongs to.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
tool_name: Name of the tool
|
|
504
|
+
tools: List of discovered tools with '_server' metadata
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Server name, or empty string if not found
|
|
508
|
+
"""
|
|
509
|
+
for tool in tools:
|
|
510
|
+
if tool['name'] == tool_name:
|
|
511
|
+
return tool.get('_server', '')
|
|
512
|
+
return ''
|
|
513
|
+
|
|
514
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
515
|
+
# Prompt Rendering
|
|
516
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
517
|
+
|
|
518
|
+
def _render_system_prompt(
|
|
519
|
+
self,
|
|
520
|
+
input_data: Dict[str, Any],
|
|
521
|
+
tools_prompt: str = "",
|
|
522
|
+
tools: Optional[List[Dict]] = None
|
|
523
|
+
) -> str:
|
|
524
|
+
"""
|
|
525
|
+
Render system prompt with input data and optional tools context.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
input_data: Input values for {{ input.* }}
|
|
529
|
+
tools_prompt: Rendered tool prompt to inject
|
|
530
|
+
tools: List of tool definitions (available as {{ tools }})
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
Rendered system prompt
|
|
534
|
+
"""
|
|
535
|
+
# Merge raw config with computed values for template access
|
|
536
|
+
model_config = {
|
|
537
|
+
**self._model_config_raw,
|
|
538
|
+
"name": self.model,
|
|
539
|
+
"temperature": self.temperature,
|
|
540
|
+
"max_tokens": self.max_tokens,
|
|
541
|
+
}
|
|
542
|
+
return self._compiled_system.render(
|
|
543
|
+
input=input_data,
|
|
544
|
+
tools_prompt=tools_prompt,
|
|
545
|
+
tools=tools or [],
|
|
546
|
+
model=model_config
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
def _render_user_prompt(
|
|
550
|
+
self,
|
|
551
|
+
input_data: Dict[str, Any],
|
|
552
|
+
tools_prompt: str = "",
|
|
553
|
+
tools: Optional[List[Dict]] = None
|
|
554
|
+
) -> str:
|
|
555
|
+
"""
|
|
556
|
+
Render user prompt with input data and optional tools context.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
input_data: Input values for {{ input.* }}
|
|
560
|
+
tools_prompt: Rendered tool prompt (available as {{ tools_prompt }})
|
|
561
|
+
tools: List of tool definitions (available as {{ tools }})
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Rendered user prompt
|
|
565
|
+
"""
|
|
566
|
+
# Merge raw config with computed values for template access
|
|
567
|
+
model_config = {
|
|
568
|
+
**self._model_config_raw,
|
|
569
|
+
"name": self.model,
|
|
570
|
+
"temperature": self.temperature,
|
|
571
|
+
"max_tokens": self.max_tokens,
|
|
572
|
+
}
|
|
573
|
+
prompt = self._compiled_user.render(
|
|
574
|
+
input=input_data,
|
|
575
|
+
tools_prompt=tools_prompt,
|
|
576
|
+
tools=tools or [],
|
|
577
|
+
model=model_config
|
|
578
|
+
)
|
|
579
|
+
if self._instruction_suffix:
|
|
580
|
+
prompt = f"{prompt}\n\n{self._instruction_suffix}"
|
|
581
|
+
return prompt
|
|
582
|
+
|
|
583
|
+
def _build_output_instruction(self) -> str:
|
|
584
|
+
"""Build instruction for JSON output based on schema."""
|
|
585
|
+
if not self.output_schema:
|
|
586
|
+
return ""
|
|
587
|
+
|
|
588
|
+
fields = []
|
|
589
|
+
for name, field_def in self.output_schema.items():
|
|
590
|
+
desc = field_def.get('description', '')
|
|
591
|
+
field_type = field_def.get('type', 'str')
|
|
592
|
+
enum_vals = field_def.get('enum')
|
|
593
|
+
|
|
594
|
+
parts = [f'"{name}"']
|
|
595
|
+
if desc:
|
|
596
|
+
parts.append(f"({desc})")
|
|
597
|
+
if enum_vals:
|
|
598
|
+
parts.append(f"- one of: {enum_vals}")
|
|
599
|
+
|
|
600
|
+
fields.append(" ".join(parts))
|
|
601
|
+
|
|
602
|
+
return "Respond with JSON containing: " + ", ".join(fields)
|
|
603
|
+
|
|
604
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
605
|
+
# Execution
|
|
606
|
+
# ─────────────────────────────────────────────────────────────────────────
|
|
607
|
+
|
|
608
|
+
async def call(
|
|
609
|
+
self,
|
|
610
|
+
tool_provider: Optional["MCPToolProvider"] = None,
|
|
611
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
612
|
+
**input_data
|
|
613
|
+
) -> "AgentResponse":
|
|
614
|
+
"""
|
|
615
|
+
Execute a single LLM call with the given input.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
tool_provider: Optional MCPToolProvider (overrides constructor value)
|
|
619
|
+
messages: Optional conversation history (for tool call continuations)
|
|
620
|
+
**input_data: Input values available as {{ input.* }} in templates
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
AgentResponse with content, output, and optionally tool_calls
|
|
624
|
+
"""
|
|
625
|
+
from .baseagent import AgentResponse, ToolCall
|
|
626
|
+
|
|
627
|
+
# Use provided tool provider or fall back to stored one
|
|
628
|
+
if tool_provider is not None:
|
|
629
|
+
self._tool_provider = tool_provider
|
|
630
|
+
self._tools_cache = None # Clear cache
|
|
631
|
+
|
|
632
|
+
# Discover tools if MCP is configured
|
|
633
|
+
tools = self._discover_tools()
|
|
634
|
+
tools_prompt = self._render_tool_prompt(tools)
|
|
635
|
+
|
|
636
|
+
# Render prompts
|
|
637
|
+
system_prompt = self._render_system_prompt(input_data, tools_prompt=tools_prompt, tools=tools)
|
|
638
|
+
user_prompt = self._render_user_prompt(input_data, tools_prompt=tools_prompt, tools=tools)
|
|
639
|
+
|
|
640
|
+
# Add output instruction if we have a schema and no tools
|
|
641
|
+
# (with tools, the LLM may call tools instead of returning JSON)
|
|
642
|
+
if self.output_schema and not tools:
|
|
643
|
+
output_instruction = self._build_output_instruction()
|
|
644
|
+
if output_instruction:
|
|
645
|
+
user_prompt = f"{user_prompt}\n\n{output_instruction}"
|
|
646
|
+
|
|
647
|
+
# Build messages
|
|
648
|
+
if messages:
|
|
649
|
+
# Continue from provided message history
|
|
650
|
+
all_messages = [{"role": "system", "content": system_prompt}] + messages
|
|
651
|
+
# Only add user prompt if input_data was provided
|
|
652
|
+
if input_data:
|
|
653
|
+
all_messages.append({"role": "user", "content": user_prompt})
|
|
654
|
+
else:
|
|
655
|
+
all_messages = [
|
|
656
|
+
{"role": "system", "content": system_prompt},
|
|
657
|
+
{"role": "user", "content": user_prompt}
|
|
658
|
+
]
|
|
659
|
+
|
|
660
|
+
# Build LLM call parameters
|
|
661
|
+
params = {
|
|
662
|
+
"model": self.model,
|
|
663
|
+
"messages": all_messages,
|
|
664
|
+
"temperature": self.temperature,
|
|
665
|
+
"max_tokens": self.max_tokens,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
# Add tools if available
|
|
669
|
+
if tools:
|
|
670
|
+
params["tools"] = self._convert_tools_for_llm(tools)
|
|
671
|
+
|
|
672
|
+
# Use JSON mode if we have an output schema and no tools
|
|
673
|
+
if self.output_schema and not tools:
|
|
674
|
+
params["response_format"] = {"type": "json_object"}
|
|
675
|
+
|
|
676
|
+
# Call LLM via selected backend
|
|
677
|
+
response = await self._call_llm(params)
|
|
678
|
+
|
|
679
|
+
# Track usage
|
|
680
|
+
self.total_api_calls += 1
|
|
681
|
+
if hasattr(response, 'usage') and response.usage:
|
|
682
|
+
input_tokens = getattr(response.usage, 'prompt_tokens', 0)
|
|
683
|
+
output_tokens = getattr(response.usage, 'completion_tokens', 0)
|
|
684
|
+
self.total_cost += (input_tokens * 0.001 + output_tokens * 0.002) / 1000
|
|
685
|
+
|
|
686
|
+
# Extract response
|
|
687
|
+
message = response.choices[0].message
|
|
688
|
+
content = message.content
|
|
689
|
+
|
|
690
|
+
# Parse output schema if applicable
|
|
691
|
+
output = None
|
|
692
|
+
if self.output_schema and content and not tools:
|
|
693
|
+
try:
|
|
694
|
+
# Strip markdown fences - LLMs sometimes wrap JSON in ```json blocks
|
|
695
|
+
output = json.loads(strip_markdown_json(content))
|
|
696
|
+
except json.JSONDecodeError:
|
|
697
|
+
logger.warning(f"Failed to parse JSON response: {content}")
|
|
698
|
+
output = {"_raw": content}
|
|
699
|
+
|
|
700
|
+
# Extract tool calls if present
|
|
701
|
+
tool_calls = None
|
|
702
|
+
if hasattr(message, 'tool_calls') and message.tool_calls:
|
|
703
|
+
tool_calls = []
|
|
704
|
+
for tc in message.tool_calls:
|
|
705
|
+
tool_name = tc.function.name
|
|
706
|
+
server = self._find_tool_server(tool_name, tools)
|
|
707
|
+
|
|
708
|
+
try:
|
|
709
|
+
arguments = json.loads(tc.function.arguments)
|
|
710
|
+
except json.JSONDecodeError:
|
|
711
|
+
arguments = {}
|
|
712
|
+
|
|
713
|
+
tool_calls.append(ToolCall(
|
|
714
|
+
id=tc.id,
|
|
715
|
+
server=server,
|
|
716
|
+
tool=tool_name,
|
|
717
|
+
arguments=arguments
|
|
718
|
+
))
|
|
719
|
+
|
|
720
|
+
return AgentResponse(
|
|
721
|
+
content=content,
|
|
722
|
+
output=output,
|
|
723
|
+
tool_calls=tool_calls,
|
|
724
|
+
raw_response=response
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
def call_sync(
|
|
728
|
+
self,
|
|
729
|
+
tool_provider: Optional["MCPToolProvider"] = None,
|
|
730
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
731
|
+
**input_data
|
|
732
|
+
) -> "AgentResponse":
|
|
733
|
+
"""Synchronous wrapper for call()."""
|
|
734
|
+
import asyncio
|
|
735
|
+
return asyncio.run(self.call(tool_provider=tool_provider, messages=messages, **input_data))
|