praisonaiagents 0.0.98__tar.gz → 0.0.100__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/PKG-INFO +1 -1
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agent/agent.py +199 -9
- praisonaiagents-0.0.100/praisonaiagents/guardrails/__init__.py +11 -0
- praisonaiagents-0.0.100/praisonaiagents/guardrails/guardrail_result.py +43 -0
- praisonaiagents-0.0.100/praisonaiagents/guardrails/llm_guardrail.py +88 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/llm/llm.py +248 -148
- praisonaiagents-0.0.100/praisonaiagents/memory/__init__.py +15 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/memory/memory.py +7 -4
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/task/task.py +5 -1
- praisonaiagents-0.0.100/praisonaiagents/tools/searxng_tools.py +94 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents.egg-info/PKG-INFO +1 -1
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents.egg-info/SOURCES.txt +5 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/pyproject.toml +4 -3
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/README.md +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agent/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agent/image_agent.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agents/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agents/agents.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/agents/autoagents.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/approval.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/knowledge/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/knowledge/chunking.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/knowledge/knowledge.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/llm/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/main.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/mcp/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/mcp/mcp.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/mcp/mcp_sse.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/process/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/process/process.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/session.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/task/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/__init__.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/arxiv_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/calculator_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/csv_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/duckdb_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/duckduckgo_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/excel_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/file_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/json_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/newspaper_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/pandas_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/python_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/shell_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/spider_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/test.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/train/data/generatecot.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/wikipedia_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/xml_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/yaml_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents/tools/yfinance_tools.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents.egg-info/dependency_links.txt +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents.egg-info/requires.txt +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/praisonaiagents.egg-info/top_level.txt +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/setup.cfg +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/tests/test-graph-memory.py +0 -0
- {praisonaiagents-0.0.98 → praisonaiagents-0.0.100}/tests/test.py +0 -0
@@ -3,7 +3,7 @@ import time
|
|
3
3
|
import json
|
4
4
|
import logging
|
5
5
|
import asyncio
|
6
|
-
from typing import List, Optional, Any, Dict, Union, Literal, TYPE_CHECKING
|
6
|
+
from typing import List, Optional, Any, Dict, Union, Literal, TYPE_CHECKING, Callable, Tuple
|
7
7
|
from rich.console import Console
|
8
8
|
from rich.live import Live
|
9
9
|
from openai import AsyncOpenAI
|
@@ -32,6 +32,7 @@ _shared_apps = {} # Dict of port -> FastAPI app
|
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
34
|
from ..task.task import Task
|
35
|
+
from ..main import TaskOutput
|
35
36
|
|
36
37
|
@dataclass
|
37
38
|
class ChatCompletionMessage:
|
@@ -368,7 +369,9 @@ class Agent:
|
|
368
369
|
min_reflect: int = 1,
|
369
370
|
reflect_llm: Optional[str] = None,
|
370
371
|
user_id: Optional[str] = None,
|
371
|
-
reasoning_steps: bool = False
|
372
|
+
reasoning_steps: bool = False,
|
373
|
+
guardrail: Optional[Union[Callable[['TaskOutput'], Tuple[bool, Any]], str]] = None,
|
374
|
+
max_guardrail_retries: int = 3
|
372
375
|
):
|
373
376
|
# Add check at start if memory is requested
|
374
377
|
if memory is not None:
|
@@ -483,6 +486,12 @@ Your Goal: {self.goal}
|
|
483
486
|
# Store user_id
|
484
487
|
self.user_id = user_id or "praison"
|
485
488
|
self.reasoning_steps = reasoning_steps
|
489
|
+
|
490
|
+
# Initialize guardrail settings
|
491
|
+
self.guardrail = guardrail
|
492
|
+
self.max_guardrail_retries = max_guardrail_retries
|
493
|
+
self._guardrail_fn = None
|
494
|
+
self._setup_guardrail()
|
486
495
|
|
487
496
|
# Check if knowledge parameter has any values
|
488
497
|
if not knowledge:
|
@@ -512,6 +521,149 @@ Your Goal: {self.goal}
|
|
512
521
|
except Exception as e:
|
513
522
|
logging.error(f"Error processing knowledge item: {knowledge_item}, error: {e}")
|
514
523
|
|
524
|
+
def _setup_guardrail(self):
|
525
|
+
"""Setup the guardrail function based on the provided guardrail parameter."""
|
526
|
+
if self.guardrail is None:
|
527
|
+
self._guardrail_fn = None
|
528
|
+
return
|
529
|
+
|
530
|
+
if callable(self.guardrail):
|
531
|
+
# Validate function signature
|
532
|
+
sig = inspect.signature(self.guardrail)
|
533
|
+
positional_args = [
|
534
|
+
param for param in sig.parameters.values()
|
535
|
+
if param.default is inspect.Parameter.empty
|
536
|
+
]
|
537
|
+
if len(positional_args) != 1:
|
538
|
+
raise ValueError("Agent guardrail function must accept exactly one parameter (TaskOutput)")
|
539
|
+
|
540
|
+
# Check return annotation if present
|
541
|
+
from typing import get_args, get_origin
|
542
|
+
return_annotation = sig.return_annotation
|
543
|
+
if return_annotation != inspect.Signature.empty:
|
544
|
+
return_annotation_args = get_args(return_annotation)
|
545
|
+
if not (
|
546
|
+
get_origin(return_annotation) is tuple
|
547
|
+
and len(return_annotation_args) == 2
|
548
|
+
and return_annotation_args[0] is bool
|
549
|
+
and (
|
550
|
+
return_annotation_args[1] is Any
|
551
|
+
or return_annotation_args[1] is str
|
552
|
+
or str(return_annotation_args[1]).endswith('TaskOutput')
|
553
|
+
or str(return_annotation_args[1]).startswith('typing.Union')
|
554
|
+
)
|
555
|
+
):
|
556
|
+
raise ValueError(
|
557
|
+
"If return type is annotated, it must be Tuple[bool, Any] or Tuple[bool, Union[str, TaskOutput]]"
|
558
|
+
)
|
559
|
+
|
560
|
+
self._guardrail_fn = self.guardrail
|
561
|
+
elif isinstance(self.guardrail, str):
|
562
|
+
# Create LLM-based guardrail
|
563
|
+
from ..guardrails import LLMGuardrail
|
564
|
+
llm = getattr(self, 'llm', None) or getattr(self, 'llm_instance', None)
|
565
|
+
self._guardrail_fn = LLMGuardrail(description=self.guardrail, llm=llm)
|
566
|
+
else:
|
567
|
+
raise ValueError("Agent guardrail must be either a callable or a string description")
|
568
|
+
|
569
|
+
def _process_guardrail(self, task_output):
|
570
|
+
"""Process the guardrail validation for a task output.
|
571
|
+
|
572
|
+
Args:
|
573
|
+
task_output: The task output to validate
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
GuardrailResult: The result of the guardrail validation
|
577
|
+
"""
|
578
|
+
from ..guardrails import GuardrailResult
|
579
|
+
|
580
|
+
if not self._guardrail_fn:
|
581
|
+
return GuardrailResult(success=True, result=task_output)
|
582
|
+
|
583
|
+
try:
|
584
|
+
# Call the guardrail function
|
585
|
+
result = self._guardrail_fn(task_output)
|
586
|
+
|
587
|
+
# Convert the result to a GuardrailResult
|
588
|
+
return GuardrailResult.from_tuple(result)
|
589
|
+
|
590
|
+
except Exception as e:
|
591
|
+
logging.error(f"Agent {self.name}: Error in guardrail validation: {e}")
|
592
|
+
# On error, return failure
|
593
|
+
return GuardrailResult(
|
594
|
+
success=False,
|
595
|
+
result=None,
|
596
|
+
error=f"Agent guardrail validation error: {str(e)}"
|
597
|
+
)
|
598
|
+
|
599
|
+
def _apply_guardrail_with_retry(self, response_text, prompt, temperature=0.2, tools=None):
|
600
|
+
"""Apply guardrail validation with retry logic.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
response_text: The response to validate
|
604
|
+
prompt: Original prompt for regeneration if needed
|
605
|
+
temperature: Temperature for regeneration
|
606
|
+
tools: Tools for regeneration
|
607
|
+
|
608
|
+
Returns:
|
609
|
+
str: The validated response text or None if validation fails after retries
|
610
|
+
"""
|
611
|
+
if not self._guardrail_fn:
|
612
|
+
return response_text
|
613
|
+
|
614
|
+
from ..main import TaskOutput
|
615
|
+
|
616
|
+
retry_count = 0
|
617
|
+
current_response = response_text
|
618
|
+
|
619
|
+
while retry_count <= self.max_guardrail_retries:
|
620
|
+
# Create TaskOutput object
|
621
|
+
task_output = TaskOutput(
|
622
|
+
description="Agent response output",
|
623
|
+
raw=current_response,
|
624
|
+
agent=self.name
|
625
|
+
)
|
626
|
+
|
627
|
+
# Process guardrail
|
628
|
+
guardrail_result = self._process_guardrail(task_output)
|
629
|
+
|
630
|
+
if guardrail_result.success:
|
631
|
+
logging.info(f"Agent {self.name}: Guardrail validation passed")
|
632
|
+
# Return the potentially modified result
|
633
|
+
if guardrail_result.result and hasattr(guardrail_result.result, 'raw'):
|
634
|
+
return guardrail_result.result.raw
|
635
|
+
elif guardrail_result.result:
|
636
|
+
return str(guardrail_result.result)
|
637
|
+
else:
|
638
|
+
return current_response
|
639
|
+
|
640
|
+
# Guardrail failed
|
641
|
+
if retry_count >= self.max_guardrail_retries:
|
642
|
+
raise Exception(
|
643
|
+
f"Agent {self.name} response failed guardrail validation after {self.max_guardrail_retries} retries. "
|
644
|
+
f"Last error: {guardrail_result.error}"
|
645
|
+
)
|
646
|
+
|
647
|
+
retry_count += 1
|
648
|
+
logging.warning(f"Agent {self.name}: Guardrail validation failed (retry {retry_count}/{self.max_guardrail_retries}): {guardrail_result.error}")
|
649
|
+
|
650
|
+
# Regenerate response for retry
|
651
|
+
try:
|
652
|
+
retry_prompt = f"{prompt}\n\nNote: Previous response failed validation due to: {guardrail_result.error}. Please provide an improved response."
|
653
|
+
response = self._chat_completion([{"role": "user", "content": retry_prompt}], temperature, tools)
|
654
|
+
if response and response.choices:
|
655
|
+
current_response = response.choices[0].message.content.strip()
|
656
|
+
else:
|
657
|
+
raise Exception("Failed to generate retry response")
|
658
|
+
except Exception as e:
|
659
|
+
logging.error(f"Agent {self.name}: Error during guardrail retry: {e}")
|
660
|
+
# If we can't regenerate, fail the guardrail
|
661
|
+
raise Exception(
|
662
|
+
f"Agent {self.name} guardrail retry failed: {e}"
|
663
|
+
)
|
664
|
+
|
665
|
+
return current_response
|
666
|
+
|
515
667
|
def generate_task(self) -> 'Task':
|
516
668
|
"""Generate a Task object from the agent's instructions"""
|
517
669
|
from ..task.task import Task
|
@@ -967,7 +1119,13 @@ Your Goal: {self.goal}
|
|
967
1119
|
total_time = time.time() - start_time
|
968
1120
|
logging.debug(f"Agent.chat completed in {total_time:.2f} seconds")
|
969
1121
|
|
970
|
-
|
1122
|
+
# Apply guardrail validation for custom LLM response
|
1123
|
+
try:
|
1124
|
+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
|
1125
|
+
return validated_response
|
1126
|
+
except Exception as e:
|
1127
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed for custom LLM: {e}")
|
1128
|
+
return None
|
971
1129
|
except Exception as e:
|
972
1130
|
display_error(f"Error in LLM chat: {e}")
|
973
1131
|
return None
|
@@ -1055,8 +1213,20 @@ Your Goal: {self.goal}
|
|
1055
1213
|
display_interaction(original_prompt, response_text, markdown=self.markdown, generation_time=time.time() - start_time, console=self.console)
|
1056
1214
|
# Return only reasoning content if reasoning_steps is True
|
1057
1215
|
if reasoning_steps and hasattr(response.choices[0].message, 'reasoning_content'):
|
1058
|
-
|
1059
|
-
|
1216
|
+
# Apply guardrail to reasoning content
|
1217
|
+
try:
|
1218
|
+
validated_reasoning = self._apply_guardrail_with_retry(response.choices[0].message.reasoning_content, original_prompt, temperature, tools)
|
1219
|
+
return validated_reasoning
|
1220
|
+
except Exception as e:
|
1221
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed for reasoning content: {e}")
|
1222
|
+
return None
|
1223
|
+
# Apply guardrail to regular response
|
1224
|
+
try:
|
1225
|
+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
|
1226
|
+
return validated_response
|
1227
|
+
except Exception as e:
|
1228
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed: {e}")
|
1229
|
+
return None
|
1060
1230
|
|
1061
1231
|
reflection_prompt = f"""
|
1062
1232
|
Reflect on your previous response: '{response_text}'.
|
@@ -1089,7 +1259,13 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1089
1259
|
self.chat_history.append({"role": "user", "content": prompt})
|
1090
1260
|
self.chat_history.append({"role": "assistant", "content": response_text})
|
1091
1261
|
display_interaction(prompt, response_text, markdown=self.markdown, generation_time=time.time() - start_time, console=self.console)
|
1092
|
-
|
1262
|
+
# Apply guardrail validation after satisfactory reflection
|
1263
|
+
try:
|
1264
|
+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
|
1265
|
+
return validated_response
|
1266
|
+
except Exception as e:
|
1267
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed after reflection: {e}")
|
1268
|
+
return None
|
1093
1269
|
|
1094
1270
|
# Check if we've hit max reflections
|
1095
1271
|
if reflection_count >= self.max_reflect - 1:
|
@@ -1098,7 +1274,13 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1098
1274
|
self.chat_history.append({"role": "user", "content": prompt})
|
1099
1275
|
self.chat_history.append({"role": "assistant", "content": response_text})
|
1100
1276
|
display_interaction(prompt, response_text, markdown=self.markdown, generation_time=time.time() - start_time, console=self.console)
|
1101
|
-
|
1277
|
+
# Apply guardrail validation after max reflections
|
1278
|
+
try:
|
1279
|
+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
|
1280
|
+
return validated_response
|
1281
|
+
except Exception as e:
|
1282
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed after max reflections: {e}")
|
1283
|
+
return None
|
1102
1284
|
|
1103
1285
|
logging.debug(f"{self.name} reflection count {reflection_count + 1}, continuing reflection process")
|
1104
1286
|
messages.append({"role": "user", "content": "Now regenerate your response using the reflection you made"})
|
@@ -1122,8 +1304,16 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1122
1304
|
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
|
1123
1305
|
total_time = time.time() - start_time
|
1124
1306
|
logging.debug(f"Agent.chat completed in {total_time:.2f} seconds")
|
1125
|
-
|
1126
|
-
|
1307
|
+
|
1308
|
+
# Apply guardrail validation before returning
|
1309
|
+
try:
|
1310
|
+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
|
1311
|
+
return validated_response
|
1312
|
+
except Exception as e:
|
1313
|
+
logging.error(f"Agent {self.name}: Guardrail validation failed: {e}")
|
1314
|
+
if self.verbose:
|
1315
|
+
display_error(f"Guardrail validation failed: {e}", console=self.console)
|
1316
|
+
return None
|
1127
1317
|
|
1128
1318
|
def clean_json_output(self, output: str) -> str:
|
1129
1319
|
"""Clean and extract JSON from response text."""
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Guardrails module for PraisonAI Agents.
|
3
|
+
|
4
|
+
This module provides validation and safety mechanisms for task outputs,
|
5
|
+
including both function-based and LLM-based guardrails.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .guardrail_result import GuardrailResult
|
9
|
+
from .llm_guardrail import LLMGuardrail
|
10
|
+
|
11
|
+
__all__ = ["GuardrailResult", "LLMGuardrail"]
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Guardrail result classes for PraisonAI Agents.
|
3
|
+
|
4
|
+
This module provides the result types for guardrail validation,
|
5
|
+
following the same pattern as CrewAI for consistency.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Any, Tuple, Union
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
from ..main import TaskOutput
|
11
|
+
|
12
|
+
|
13
|
+
class GuardrailResult(BaseModel):
|
14
|
+
"""Result of a guardrail validation."""
|
15
|
+
|
16
|
+
success: bool = Field(description="Whether the guardrail check passed")
|
17
|
+
result: Union[str, TaskOutput, None] = Field(description="The result if modified, or None if unchanged")
|
18
|
+
error: str = Field(default="", description="Error message if validation failed")
|
19
|
+
|
20
|
+
@classmethod
|
21
|
+
def from_tuple(cls, result: Tuple[bool, Any]) -> "GuardrailResult":
|
22
|
+
"""Create a GuardrailResult from a tuple returned by a guardrail function.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
result: Tuple of (success, result_or_error)
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
GuardrailResult: The structured result
|
29
|
+
"""
|
30
|
+
success, data = result
|
31
|
+
|
32
|
+
if success:
|
33
|
+
return cls(
|
34
|
+
success=True,
|
35
|
+
result=data,
|
36
|
+
error=""
|
37
|
+
)
|
38
|
+
else:
|
39
|
+
return cls(
|
40
|
+
success=False,
|
41
|
+
result=None,
|
42
|
+
error=str(data) if data else "Guardrail validation failed"
|
43
|
+
)
|
@@ -0,0 +1,88 @@
|
|
1
|
+
"""
|
2
|
+
LLM-based guardrail implementation for PraisonAI Agents.
|
3
|
+
|
4
|
+
This module provides LLM-powered guardrails that can validate task outputs
|
5
|
+
using natural language descriptions, similar to CrewAI's implementation.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from typing import Any, Tuple, Union, Optional
|
10
|
+
from pydantic import BaseModel
|
11
|
+
from ..main import TaskOutput
|
12
|
+
|
13
|
+
|
14
|
+
class LLMGuardrail:
|
15
|
+
"""An LLM-powered guardrail that validates task outputs using natural language."""
|
16
|
+
|
17
|
+
def __init__(self, description: str, llm: Any = None):
|
18
|
+
"""Initialize the LLM guardrail.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
description: Natural language description of what to validate
|
22
|
+
llm: The LLM instance to use for validation
|
23
|
+
"""
|
24
|
+
self.description = description
|
25
|
+
self.llm = llm
|
26
|
+
self.logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Union[str, TaskOutput]]:
|
29
|
+
"""Validate the task output using the LLM.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
task_output: The task output to validate
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
Tuple of (success, result) where result is the output or error message
|
36
|
+
"""
|
37
|
+
try:
|
38
|
+
if not self.llm:
|
39
|
+
self.logger.warning("No LLM provided for guardrail validation")
|
40
|
+
return True, task_output
|
41
|
+
|
42
|
+
# Create validation prompt
|
43
|
+
validation_prompt = f"""
|
44
|
+
You are a quality assurance validator. Your task is to evaluate the following output against specific criteria.
|
45
|
+
|
46
|
+
Validation Criteria: {self.description}
|
47
|
+
|
48
|
+
Output to Validate:
|
49
|
+
{task_output.raw}
|
50
|
+
|
51
|
+
Please evaluate if this output meets the criteria. Respond with:
|
52
|
+
1. "PASS" if the output meets all criteria
|
53
|
+
2. "FAIL: [specific reason]" if the output does not meet criteria
|
54
|
+
|
55
|
+
Your response:"""
|
56
|
+
|
57
|
+
# Get LLM response
|
58
|
+
if hasattr(self.llm, 'chat'):
|
59
|
+
# For Agent's LLM interface
|
60
|
+
response = self.llm.chat(validation_prompt, temperature=0.1)
|
61
|
+
elif hasattr(self.llm, 'get_response'):
|
62
|
+
# For custom LLM instances
|
63
|
+
response = self.llm.get_response(validation_prompt, temperature=0.1)
|
64
|
+
elif callable(self.llm):
|
65
|
+
# For simple callable LLMs
|
66
|
+
response = self.llm(validation_prompt)
|
67
|
+
else:
|
68
|
+
self.logger.error(f"Unsupported LLM type: {type(self.llm)}")
|
69
|
+
return True, task_output
|
70
|
+
|
71
|
+
# Parse response
|
72
|
+
response = str(response).strip()
|
73
|
+
|
74
|
+
if response.upper().startswith("PASS"):
|
75
|
+
return True, task_output
|
76
|
+
elif response.upper().startswith("FAIL"):
|
77
|
+
# Extract the reason
|
78
|
+
reason = response[5:].strip(": ")
|
79
|
+
return False, f"Guardrail validation failed: {reason}"
|
80
|
+
else:
|
81
|
+
# Unclear response, log and pass through
|
82
|
+
self.logger.warning(f"Unclear guardrail response: {response}")
|
83
|
+
return True, task_output
|
84
|
+
|
85
|
+
except Exception as e:
|
86
|
+
self.logger.error(f"Error in LLM guardrail validation: {str(e)}")
|
87
|
+
# On error, pass through the original output
|
88
|
+
return True, task_output
|