praisonaiagents 0.0.123__py3-none-any.whl → 0.0.125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonaiagents/__init__.py +24 -0
- praisonaiagents/agent/agent.py +329 -192
- praisonaiagents/agents/autoagents.py +1 -1
- praisonaiagents/knowledge/knowledge.py +0 -3
- praisonaiagents/llm/__init__.py +6 -9
- praisonaiagents/llm/llm.py +68 -26
- praisonaiagents/main.py +1 -18
- praisonaiagents/mcp/mcp.py +46 -8
- praisonaiagents/mcp/mcp_http_stream.py +466 -0
- praisonaiagents/mcp/mcp_sse.py +19 -2
- praisonaiagents/process/process.py +88 -3
- praisonaiagents/task/task.py +1 -0
- {praisonaiagents-0.0.123.dist-info → praisonaiagents-0.0.125.dist-info}/METADATA +2 -1
- {praisonaiagents-0.0.123.dist-info → praisonaiagents-0.0.125.dist-info}/RECORD +16 -15
- {praisonaiagents-0.0.123.dist-info → praisonaiagents-0.0.125.dist-info}/WHEEL +0 -0
- {praisonaiagents-0.0.123.dist-info → praisonaiagents-0.0.125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,466 @@
|
|
1
|
+
"""
|
2
|
+
HTTP Stream client implementation for MCP (Model Context Protocol).
|
3
|
+
This module provides the necessary classes and functions to connect to an MCP server
|
4
|
+
over HTTP Stream transport, implementing the Streamable HTTP transport protocol.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import logging
|
9
|
+
import threading
|
10
|
+
import inspect
|
11
|
+
import json
|
12
|
+
import time
|
13
|
+
import uuid
|
14
|
+
from typing import List, Dict, Any, Optional, Callable, Iterable, Union
|
15
|
+
from urllib.parse import urlparse, urljoin
|
16
|
+
|
17
|
+
from mcp import ClientSession
|
18
|
+
try:
|
19
|
+
import aiohttp
|
20
|
+
except ImportError:
|
21
|
+
raise ImportError("aiohttp is required for HTTP Stream transport. Install with: pip install praisonaiagents[mcp]")
|
22
|
+
|
23
|
+
logger = logging.getLogger("mcp-http-stream")
|
24
|
+
|
25
|
+
# Global event loop for async operations
|
26
|
+
_event_loop = None
|
27
|
+
|
28
|
+
def get_event_loop():
|
29
|
+
"""Get or create a global event loop."""
|
30
|
+
global _event_loop
|
31
|
+
if _event_loop is None or _event_loop.is_closed():
|
32
|
+
_event_loop = asyncio.new_event_loop()
|
33
|
+
asyncio.set_event_loop(_event_loop)
|
34
|
+
return _event_loop
|
35
|
+
|
36
|
+
|
37
|
+
class HTTPStreamMCPTool:
|
38
|
+
"""A wrapper for an MCP tool that can be used with praisonaiagents."""
|
39
|
+
|
40
|
+
def __init__(self, name: str, description: str, session: ClientSession, input_schema: Optional[Dict[str, Any]] = None, timeout: int = 60):
|
41
|
+
self.name = name
|
42
|
+
self.__name__ = name # Required for Agent to recognize it as a tool
|
43
|
+
self.__qualname__ = name # Required for Agent to recognize it as a tool
|
44
|
+
self.__doc__ = description # Required for Agent to recognize it as a tool
|
45
|
+
self.description = description
|
46
|
+
self.session = session
|
47
|
+
self.input_schema = input_schema or {}
|
48
|
+
self.timeout = timeout
|
49
|
+
|
50
|
+
# Create a signature based on input schema
|
51
|
+
params = []
|
52
|
+
if input_schema and 'properties' in input_schema:
|
53
|
+
for param_name, prop_schema in input_schema['properties'].items():
|
54
|
+
# Determine type annotation based on schema
|
55
|
+
prop_type = prop_schema.get('type', 'string') if isinstance(prop_schema, dict) else 'string'
|
56
|
+
if prop_type == 'string':
|
57
|
+
annotation = str
|
58
|
+
elif prop_type == 'integer':
|
59
|
+
annotation = int
|
60
|
+
elif prop_type == 'number':
|
61
|
+
annotation = float
|
62
|
+
elif prop_type == 'boolean':
|
63
|
+
annotation = bool
|
64
|
+
elif prop_type == 'array':
|
65
|
+
annotation = list
|
66
|
+
elif prop_type == 'object':
|
67
|
+
annotation = dict
|
68
|
+
else:
|
69
|
+
annotation = Any
|
70
|
+
|
71
|
+
params.append(
|
72
|
+
inspect.Parameter(
|
73
|
+
name=param_name,
|
74
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
75
|
+
default=inspect.Parameter.empty if param_name in input_schema.get('required', []) else None,
|
76
|
+
annotation=annotation
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
self.__signature__ = inspect.Signature(params)
|
81
|
+
|
82
|
+
def __call__(self, **kwargs):
|
83
|
+
"""Synchronous wrapper for the async call."""
|
84
|
+
logger.debug(f"Tool {self.name} called with args: {kwargs}")
|
85
|
+
|
86
|
+
# Use the global event loop
|
87
|
+
loop = get_event_loop()
|
88
|
+
|
89
|
+
# Run the async call in the event loop
|
90
|
+
future = asyncio.run_coroutine_threadsafe(self._async_call(**kwargs), loop)
|
91
|
+
try:
|
92
|
+
# Wait for the result with a timeout
|
93
|
+
return future.result(timeout=self.timeout)
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Error calling tool {self.name}: {e}")
|
96
|
+
return f"Error: {str(e)}"
|
97
|
+
|
98
|
+
async def _async_call(self, **kwargs):
|
99
|
+
"""Call the tool with the provided arguments."""
|
100
|
+
logger.debug(f"Async calling tool {self.name} with args: {kwargs}")
|
101
|
+
try:
|
102
|
+
result = await self.session.call_tool(self.name, kwargs)
|
103
|
+
|
104
|
+
# Extract text from result
|
105
|
+
if hasattr(result, 'content') and result.content:
|
106
|
+
if hasattr(result.content[0], 'text'):
|
107
|
+
return result.content[0].text
|
108
|
+
return str(result.content[0])
|
109
|
+
return str(result)
|
110
|
+
except Exception as e:
|
111
|
+
logger.error(f"Error in _async_call for {self.name}: {e}")
|
112
|
+
raise
|
113
|
+
|
114
|
+
def _fix_array_schemas(self, schema):
|
115
|
+
"""
|
116
|
+
Fix array schemas by adding missing 'items' attribute required by OpenAI.
|
117
|
+
|
118
|
+
This ensures compatibility with OpenAI's function calling format which
|
119
|
+
requires array types to specify the type of items they contain.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
schema: The schema dictionary to fix
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
dict: The fixed schema
|
126
|
+
"""
|
127
|
+
if not isinstance(schema, dict):
|
128
|
+
return schema
|
129
|
+
|
130
|
+
# Create a copy to avoid modifying the original
|
131
|
+
fixed_schema = schema.copy()
|
132
|
+
|
133
|
+
# Fix array types at the current level
|
134
|
+
if fixed_schema.get("type") == "array" and "items" not in fixed_schema:
|
135
|
+
# Add a default items schema for arrays without it
|
136
|
+
fixed_schema["items"] = {"type": "string"}
|
137
|
+
|
138
|
+
# Recursively fix nested schemas
|
139
|
+
if "properties" in fixed_schema:
|
140
|
+
fixed_properties = {}
|
141
|
+
for prop_name, prop_schema in fixed_schema["properties"].items():
|
142
|
+
fixed_properties[prop_name] = self._fix_array_schemas(prop_schema)
|
143
|
+
fixed_schema["properties"] = fixed_properties
|
144
|
+
|
145
|
+
# Fix items schema if it exists
|
146
|
+
if "items" in fixed_schema:
|
147
|
+
fixed_schema["items"] = self._fix_array_schemas(fixed_schema["items"])
|
148
|
+
|
149
|
+
return fixed_schema
|
150
|
+
|
151
|
+
def to_openai_tool(self):
|
152
|
+
"""Convert the tool to OpenAI format."""
|
153
|
+
# Fix array schemas to include 'items' attribute
|
154
|
+
fixed_schema = self._fix_array_schemas(self.input_schema)
|
155
|
+
|
156
|
+
return {
|
157
|
+
"type": "function",
|
158
|
+
"function": {
|
159
|
+
"name": self.name,
|
160
|
+
"description": self.description,
|
161
|
+
"parameters": fixed_schema
|
162
|
+
}
|
163
|
+
}
|
164
|
+
|
165
|
+
|
166
|
+
class HTTPStreamTransport:
|
167
|
+
"""
|
168
|
+
HTTP Stream Transport implementation for MCP.
|
169
|
+
|
170
|
+
This transport provides a single endpoint for all MCP communication,
|
171
|
+
supporting both batch (JSON) and streaming (SSE) response modes.
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(self, base_url: str, session_id: Optional[str] = None, options: Optional[Dict[str, Any]] = None):
|
175
|
+
self.base_url = base_url
|
176
|
+
self.session_id = session_id
|
177
|
+
self.options = options or {}
|
178
|
+
self.response_mode = self.options.get('responseMode', 'batch')
|
179
|
+
self.headers = {
|
180
|
+
'Content-Type': 'application/json',
|
181
|
+
'Accept': 'application/json, text/event-stream'
|
182
|
+
}
|
183
|
+
if session_id:
|
184
|
+
self.headers['Mcp-Session-Id'] = session_id
|
185
|
+
|
186
|
+
# Add custom headers if provided
|
187
|
+
if 'headers' in self.options:
|
188
|
+
self.headers.update(self.options['headers'])
|
189
|
+
|
190
|
+
self._session = None
|
191
|
+
self._sse_task = None
|
192
|
+
self._message_queue = asyncio.Queue()
|
193
|
+
self._pending_requests = {}
|
194
|
+
self._closing = False
|
195
|
+
|
196
|
+
async def __aenter__(self):
|
197
|
+
self._session = aiohttp.ClientSession()
|
198
|
+
return self
|
199
|
+
|
200
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
201
|
+
# Set closing flag to stop listener gracefully
|
202
|
+
self._closing = True
|
203
|
+
|
204
|
+
if self._sse_task:
|
205
|
+
self._sse_task.cancel()
|
206
|
+
try:
|
207
|
+
await self._sse_task
|
208
|
+
except asyncio.CancelledError:
|
209
|
+
pass
|
210
|
+
if self._session:
|
211
|
+
await self._session.close()
|
212
|
+
|
213
|
+
async def send_request(self, request: Dict[str, Any]) -> Union[Dict[str, Any], None]:
|
214
|
+
"""Send a request to the HTTP Stream endpoint."""
|
215
|
+
if not self._session:
|
216
|
+
raise RuntimeError("Transport not initialized. Use async context manager.")
|
217
|
+
|
218
|
+
try:
|
219
|
+
async with self._session.post(self.base_url, json=request, headers=self.headers) as response:
|
220
|
+
# Update session ID if provided in response
|
221
|
+
if 'Mcp-Session-Id' in response.headers:
|
222
|
+
self.session_id = response.headers['Mcp-Session-Id']
|
223
|
+
self.headers['Mcp-Session-Id'] = self.session_id
|
224
|
+
|
225
|
+
# Handle different response types
|
226
|
+
content_type = response.headers.get('Content-Type', '')
|
227
|
+
|
228
|
+
if 'text/event-stream' in content_type:
|
229
|
+
# Stream mode - process SSE events
|
230
|
+
return await self._process_sse_response(response)
|
231
|
+
else:
|
232
|
+
# Batch mode - return JSON response
|
233
|
+
return await response.json()
|
234
|
+
|
235
|
+
except Exception as e:
|
236
|
+
logger.error(f"Error sending request: {e}")
|
237
|
+
raise
|
238
|
+
|
239
|
+
async def _process_sse_response(self, response):
|
240
|
+
"""Process SSE response stream."""
|
241
|
+
buffer = ""
|
242
|
+
async for chunk in response.content:
|
243
|
+
buffer += chunk.decode('utf-8')
|
244
|
+
|
245
|
+
# Process complete SSE events
|
246
|
+
while "\n\n" in buffer:
|
247
|
+
event, buffer = buffer.split("\n\n", 1)
|
248
|
+
lines = event.strip().split("\n")
|
249
|
+
|
250
|
+
# Parse SSE event
|
251
|
+
data = None
|
252
|
+
for line in lines:
|
253
|
+
if line.startswith("data: "):
|
254
|
+
data = line[6:] # Remove "data: " prefix
|
255
|
+
|
256
|
+
if data:
|
257
|
+
try:
|
258
|
+
message = json.loads(data)
|
259
|
+
# Process the message
|
260
|
+
if 'id' in message and message['id'] in self._pending_requests:
|
261
|
+
# This is a response to a pending request
|
262
|
+
self._pending_requests[message['id']].set_result(message)
|
263
|
+
else:
|
264
|
+
# This is a server-initiated message
|
265
|
+
await self._message_queue.put(message)
|
266
|
+
except json.JSONDecodeError:
|
267
|
+
logger.error(f"Failed to parse SSE event: {data}")
|
268
|
+
|
269
|
+
async def start_sse_listener(self):
|
270
|
+
"""Start listening for SSE events from the server."""
|
271
|
+
if self._sse_task is None or self._sse_task.done():
|
272
|
+
self._sse_task = asyncio.create_task(self._sse_listener())
|
273
|
+
|
274
|
+
async def _sse_listener(self):
|
275
|
+
"""Background task to listen for SSE events."""
|
276
|
+
while True:
|
277
|
+
try:
|
278
|
+
# Check if we should stop
|
279
|
+
if hasattr(self, '_closing') and self._closing:
|
280
|
+
break
|
281
|
+
|
282
|
+
url = self.base_url
|
283
|
+
if self.session_id:
|
284
|
+
# Add session as query parameter for SSE connection
|
285
|
+
url = f"{url}?session={self.session_id}"
|
286
|
+
|
287
|
+
headers = {
|
288
|
+
'Accept': 'text/event-stream',
|
289
|
+
'Cache-Control': 'no-cache'
|
290
|
+
}
|
291
|
+
if self.session_id:
|
292
|
+
headers['Mcp-Session-Id'] = self.session_id
|
293
|
+
|
294
|
+
async with self._session.get(url, headers=headers) as response:
|
295
|
+
buffer = ""
|
296
|
+
async for chunk in response.content:
|
297
|
+
# Check if we should stop
|
298
|
+
if hasattr(self, '_closing') and self._closing:
|
299
|
+
break
|
300
|
+
|
301
|
+
buffer += chunk.decode('utf-8')
|
302
|
+
|
303
|
+
# Process complete SSE events
|
304
|
+
while "\n\n" in buffer:
|
305
|
+
event, buffer = buffer.split("\n\n", 1)
|
306
|
+
lines = event.strip().split("\n")
|
307
|
+
|
308
|
+
# Parse SSE event
|
309
|
+
data = None
|
310
|
+
for line in lines:
|
311
|
+
if line.startswith("data: "):
|
312
|
+
data = line[6:] # Remove "data: " prefix
|
313
|
+
|
314
|
+
if data:
|
315
|
+
try:
|
316
|
+
message = json.loads(data)
|
317
|
+
await self._message_queue.put(message)
|
318
|
+
except json.JSONDecodeError:
|
319
|
+
logger.error(f"Failed to parse SSE event: {data}")
|
320
|
+
|
321
|
+
except asyncio.CancelledError:
|
322
|
+
# Proper shutdown
|
323
|
+
break
|
324
|
+
except Exception as e:
|
325
|
+
if not (hasattr(self, '_closing') and self._closing):
|
326
|
+
logger.error(f"SSE listener error: {e}")
|
327
|
+
await asyncio.sleep(1) # Reconnect after 1 second
|
328
|
+
else:
|
329
|
+
break
|
330
|
+
|
331
|
+
def read_stream(self):
|
332
|
+
"""Create a read stream for the ClientSession."""
|
333
|
+
async def _read():
|
334
|
+
while True:
|
335
|
+
message = await self._message_queue.get()
|
336
|
+
yield message
|
337
|
+
return _read()
|
338
|
+
|
339
|
+
def write_stream(self):
|
340
|
+
"""Create a write stream for the ClientSession."""
|
341
|
+
async def _write(message):
|
342
|
+
if hasattr(message, 'to_dict'):
|
343
|
+
message = message.to_dict()
|
344
|
+
response = await self.send_request(message)
|
345
|
+
return response
|
346
|
+
return _write
|
347
|
+
|
348
|
+
|
349
|
+
class HTTPStreamMCPClient:
|
350
|
+
"""A client for connecting to an MCP server over HTTP Stream transport."""
|
351
|
+
|
352
|
+
def __init__(self, server_url: str, debug: bool = False, timeout: int = 60, options: Optional[Dict[str, Any]] = None):
|
353
|
+
"""
|
354
|
+
Initialize an HTTP Stream MCP client.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
server_url: The URL of the HTTP Stream MCP server
|
358
|
+
debug: Whether to enable debug logging
|
359
|
+
timeout: Timeout in seconds for operations (default: 60)
|
360
|
+
options: Additional configuration options for the transport
|
361
|
+
"""
|
362
|
+
# Parse URL to extract base URL and endpoint
|
363
|
+
parsed = urlparse(server_url)
|
364
|
+
|
365
|
+
# If the URL already has a path, use it; otherwise use default /mcp endpoint
|
366
|
+
if parsed.path and parsed.path != '/':
|
367
|
+
self.base_url = server_url
|
368
|
+
else:
|
369
|
+
# Default endpoint is /mcp
|
370
|
+
self.base_url = urljoin(server_url, '/mcp')
|
371
|
+
|
372
|
+
self.debug = debug
|
373
|
+
self.timeout = timeout
|
374
|
+
self.options = options or {}
|
375
|
+
self.session = None
|
376
|
+
self.tools = []
|
377
|
+
self.transport = None
|
378
|
+
|
379
|
+
# Set up logging
|
380
|
+
if debug:
|
381
|
+
logger.setLevel(logging.DEBUG)
|
382
|
+
else:
|
383
|
+
# Set to WARNING by default to hide INFO messages
|
384
|
+
logger.setLevel(logging.WARNING)
|
385
|
+
|
386
|
+
self._initialize()
|
387
|
+
|
388
|
+
def _initialize(self):
|
389
|
+
"""Initialize the connection and tools."""
|
390
|
+
# Use the global event loop
|
391
|
+
loop = get_event_loop()
|
392
|
+
|
393
|
+
# Start a background thread to run the event loop
|
394
|
+
def run_event_loop():
|
395
|
+
asyncio.set_event_loop(loop)
|
396
|
+
loop.run_forever()
|
397
|
+
|
398
|
+
self.loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
399
|
+
self.loop_thread.start()
|
400
|
+
|
401
|
+
# Run the initialization in the event loop
|
402
|
+
future = asyncio.run_coroutine_threadsafe(self._async_initialize(), loop)
|
403
|
+
self.tools = future.result(timeout=self.timeout)
|
404
|
+
|
405
|
+
async def _async_initialize(self):
|
406
|
+
"""Asynchronously initialize the connection and tools."""
|
407
|
+
logger.debug(f"Connecting to MCP server at {self.base_url}")
|
408
|
+
|
409
|
+
# Create HTTP Stream transport
|
410
|
+
self.transport = HTTPStreamTransport(self.base_url, options=self.options)
|
411
|
+
await self.transport.__aenter__()
|
412
|
+
|
413
|
+
# Create read and write streams
|
414
|
+
read_stream = self.transport.read_stream()
|
415
|
+
write_stream = self.transport.write_stream()
|
416
|
+
|
417
|
+
# Start SSE listener if in stream mode
|
418
|
+
if self.options.get('responseMode', 'batch') == 'stream':
|
419
|
+
await self.transport.start_sse_listener()
|
420
|
+
|
421
|
+
# Create client session
|
422
|
+
self._session_context = ClientSession(read_stream, write_stream)
|
423
|
+
self.session = await self._session_context.__aenter__()
|
424
|
+
|
425
|
+
# Initialize
|
426
|
+
await self.session.initialize()
|
427
|
+
|
428
|
+
# List available tools
|
429
|
+
logger.debug("Listing tools...")
|
430
|
+
response = await self.session.list_tools()
|
431
|
+
tools_data = response.tools
|
432
|
+
logger.debug(f"Found {len(tools_data)} tools: {[tool.name for tool in tools_data]}")
|
433
|
+
|
434
|
+
# Create tool wrappers
|
435
|
+
tools = []
|
436
|
+
for tool in tools_data:
|
437
|
+
input_schema = tool.inputSchema if hasattr(tool, 'inputSchema') else None
|
438
|
+
wrapper = HTTPStreamMCPTool(
|
439
|
+
name=tool.name,
|
440
|
+
description=tool.description if hasattr(tool, 'description') else f"Call the {tool.name} tool",
|
441
|
+
session=self.session,
|
442
|
+
input_schema=input_schema,
|
443
|
+
timeout=self.timeout
|
444
|
+
)
|
445
|
+
tools.append(wrapper)
|
446
|
+
|
447
|
+
return tools
|
448
|
+
|
449
|
+
def __iter__(self):
|
450
|
+
"""Return an iterator over the tools."""
|
451
|
+
return iter(self.tools)
|
452
|
+
|
453
|
+
def to_openai_tools(self):
|
454
|
+
"""Convert all tools to OpenAI format."""
|
455
|
+
return [tool.to_openai_tool() for tool in self.tools]
|
456
|
+
|
457
|
+
async def __aenter__(self):
|
458
|
+
"""Async context manager entry."""
|
459
|
+
return self
|
460
|
+
|
461
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
462
|
+
"""Async context manager exit."""
|
463
|
+
if self.transport:
|
464
|
+
await self.transport.__aexit__(exc_type, exc_val, exc_tb)
|
465
|
+
if hasattr(self, '_session_context') and self._session_context:
|
466
|
+
await self._session_context.__aexit__(exc_type, exc_val, exc_tb)
|
praisonaiagents/mcp/mcp_sse.py
CHANGED
@@ -44,13 +44,30 @@ class SSEMCPTool:
|
|
44
44
|
# Create a signature based on input schema
|
45
45
|
params = []
|
46
46
|
if input_schema and 'properties' in input_schema:
|
47
|
-
for param_name in input_schema['properties']:
|
47
|
+
for param_name, prop_schema in input_schema['properties'].items():
|
48
|
+
# Determine type annotation based on schema
|
49
|
+
prop_type = prop_schema.get('type', 'string') if isinstance(prop_schema, dict) else 'string'
|
50
|
+
if prop_type == 'string':
|
51
|
+
annotation = str
|
52
|
+
elif prop_type == 'integer':
|
53
|
+
annotation = int
|
54
|
+
elif prop_type == 'number':
|
55
|
+
annotation = float
|
56
|
+
elif prop_type == 'boolean':
|
57
|
+
annotation = bool
|
58
|
+
elif prop_type == 'array':
|
59
|
+
annotation = list
|
60
|
+
elif prop_type == 'object':
|
61
|
+
annotation = dict
|
62
|
+
else:
|
63
|
+
annotation = str # Default to string for SSE
|
64
|
+
|
48
65
|
params.append(
|
49
66
|
inspect.Parameter(
|
50
67
|
name=param_name,
|
51
68
|
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
52
69
|
default=inspect.Parameter.empty if param_name in input_schema.get('required', []) else None,
|
53
|
-
annotation=
|
70
|
+
annotation=annotation
|
54
71
|
)
|
55
72
|
)
|
56
73
|
|
@@ -16,6 +16,7 @@ class LoopItems(BaseModel):
|
|
16
16
|
|
17
17
|
class Process:
|
18
18
|
DEFAULT_RETRY_LIMIT = 3 # Predefined retry limit in a common place
|
19
|
+
VALIDATION_FAILURE_DECISIONS = ["invalid", "retry", "failed", "error", "unsuccessful", "fail", "errors", "reject", "rejected", "incomplete"] # Decision strings that trigger validation feedback
|
19
20
|
|
20
21
|
def __init__(self, tasks: Dict[str, Task], agents: List[Agent], manager_llm: Optional[str] = None, verbose: bool = False, max_iter: int = 10):
|
21
22
|
logging.debug(f"=== Initializing Process ===")
|
@@ -33,12 +34,38 @@ class Process:
|
|
33
34
|
self.task_retry_counter: Dict[str, int] = {} # Initialize retry counter
|
34
35
|
self.workflow_finished = False # ADDED: Workflow finished flag
|
35
36
|
|
37
|
+
def _create_loop_subtasks(self, loop_task: Task):
|
38
|
+
"""Create subtasks for a loop task from input file."""
|
39
|
+
logging.warning(f"_create_loop_subtasks called for {loop_task.name} but method not fully implemented")
|
40
|
+
# TODO: Implement loop subtask creation from input file
|
41
|
+
# This should read loop_task.input_file and create subtasks
|
42
|
+
pass
|
43
|
+
|
36
44
|
def _build_task_context(self, current_task: Task) -> str:
|
37
45
|
"""Build context for a task based on its retain_full_context setting"""
|
38
|
-
if
|
39
|
-
|
46
|
+
# Check if we have validation feedback to include
|
47
|
+
if current_task.validation_feedback:
|
48
|
+
feedback = current_task.validation_feedback
|
49
|
+
context = f"\nPrevious attempt failed validation with reason: {feedback['validation_response']}"
|
50
|
+
if feedback.get('validated_task'):
|
51
|
+
context += f"\nValidated task: {feedback['validated_task']}"
|
52
|
+
if feedback.get('validation_details'):
|
53
|
+
context += f"\nValidation feedback: {feedback['validation_details']}"
|
54
|
+
if feedback.get('rejected_output'):
|
55
|
+
context += f"\nRejected output: {feedback['rejected_output']}"
|
56
|
+
context += "\nPlease try again with a different approach based on this feedback.\n"
|
57
|
+
# Clear the feedback after including it to prevent it from persisting
|
58
|
+
current_task.validation_feedback = None
|
40
59
|
|
41
|
-
|
60
|
+
# If we have validation feedback but no previous tasks context, return just the feedback
|
61
|
+
if not (current_task.previous_tasks or current_task.context):
|
62
|
+
return context
|
63
|
+
# Otherwise, append the regular context
|
64
|
+
context += "\nInput data from previous tasks:"
|
65
|
+
elif not (current_task.previous_tasks or current_task.context):
|
66
|
+
return ""
|
67
|
+
else:
|
68
|
+
context = "\nInput data from previous tasks:"
|
42
69
|
|
43
70
|
if current_task.retain_full_context:
|
44
71
|
# Original behavior: include all previous tasks
|
@@ -496,6 +523,35 @@ Subtask: {st.name}
|
|
496
523
|
next_task = next((t for t in self.tasks.values() if t.name == task_value), None)
|
497
524
|
if next_task:
|
498
525
|
next_task.status = "not started" # Reset status to allow execution
|
526
|
+
|
527
|
+
# Capture validation feedback for retry scenarios
|
528
|
+
if decision_str in Process.VALIDATION_FAILURE_DECISIONS:
|
529
|
+
if current_task and current_task.result:
|
530
|
+
# Get the rejected output from the task that was validated
|
531
|
+
validated_task = None
|
532
|
+
# Find the task that produced the output being validated
|
533
|
+
if current_task.previous_tasks:
|
534
|
+
# For validation tasks, typically validate the most recent previous task
|
535
|
+
prev_task_name = current_task.previous_tasks[-1]
|
536
|
+
validated_task = next((t for t in self.tasks.values() if t.name == prev_task_name), None)
|
537
|
+
elif current_task.context:
|
538
|
+
# If no previous_tasks, check context for the validated task
|
539
|
+
# Use the most recent task with a result from context
|
540
|
+
for ctx_task in reversed(current_task.context):
|
541
|
+
if ctx_task.result and ctx_task.name != current_task.name:
|
542
|
+
validated_task = ctx_task
|
543
|
+
break
|
544
|
+
|
545
|
+
feedback = {
|
546
|
+
'validation_response': decision_str,
|
547
|
+
'validation_details': current_task.result.raw,
|
548
|
+
'rejected_output': validated_task.result.raw if validated_task and validated_task.result else None,
|
549
|
+
'validator_task': current_task.name,
|
550
|
+
'validated_task': validated_task.name if validated_task else None
|
551
|
+
}
|
552
|
+
next_task.validation_feedback = feedback
|
553
|
+
logging.debug(f"Added validation feedback to {next_task.name}: {feedback['validation_response']} (validated task: {feedback.get('validated_task', 'None')})")
|
554
|
+
|
499
555
|
logging.debug(f"Routing to {next_task.name} based on decision: {decision_str}")
|
500
556
|
# Don't mark workflow as finished when following condition path
|
501
557
|
self.workflow_finished = False
|
@@ -1098,6 +1154,35 @@ Subtask: {st.name}
|
|
1098
1154
|
next_task = next((t for t in self.tasks.values() if t.name == task_value), None)
|
1099
1155
|
if next_task:
|
1100
1156
|
next_task.status = "not started" # Reset status to allow execution
|
1157
|
+
|
1158
|
+
# Capture validation feedback for retry scenarios
|
1159
|
+
if decision_str in Process.VALIDATION_FAILURE_DECISIONS:
|
1160
|
+
if current_task and current_task.result:
|
1161
|
+
# Get the rejected output from the task that was validated
|
1162
|
+
validated_task = None
|
1163
|
+
# Find the task that produced the output being validated
|
1164
|
+
if current_task.previous_tasks:
|
1165
|
+
# For validation tasks, typically validate the most recent previous task
|
1166
|
+
prev_task_name = current_task.previous_tasks[-1]
|
1167
|
+
validated_task = next((t for t in self.tasks.values() if t.name == prev_task_name), None)
|
1168
|
+
elif current_task.context:
|
1169
|
+
# If no previous_tasks, check context for the validated task
|
1170
|
+
# Use the most recent task with a result from context
|
1171
|
+
for ctx_task in reversed(current_task.context):
|
1172
|
+
if ctx_task.result and ctx_task.name != current_task.name:
|
1173
|
+
validated_task = ctx_task
|
1174
|
+
break
|
1175
|
+
|
1176
|
+
feedback = {
|
1177
|
+
'validation_response': decision_str,
|
1178
|
+
'validation_details': current_task.result.raw,
|
1179
|
+
'rejected_output': validated_task.result.raw if validated_task and validated_task.result else None,
|
1180
|
+
'validator_task': current_task.name,
|
1181
|
+
'validated_task': validated_task.name if validated_task else None
|
1182
|
+
}
|
1183
|
+
next_task.validation_feedback = feedback
|
1184
|
+
logging.debug(f"Added validation feedback to {next_task.name}: {feedback['validation_response']} (validated task: {feedback.get('validated_task', 'None')})")
|
1185
|
+
|
1101
1186
|
logging.debug(f"Routing to {next_task.name} based on decision: {decision_str}")
|
1102
1187
|
# Don't mark workflow as finished when following condition path
|
1103
1188
|
self.workflow_finished = False
|
praisonaiagents/task/task.py
CHANGED
@@ -88,6 +88,7 @@ class Task:
|
|
88
88
|
self.max_retries = max_retries
|
89
89
|
self.retry_count = retry_count
|
90
90
|
self._guardrail_fn = None
|
91
|
+
self.validation_feedback = None # Store validation failure feedback for retry attempts
|
91
92
|
|
92
93
|
# Set logger level based on config verbose level
|
93
94
|
verbose = self.config.get("verbose", 0)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: praisonaiagents
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.125
|
4
4
|
Summary: Praison AI agents for completing complex tasks with Self Reflection Agents
|
5
5
|
Author: Mervin Praison
|
6
6
|
Requires-Python: >=3.10
|
@@ -9,6 +9,7 @@ Requires-Dist: rich
|
|
9
9
|
Requires-Dist: openai
|
10
10
|
Requires-Dist: mcp>=1.6.0
|
11
11
|
Requires-Dist: posthog>=3.0.0
|
12
|
+
Requires-Dist: aiohttp>=3.8.0
|
12
13
|
Provides-Extra: mcp
|
13
14
|
Requires-Dist: mcp>=1.6.0; extra == "mcp"
|
14
15
|
Requires-Dist: fastapi>=0.115.0; extra == "mcp"
|