agno 2.0.11__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +607 -176
- agno/db/in_memory/in_memory_db.py +42 -29
- agno/db/mongo/mongo.py +65 -66
- agno/db/postgres/postgres.py +6 -4
- agno/db/utils.py +50 -22
- agno/exceptions.py +62 -1
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +51 -0
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/ollama.py +5 -0
- agno/knowledge/embedder/openai.py +18 -54
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +11 -4
- agno/knowledge/reader/pdf_reader.py +4 -3
- agno/knowledge/reader/website_reader.py +3 -2
- agno/models/base.py +125 -32
- agno/models/cerebras/cerebras.py +1 -0
- agno/models/cerebras/cerebras_openai.py +1 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/google/gemini.py +27 -5
- agno/models/openai/chat.py +13 -4
- agno/models/openai/responses.py +1 -1
- agno/models/perplexity/perplexity.py +2 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +49 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +1 -0
- agno/os/app.py +98 -126
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/agui/agui.py +21 -5
- agno/os/interfaces/base.py +4 -2
- agno/os/interfaces/slack/slack.py +13 -8
- agno/os/interfaces/whatsapp/router.py +2 -0
- agno/os/interfaces/whatsapp/whatsapp.py +12 -5
- agno/os/mcp.py +2 -2
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +182 -46
- agno/os/routers/home.py +2 -2
- agno/os/routers/memory/memory.py +23 -1
- agno/os/routers/memory/schemas.py +1 -1
- agno/os/routers/session/session.py +20 -3
- agno/os/utils.py +74 -8
- agno/run/agent.py +120 -77
- agno/run/base.py +2 -13
- agno/run/team.py +115 -72
- agno/run/workflow.py +5 -15
- agno/session/summary.py +9 -10
- agno/session/team.py +2 -1
- agno/team/team.py +721 -169
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +42 -2
- agno/tools/knowledge.py +3 -3
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/spider.py +2 -2
- agno/tools/workflow.py +4 -5
- agno/utils/events.py +66 -1
- agno/utils/hooks.py +57 -0
- agno/utils/media.py +11 -9
- agno/utils/print_response/agent.py +43 -5
- agno/utils/print_response/team.py +48 -12
- agno/utils/serialize.py +32 -0
- agno/vectordb/cassandra/cassandra.py +44 -4
- agno/vectordb/chroma/chromadb.py +79 -8
- agno/vectordb/clickhouse/clickhousedb.py +43 -6
- agno/vectordb/couchbase/couchbase.py +76 -5
- agno/vectordb/lancedb/lance_db.py +38 -3
- agno/vectordb/milvus/milvus.py +76 -4
- agno/vectordb/mongodb/mongodb.py +76 -4
- agno/vectordb/pgvector/pgvector.py +50 -6
- agno/vectordb/pineconedb/pineconedb.py +39 -2
- agno/vectordb/qdrant/qdrant.py +76 -26
- agno/vectordb/singlestore/singlestore.py +77 -4
- agno/vectordb/upstashdb/upstashdb.py +42 -2
- agno/vectordb/weaviate/weaviate.py +39 -3
- agno/workflow/types.py +5 -6
- agno/workflow/workflow.py +58 -2
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/METADATA +4 -3
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/RECORD +93 -82
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/WHEEL +0 -0
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/top_level.txt +0 -0
agno/exceptions.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
3
|
|
|
3
4
|
from agno.models.message import Message
|
|
4
5
|
|
|
@@ -17,6 +18,8 @@ class AgentRunException(Exception):
|
|
|
17
18
|
self.agent_message = agent_message
|
|
18
19
|
self.messages = messages
|
|
19
20
|
self.stop_execution = stop_execution
|
|
21
|
+
self.type = "agent_run_error"
|
|
22
|
+
self.error_id = "agent_run_error"
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class RetryAgentRun(AgentRunException):
|
|
@@ -32,6 +35,7 @@ class RetryAgentRun(AgentRunException):
|
|
|
32
35
|
super().__init__(
|
|
33
36
|
exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=False
|
|
34
37
|
)
|
|
38
|
+
self.error_id = "retry_agent_run_error"
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
class StopAgentRun(AgentRunException):
|
|
@@ -47,6 +51,7 @@ class StopAgentRun(AgentRunException):
|
|
|
47
51
|
super().__init__(
|
|
48
52
|
exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True
|
|
49
53
|
)
|
|
54
|
+
self.error_id = "stop_agent_run_error"
|
|
50
55
|
|
|
51
56
|
|
|
52
57
|
class RunCancelledException(Exception):
|
|
@@ -54,6 +59,8 @@ class RunCancelledException(Exception):
|
|
|
54
59
|
|
|
55
60
|
def __init__(self, message: str = "Operation cancelled by user"):
|
|
56
61
|
super().__init__(message)
|
|
62
|
+
self.type = "run_cancelled_error"
|
|
63
|
+
self.error_id = "run_cancelled_error"
|
|
57
64
|
|
|
58
65
|
|
|
59
66
|
class AgnoError(Exception):
|
|
@@ -63,6 +70,8 @@ class AgnoError(Exception):
|
|
|
63
70
|
super().__init__(message)
|
|
64
71
|
self.message = message
|
|
65
72
|
self.status_code = status_code
|
|
73
|
+
self.type = "agno_error"
|
|
74
|
+
self.error_id = "agno_error"
|
|
66
75
|
|
|
67
76
|
def __str__(self) -> str:
|
|
68
77
|
return str(self.message)
|
|
@@ -78,6 +87,9 @@ class ModelProviderError(AgnoError):
|
|
|
78
87
|
self.model_name = model_name
|
|
79
88
|
self.model_id = model_id
|
|
80
89
|
|
|
90
|
+
self.type = "model_provider_error"
|
|
91
|
+
self.error_id = "model_provider_error"
|
|
92
|
+
|
|
81
93
|
|
|
82
94
|
class ModelRateLimitError(ModelProviderError):
|
|
83
95
|
"""Exception raised when a model provider returns a rate limit error."""
|
|
@@ -86,9 +98,58 @@ class ModelRateLimitError(ModelProviderError):
|
|
|
86
98
|
self, message: str, status_code: int = 429, model_name: Optional[str] = None, model_id: Optional[str] = None
|
|
87
99
|
):
|
|
88
100
|
super().__init__(message, status_code, model_name, model_id)
|
|
101
|
+
self.error_id = "model_rate_limit_error"
|
|
89
102
|
|
|
90
103
|
|
|
91
104
|
class EvalError(Exception):
|
|
92
105
|
"""Exception raised when an evaluation fails."""
|
|
93
106
|
|
|
94
107
|
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CheckTrigger(Enum):
|
|
111
|
+
"""Enum for guardrail triggers."""
|
|
112
|
+
|
|
113
|
+
OFF_TOPIC = "off_topic"
|
|
114
|
+
INPUT_NOT_ALLOWED = "input_not_allowed"
|
|
115
|
+
OUTPUT_NOT_ALLOWED = "output_not_allowed"
|
|
116
|
+
VALIDATION_FAILED = "validation_failed"
|
|
117
|
+
|
|
118
|
+
PROMPT_INJECTION = "prompt_injection"
|
|
119
|
+
PII_DETECTED = "pii_detected"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class InputCheckError(Exception):
|
|
123
|
+
"""Exception raised when an input check fails."""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
message: str,
|
|
128
|
+
check_trigger: CheckTrigger = CheckTrigger.INPUT_NOT_ALLOWED,
|
|
129
|
+
additional_data: Optional[Dict[str, Any]] = None,
|
|
130
|
+
):
|
|
131
|
+
super().__init__(message)
|
|
132
|
+
self.type = "input_check_error"
|
|
133
|
+
self.error_id = check_trigger.value
|
|
134
|
+
|
|
135
|
+
self.message = message
|
|
136
|
+
self.check_trigger = check_trigger
|
|
137
|
+
self.additional_data = additional_data
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class OutputCheckError(Exception):
|
|
141
|
+
"""Exception raised when an output check fails."""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
message: str,
|
|
146
|
+
check_trigger: CheckTrigger = CheckTrigger.OUTPUT_NOT_ALLOWED,
|
|
147
|
+
additional_data: Optional[Dict[str, Any]] = None,
|
|
148
|
+
):
|
|
149
|
+
super().__init__(message)
|
|
150
|
+
self.type = "output_check_error"
|
|
151
|
+
self.error_id = check_trigger.value
|
|
152
|
+
|
|
153
|
+
self.message = message
|
|
154
|
+
self.check_trigger = check_trigger
|
|
155
|
+
self.additional_data = additional_data
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from agno.guardrails.base import BaseGuardrail
|
|
2
|
+
from agno.guardrails.openai import OpenAIModerationGuardrail
|
|
3
|
+
from agno.guardrails.pii import PIIDetectionGuardrail
|
|
4
|
+
from agno.guardrails.prompt_injection import PromptInjectionGuardrail
|
|
5
|
+
|
|
6
|
+
__all__ = ["BaseGuardrail", "OpenAIModerationGuardrail", "PIIDetectionGuardrail", "PromptInjectionGuardrail"]
|
agno/guardrails/base.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from agno.run.agent import RunInput
|
|
5
|
+
from agno.run.team import TeamRunInput
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseGuardrail(ABC):
|
|
9
|
+
"""Abstract base class for all guardrail implementations."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
13
|
+
"""Perform synchronous guardrail check."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
18
|
+
"""Perform asynchronous guardrail check."""
|
|
19
|
+
pass
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from os import getenv
|
|
2
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
5
|
+
from agno.guardrails.base import BaseGuardrail
|
|
6
|
+
from agno.run.agent import RunInput
|
|
7
|
+
from agno.run.team import TeamRunInput
|
|
8
|
+
from agno.utils.log import log_debug
|
|
9
|
+
from agno.utils.openai import images_to_message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OpenAIModerationGuardrail(BaseGuardrail):
|
|
13
|
+
"""Guardrail for detecting content that violates OpenAI's content policy.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
moderation_model (str): The model to use for moderation. Defaults to "omni-moderation-latest".
|
|
17
|
+
raise_for_categories (List[str]): The categories to raise for.
|
|
18
|
+
Options are: "sexual", "sexual/minors", "harassment",
|
|
19
|
+
"harassment/threatening", "hate", "hate/threatening",
|
|
20
|
+
"illicit", "illicit/violent", "self-harm", "self-harm/intent",
|
|
21
|
+
"self-harm/instructions", "violence", "violence/graphic".
|
|
22
|
+
Defaults to include all categories.
|
|
23
|
+
api_key (str): The API key to use for moderation. Defaults to the OPENAI_API_KEY environment variable.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
moderation_model: str = "omni-moderation-latest",
|
|
29
|
+
raise_for_categories: Optional[
|
|
30
|
+
List[
|
|
31
|
+
Literal[
|
|
32
|
+
"sexual",
|
|
33
|
+
"sexual/minors",
|
|
34
|
+
"harassment",
|
|
35
|
+
"harassment/threatening",
|
|
36
|
+
"hate",
|
|
37
|
+
"hate/threatening",
|
|
38
|
+
"illicit",
|
|
39
|
+
"illicit/violent",
|
|
40
|
+
"self-harm",
|
|
41
|
+
"self-harm/intent",
|
|
42
|
+
"self-harm/instructions",
|
|
43
|
+
"violence",
|
|
44
|
+
"violence/graphic",
|
|
45
|
+
]
|
|
46
|
+
]
|
|
47
|
+
] = None,
|
|
48
|
+
api_key: Optional[str] = None,
|
|
49
|
+
):
|
|
50
|
+
self.moderation_model = moderation_model
|
|
51
|
+
self.api_key = api_key or getenv("OPENAI_API_KEY")
|
|
52
|
+
self.raise_for_categories = raise_for_categories
|
|
53
|
+
|
|
54
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
55
|
+
"""Check for content that violates OpenAI's content policy."""
|
|
56
|
+
try:
|
|
57
|
+
from openai import OpenAI as OpenAIClient
|
|
58
|
+
except ImportError:
|
|
59
|
+
raise ImportError("`openai` not installed. Please install using `pip install openai`")
|
|
60
|
+
|
|
61
|
+
content = run_input.input_content_string()
|
|
62
|
+
images = run_input.images
|
|
63
|
+
|
|
64
|
+
log_debug(f"Moderating content using {self.moderation_model}")
|
|
65
|
+
client = OpenAIClient(api_key=self.api_key)
|
|
66
|
+
|
|
67
|
+
model_input: Union[str, List[Dict[str, Any]]] = content
|
|
68
|
+
|
|
69
|
+
if images is not None:
|
|
70
|
+
model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
|
|
71
|
+
|
|
72
|
+
# Prepare input based on content type
|
|
73
|
+
response = client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
|
|
74
|
+
|
|
75
|
+
result = response.results[0]
|
|
76
|
+
|
|
77
|
+
if result.flagged:
|
|
78
|
+
moderation_result = {
|
|
79
|
+
"categories": result.categories.model_dump(),
|
|
80
|
+
"category_scores": result.category_scores.model_dump(),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
trigger_validation = False
|
|
84
|
+
|
|
85
|
+
if self.raise_for_categories is not None:
|
|
86
|
+
for category in self.raise_for_categories:
|
|
87
|
+
if moderation_result["categories"][category]:
|
|
88
|
+
trigger_validation = True
|
|
89
|
+
else:
|
|
90
|
+
# Since at least one category is flagged, we need to raise the check
|
|
91
|
+
trigger_validation = True
|
|
92
|
+
|
|
93
|
+
if trigger_validation:
|
|
94
|
+
raise InputCheckError(
|
|
95
|
+
"OpenAI moderation violation detected.",
|
|
96
|
+
additional_data=moderation_result,
|
|
97
|
+
check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
101
|
+
"""Check for content that violates OpenAI's content policy."""
|
|
102
|
+
try:
|
|
103
|
+
from openai import AsyncOpenAI as OpenAIClient
|
|
104
|
+
except ImportError:
|
|
105
|
+
raise ImportError("`openai` not installed. Please install using `pip install openai`")
|
|
106
|
+
|
|
107
|
+
content = run_input.input_content_string()
|
|
108
|
+
images = run_input.images
|
|
109
|
+
|
|
110
|
+
log_debug(f"Moderating content using {self.moderation_model}")
|
|
111
|
+
client = OpenAIClient(api_key=self.api_key)
|
|
112
|
+
|
|
113
|
+
model_input: Union[str, List[Dict[str, Any]]] = content
|
|
114
|
+
|
|
115
|
+
if images is not None:
|
|
116
|
+
model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
|
|
117
|
+
|
|
118
|
+
# Prepare input based on content type
|
|
119
|
+
response = await client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
|
|
120
|
+
|
|
121
|
+
result = response.results[0]
|
|
122
|
+
|
|
123
|
+
if result.flagged:
|
|
124
|
+
moderation_result = {
|
|
125
|
+
"categories": result.categories.model_dump(),
|
|
126
|
+
"category_scores": result.category_scores.model_dump(),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
trigger_validation = False
|
|
130
|
+
|
|
131
|
+
if self.raise_for_categories is not None:
|
|
132
|
+
for category in self.raise_for_categories:
|
|
133
|
+
if moderation_result["categories"][category]:
|
|
134
|
+
trigger_validation = True
|
|
135
|
+
else:
|
|
136
|
+
# Since at least one category is flagged, we need to raise the check
|
|
137
|
+
trigger_validation = True
|
|
138
|
+
|
|
139
|
+
if trigger_validation:
|
|
140
|
+
raise InputCheckError(
|
|
141
|
+
"OpenAI moderation violation detected.",
|
|
142
|
+
additional_data=moderation_result,
|
|
143
|
+
check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
|
|
144
|
+
)
|
agno/guardrails/pii.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from re import Pattern
|
|
2
|
+
from typing import Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
5
|
+
from agno.guardrails.base import BaseGuardrail
|
|
6
|
+
from agno.run.agent import RunInput
|
|
7
|
+
from agno.run.team import TeamRunInput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PIIDetectionGuardrail(BaseGuardrail):
|
|
11
|
+
"""Guardrail for detecting Personally Identifiable Information (PII).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
mask_pii: Whether to mask the PII in the input, rather than raising an error.
|
|
15
|
+
enable_ssn_check: Whether to check for Social Security Numbers. True by default.
|
|
16
|
+
enable_credit_card_check: Whether to check for credit cards. True by default.
|
|
17
|
+
enable_email_check: Whether to check for emails. True by default.
|
|
18
|
+
enable_phone_check: Whether to check for phone numbers. True by default.
|
|
19
|
+
custom_patterns: A dictionary of custom PII patterns to detect. This is added to the default patterns.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
mask_pii: bool = False,
|
|
25
|
+
enable_ssn_check: bool = True,
|
|
26
|
+
enable_credit_card_check: bool = True,
|
|
27
|
+
enable_email_check: bool = True,
|
|
28
|
+
enable_phone_check: bool = True,
|
|
29
|
+
custom_patterns: Optional[Dict[str, Pattern[str]]] = None,
|
|
30
|
+
):
|
|
31
|
+
import re
|
|
32
|
+
|
|
33
|
+
self.mask_pii = mask_pii
|
|
34
|
+
self.pii_patterns = {}
|
|
35
|
+
|
|
36
|
+
if enable_ssn_check:
|
|
37
|
+
self.pii_patterns["SSN"] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
|
38
|
+
if enable_credit_card_check:
|
|
39
|
+
self.pii_patterns["Credit Card"] = re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b")
|
|
40
|
+
if enable_email_check:
|
|
41
|
+
self.pii_patterns["Email"] = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
|
|
42
|
+
if enable_phone_check:
|
|
43
|
+
self.pii_patterns["Phone"] = re.compile(r"\b\d{3}[\s.-]?\d{3}[\s.-]?\d{4}\b")
|
|
44
|
+
|
|
45
|
+
if custom_patterns:
|
|
46
|
+
self.pii_patterns.update(custom_patterns)
|
|
47
|
+
|
|
48
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
49
|
+
"""Check for PII patterns in the input."""
|
|
50
|
+
content = run_input.input_content_string()
|
|
51
|
+
detected_pii = []
|
|
52
|
+
for pii_type, pattern in self.pii_patterns.items():
|
|
53
|
+
if pattern.search(content):
|
|
54
|
+
detected_pii.append(pii_type)
|
|
55
|
+
if detected_pii:
|
|
56
|
+
if self.mask_pii:
|
|
57
|
+
for pii_type in detected_pii:
|
|
58
|
+
|
|
59
|
+
def mask_match(match):
|
|
60
|
+
return "*" * len(match.group(0))
|
|
61
|
+
|
|
62
|
+
content = self.pii_patterns[pii_type].sub(mask_match, content)
|
|
63
|
+
run_input.input_content = content
|
|
64
|
+
return
|
|
65
|
+
else:
|
|
66
|
+
raise InputCheckError(
|
|
67
|
+
"Potential PII detected in input",
|
|
68
|
+
additional_data={"detected_pii": detected_pii},
|
|
69
|
+
check_trigger=CheckTrigger.PII_DETECTED,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
73
|
+
"""Asynchronously check for PII patterns in the input."""
|
|
74
|
+
content = run_input.input_content_string()
|
|
75
|
+
detected_pii = []
|
|
76
|
+
for pii_type, pattern in self.pii_patterns.items():
|
|
77
|
+
if pattern.search(content):
|
|
78
|
+
detected_pii.append(pii_type)
|
|
79
|
+
if detected_pii:
|
|
80
|
+
if self.mask_pii:
|
|
81
|
+
for pii_type in detected_pii:
|
|
82
|
+
|
|
83
|
+
def mask_match(match):
|
|
84
|
+
return "*" * len(match.group(0))
|
|
85
|
+
|
|
86
|
+
content = self.pii_patterns[pii_type].sub(mask_match, content)
|
|
87
|
+
run_input.input_content = content
|
|
88
|
+
return
|
|
89
|
+
else:
|
|
90
|
+
raise InputCheckError(
|
|
91
|
+
"Potential PII detected in input",
|
|
92
|
+
additional_data={"detected_pii": detected_pii},
|
|
93
|
+
check_trigger=CheckTrigger.PII_DETECTED,
|
|
94
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
4
|
+
from agno.guardrails.base import BaseGuardrail
|
|
5
|
+
from agno.run.agent import RunInput
|
|
6
|
+
from agno.run.team import TeamRunInput
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PromptInjectionGuardrail(BaseGuardrail):
|
|
10
|
+
"""Guardrail for detecting prompt injection attempts.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
injection_patterns (Optional[List[str]]): A list of patterns to check for. Defaults to a list of common prompt injection patterns.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, injection_patterns: Optional[List[str]] = None):
|
|
17
|
+
self.injection_patterns = injection_patterns or [
|
|
18
|
+
"ignore previous instructions",
|
|
19
|
+
"ignore your instructions",
|
|
20
|
+
"you are now a",
|
|
21
|
+
"forget everything above",
|
|
22
|
+
"developer mode",
|
|
23
|
+
"override safety",
|
|
24
|
+
"disregard guidelines",
|
|
25
|
+
"system prompt",
|
|
26
|
+
"jailbreak",
|
|
27
|
+
"act as if",
|
|
28
|
+
"pretend you are",
|
|
29
|
+
"roleplay as",
|
|
30
|
+
"simulate being",
|
|
31
|
+
"bypass restrictions",
|
|
32
|
+
"ignore safeguards",
|
|
33
|
+
"admin override",
|
|
34
|
+
"root access",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
38
|
+
"""Check for prompt injection patterns in the input."""
|
|
39
|
+
if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
|
|
40
|
+
raise InputCheckError(
|
|
41
|
+
"Potential jailbreaking or prompt injection detected.",
|
|
42
|
+
check_trigger=CheckTrigger.PROMPT_INJECTION,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
46
|
+
"""Asynchronously check for prompt injection patterns in the input."""
|
|
47
|
+
if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
|
|
48
|
+
raise InputCheckError(
|
|
49
|
+
"Potential jailbreaking or prompt injection detected.",
|
|
50
|
+
check_trigger=CheckTrigger.PROMPT_INJECTION,
|
|
51
|
+
)
|
|
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
5
5
|
|
|
6
6
|
from agno.exceptions import AgnoError, ModelProviderError
|
|
7
7
|
from agno.knowledge.embedder.base import Embedder
|
|
8
|
-
from agno.utils.log import log_error,
|
|
8
|
+
from agno.utils.log import log_error, log_warning
|
|
9
9
|
|
|
10
10
|
try:
|
|
11
11
|
from boto3 import client as AwsClient
|
|
@@ -69,6 +69,11 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
69
69
|
client_params: Optional[Dict[str, Any]] = None
|
|
70
70
|
client: Optional[AwsClient] = None
|
|
71
71
|
|
|
72
|
+
def __post_init__(self):
|
|
73
|
+
if self.enable_batch:
|
|
74
|
+
log_warning("AwsBedrockEmbedder does not support batch embeddings, setting enable_batch to False")
|
|
75
|
+
self.enable_batch = False
|
|
76
|
+
|
|
72
77
|
def get_client(self) -> AwsClient:
|
|
73
78
|
"""
|
|
74
79
|
Returns an AWS Bedrock client.
|
|
@@ -220,10 +225,10 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
220
225
|
# Fallback to the first available embedding type
|
|
221
226
|
for embedding_type in response["embeddings"]:
|
|
222
227
|
return response["embeddings"][embedding_type][0]
|
|
223
|
-
|
|
228
|
+
log_warning("No embeddings found in response")
|
|
224
229
|
return []
|
|
225
230
|
except Exception as e:
|
|
226
|
-
|
|
231
|
+
log_warning(f"Error extracting embeddings: {e}")
|
|
227
232
|
return []
|
|
228
233
|
|
|
229
234
|
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
|
|
@@ -286,7 +291,7 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
286
291
|
# Fallback to the first available embedding type
|
|
287
292
|
for embedding_type in response_body["embeddings"]:
|
|
288
293
|
return response_body["embeddings"][embedding_type][0]
|
|
289
|
-
|
|
294
|
+
log_warning("No embeddings found in response")
|
|
290
295
|
return []
|
|
291
296
|
except ClientError as e:
|
|
292
297
|
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
@@ -154,3 +154,57 @@ class AzureOpenAIEmbedder(Embedder):
|
|
|
154
154
|
embedding = response.data[0].embedding
|
|
155
155
|
usage = response.usage
|
|
156
156
|
return embedding, usage.model_dump()
|
|
157
|
+
|
|
158
|
+
async def async_get_embeddings_batch_and_usage(
|
|
159
|
+
self, texts: List[str]
|
|
160
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
161
|
+
"""
|
|
162
|
+
Get embeddings and usage for multiple texts in batches.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
texts: List of text strings to embed
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
169
|
+
"""
|
|
170
|
+
all_embeddings = []
|
|
171
|
+
all_usage = []
|
|
172
|
+
logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
|
|
173
|
+
|
|
174
|
+
for i in range(0, len(texts), self.batch_size):
|
|
175
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
176
|
+
|
|
177
|
+
req: Dict[str, Any] = {
|
|
178
|
+
"input": batch_texts,
|
|
179
|
+
"model": self.id,
|
|
180
|
+
"encoding_format": self.encoding_format,
|
|
181
|
+
}
|
|
182
|
+
if self.user is not None:
|
|
183
|
+
req["user"] = self.user
|
|
184
|
+
if self.id.startswith("text-embedding-3"):
|
|
185
|
+
req["dimensions"] = self.dimensions
|
|
186
|
+
if self.request_params:
|
|
187
|
+
req.update(self.request_params)
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
response: CreateEmbeddingResponse = await self.aclient.embeddings.create(**req)
|
|
191
|
+
batch_embeddings = [data.embedding for data in response.data]
|
|
192
|
+
all_embeddings.extend(batch_embeddings)
|
|
193
|
+
|
|
194
|
+
# For each embedding in the batch, add the same usage information
|
|
195
|
+
usage_dict = response.usage.model_dump() if response.usage else None
|
|
196
|
+
all_usage.extend([usage_dict] * len(batch_embeddings))
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logger.warning(f"Error in async batch embedding: {e}")
|
|
199
|
+
# Fallback to individual calls for this batch
|
|
200
|
+
for text in batch_texts:
|
|
201
|
+
try:
|
|
202
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
203
|
+
all_embeddings.append(embedding)
|
|
204
|
+
all_usage.append(usage)
|
|
205
|
+
except Exception as e2:
|
|
206
|
+
logger.warning(f"Error in individual async embedding fallback: {e2}")
|
|
207
|
+
all_embeddings.append([])
|
|
208
|
+
all_usage.append(None)
|
|
209
|
+
|
|
210
|
+
return all_embeddings, all_usage
|
agno/knowledge/embedder/base.py
CHANGED
|
@@ -7,6 +7,8 @@ class Embedder:
|
|
|
7
7
|
"""Base class for managing embedders"""
|
|
8
8
|
|
|
9
9
|
dimensions: Optional[int] = 1536
|
|
10
|
+
enable_batch: bool = False
|
|
11
|
+
batch_size: int = 100 # Number of texts to process in each API call
|
|
10
12
|
|
|
11
13
|
def get_embedding(self, text: str) -> List[float]:
|
|
12
14
|
raise NotImplementedError
|