agno 2.0.10__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +608 -175
- agno/db/in_memory/in_memory_db.py +42 -29
- agno/db/postgres/postgres.py +6 -4
- agno/exceptions.py +62 -1
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +51 -0
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/ollama.py +5 -0
- agno/knowledge/embedder/openai.py +18 -54
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +5 -4
- agno/knowledge/reader/pdf_reader.py +4 -3
- agno/knowledge/reader/website_reader.py +3 -2
- agno/models/base.py +125 -32
- agno/models/cerebras/cerebras.py +1 -0
- agno/models/cerebras/cerebras_openai.py +1 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/google/gemini.py +27 -5
- agno/models/litellm/chat.py +17 -0
- agno/models/openai/chat.py +13 -4
- agno/models/perplexity/perplexity.py +2 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +49 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +1 -0
- agno/os/app.py +167 -148
- agno/os/interfaces/whatsapp/router.py +2 -0
- agno/os/mcp.py +1 -1
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +181 -45
- agno/os/routers/home.py +2 -2
- agno/os/routers/memory/memory.py +23 -1
- agno/os/routers/memory/schemas.py +1 -1
- agno/os/routers/session/session.py +20 -3
- agno/os/utils.py +172 -8
- agno/run/agent.py +120 -77
- agno/run/team.py +115 -72
- agno/run/workflow.py +5 -15
- agno/session/summary.py +9 -10
- agno/session/team.py +2 -1
- agno/team/team.py +720 -168
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +42 -2
- agno/tools/knowledge.py +3 -3
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/spider.py +2 -2
- agno/tools/workflow.py +4 -5
- agno/utils/events.py +66 -1
- agno/utils/hooks.py +57 -0
- agno/utils/media.py +11 -9
- agno/utils/print_response/agent.py +43 -5
- agno/utils/print_response/team.py +48 -12
- agno/vectordb/cassandra/cassandra.py +44 -4
- agno/vectordb/chroma/chromadb.py +79 -8
- agno/vectordb/clickhouse/clickhousedb.py +43 -6
- agno/vectordb/couchbase/couchbase.py +76 -5
- agno/vectordb/lancedb/lance_db.py +38 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/milvus/milvus.py +76 -4
- agno/vectordb/mongodb/mongodb.py +76 -4
- agno/vectordb/pgvector/pgvector.py +50 -6
- agno/vectordb/pineconedb/pineconedb.py +39 -2
- agno/vectordb/qdrant/qdrant.py +76 -26
- agno/vectordb/singlestore/singlestore.py +77 -4
- agno/vectordb/upstashdb/upstashdb.py +42 -2
- agno/vectordb/weaviate/weaviate.py +39 -3
- agno/workflow/types.py +1 -0
- agno/workflow/workflow.py +58 -2
- {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/METADATA +4 -3
- {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/RECORD +85 -75
- {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/WHEEL +0 -0
- {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import time
|
|
2
|
+
from copy import deepcopy
|
|
2
3
|
from datetime import date, datetime, timedelta, timezone
|
|
3
4
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
5
|
from uuid import uuid4
|
|
@@ -107,15 +108,17 @@ class InMemoryDb(BaseDb):
|
|
|
107
108
|
if session_data.get("session_type") != session_type_value:
|
|
108
109
|
continue
|
|
109
110
|
|
|
111
|
+
session_data_copy = deepcopy(session_data)
|
|
112
|
+
|
|
110
113
|
if not deserialize:
|
|
111
|
-
return
|
|
114
|
+
return session_data_copy
|
|
112
115
|
|
|
113
116
|
if session_type == SessionType.AGENT:
|
|
114
|
-
return AgentSession.from_dict(
|
|
117
|
+
return AgentSession.from_dict(session_data_copy)
|
|
115
118
|
elif session_type == SessionType.TEAM:
|
|
116
|
-
return TeamSession.from_dict(
|
|
119
|
+
return TeamSession.from_dict(session_data_copy)
|
|
117
120
|
else:
|
|
118
|
-
return WorkflowSession.from_dict(
|
|
121
|
+
return WorkflowSession.from_dict(session_data_copy)
|
|
119
122
|
|
|
120
123
|
return None
|
|
121
124
|
|
|
@@ -188,7 +191,7 @@ class InMemoryDb(BaseDb):
|
|
|
188
191
|
if session_data.get("session_type") != session_type_value:
|
|
189
192
|
continue
|
|
190
193
|
|
|
191
|
-
filtered_sessions.append(session_data)
|
|
194
|
+
filtered_sessions.append(deepcopy(session_data))
|
|
192
195
|
|
|
193
196
|
total_count = len(filtered_sessions)
|
|
194
197
|
|
|
@@ -233,15 +236,16 @@ class InMemoryDb(BaseDb):
|
|
|
233
236
|
|
|
234
237
|
log_debug(f"Renamed session with id '{session_id}' to '{session_name}'")
|
|
235
238
|
|
|
239
|
+
session_copy = deepcopy(session)
|
|
236
240
|
if not deserialize:
|
|
237
|
-
return
|
|
241
|
+
return session_copy
|
|
238
242
|
|
|
239
243
|
if session_type == SessionType.AGENT:
|
|
240
|
-
return AgentSession.from_dict(
|
|
244
|
+
return AgentSession.from_dict(session_copy)
|
|
241
245
|
elif session_type == SessionType.TEAM:
|
|
242
|
-
return TeamSession.from_dict(
|
|
246
|
+
return TeamSession.from_dict(session_copy)
|
|
243
247
|
else:
|
|
244
|
-
return WorkflowSession.from_dict(
|
|
248
|
+
return WorkflowSession.from_dict(session_copy)
|
|
245
249
|
|
|
246
250
|
return None
|
|
247
251
|
|
|
@@ -269,22 +273,26 @@ class InMemoryDb(BaseDb):
|
|
|
269
273
|
if existing_session.get("session_id") == session_dict.get("session_id") and self._matches_session_key(
|
|
270
274
|
existing_session, session
|
|
271
275
|
):
|
|
272
|
-
# Update existing session
|
|
273
276
|
session_dict["updated_at"] = int(time.time())
|
|
274
|
-
self._sessions[i] = session_dict
|
|
277
|
+
self._sessions[i] = deepcopy(session_dict)
|
|
275
278
|
session_updated = True
|
|
276
279
|
break
|
|
277
280
|
|
|
278
281
|
if not session_updated:
|
|
279
|
-
# Add new session
|
|
280
282
|
session_dict["created_at"] = session_dict.get("created_at", int(time.time()))
|
|
281
283
|
session_dict["updated_at"] = session_dict.get("created_at")
|
|
282
|
-
self._sessions.append(session_dict)
|
|
284
|
+
self._sessions.append(deepcopy(session_dict))
|
|
283
285
|
|
|
286
|
+
session_dict_copy = deepcopy(session_dict)
|
|
284
287
|
if not deserialize:
|
|
285
|
-
return
|
|
288
|
+
return session_dict_copy
|
|
286
289
|
|
|
287
|
-
|
|
290
|
+
if session_dict_copy["session_type"] == SessionType.AGENT:
|
|
291
|
+
return AgentSession.from_dict(session_dict_copy)
|
|
292
|
+
elif session_dict_copy["session_type"] == SessionType.TEAM:
|
|
293
|
+
return TeamSession.from_dict(session_dict_copy)
|
|
294
|
+
else:
|
|
295
|
+
return WorkflowSession.from_dict(session_dict_copy)
|
|
288
296
|
|
|
289
297
|
except Exception as e:
|
|
290
298
|
log_error(f"Exception upserting session: {e}")
|
|
@@ -378,9 +386,10 @@ class InMemoryDb(BaseDb):
|
|
|
378
386
|
try:
|
|
379
387
|
for memory_data in self._memories:
|
|
380
388
|
if memory_data.get("memory_id") == memory_id:
|
|
389
|
+
memory_data_copy = deepcopy(memory_data)
|
|
381
390
|
if not deserialize:
|
|
382
|
-
return
|
|
383
|
-
return UserMemory.from_dict(
|
|
391
|
+
return memory_data_copy
|
|
392
|
+
return UserMemory.from_dict(memory_data_copy)
|
|
384
393
|
|
|
385
394
|
return None
|
|
386
395
|
|
|
@@ -420,7 +429,7 @@ class InMemoryDb(BaseDb):
|
|
|
420
429
|
if search_content.lower() not in memory_content.lower():
|
|
421
430
|
continue
|
|
422
431
|
|
|
423
|
-
filtered_memories.append(memory_data)
|
|
432
|
+
filtered_memories.append(deepcopy(memory_data))
|
|
424
433
|
|
|
425
434
|
total_count = len(filtered_memories)
|
|
426
435
|
|
|
@@ -499,9 +508,11 @@ class InMemoryDb(BaseDb):
|
|
|
499
508
|
if not memory_updated:
|
|
500
509
|
self._memories.append(memory_dict)
|
|
501
510
|
|
|
511
|
+
memory_dict_copy = deepcopy(memory_dict)
|
|
502
512
|
if not deserialize:
|
|
503
|
-
return
|
|
504
|
-
|
|
513
|
+
return memory_dict_copy
|
|
514
|
+
|
|
515
|
+
return UserMemory.from_dict(memory_dict_copy)
|
|
505
516
|
|
|
506
517
|
except Exception as e:
|
|
507
518
|
log_warning(f"Exception upserting user memory: {e}")
|
|
@@ -657,8 +668,8 @@ class InMemoryDb(BaseDb):
|
|
|
657
668
|
# Only include necessary fields for metrics
|
|
658
669
|
filtered_session = {
|
|
659
670
|
"user_id": session.get("user_id"),
|
|
660
|
-
"session_data": session.get("session_data"),
|
|
661
|
-
"runs": session.get("runs"),
|
|
671
|
+
"session_data": deepcopy(session.get("session_data")),
|
|
672
|
+
"runs": deepcopy(session.get("runs")),
|
|
662
673
|
"created_at": session.get("created_at"),
|
|
663
674
|
"session_type": session.get("session_type"),
|
|
664
675
|
}
|
|
@@ -688,7 +699,7 @@ class InMemoryDb(BaseDb):
|
|
|
688
699
|
if ending_date and metric_date > ending_date:
|
|
689
700
|
continue
|
|
690
701
|
|
|
691
|
-
filtered_metrics.append(metric)
|
|
702
|
+
filtered_metrics.append(deepcopy(metric))
|
|
692
703
|
|
|
693
704
|
updated_at = metric.get("updated_at")
|
|
694
705
|
if updated_at and (latest_updated_at is None or updated_at > latest_updated_at):
|
|
@@ -763,7 +774,7 @@ class InMemoryDb(BaseDb):
|
|
|
763
774
|
Exception: If an error occurs during retrieval.
|
|
764
775
|
"""
|
|
765
776
|
try:
|
|
766
|
-
knowledge_items = self._knowledge
|
|
777
|
+
knowledge_items = [deepcopy(item) for item in self._knowledge]
|
|
767
778
|
|
|
768
779
|
total_count = len(knowledge_items)
|
|
769
780
|
|
|
@@ -858,9 +869,10 @@ class InMemoryDb(BaseDb):
|
|
|
858
869
|
try:
|
|
859
870
|
for run_data in self._eval_runs:
|
|
860
871
|
if run_data.get("run_id") == eval_run_id:
|
|
872
|
+
run_data_copy = deepcopy(run_data)
|
|
861
873
|
if not deserialize:
|
|
862
|
-
return
|
|
863
|
-
return EvalRunRecord.model_validate(
|
|
874
|
+
return run_data_copy
|
|
875
|
+
return EvalRunRecord.model_validate(run_data_copy)
|
|
864
876
|
|
|
865
877
|
return None
|
|
866
878
|
|
|
@@ -906,7 +918,7 @@ class InMemoryDb(BaseDb):
|
|
|
906
918
|
elif filter_type == EvalFilterType.WORKFLOW and run_data.get("workflow_id") is None:
|
|
907
919
|
continue
|
|
908
920
|
|
|
909
|
-
filtered_runs.append(run_data)
|
|
921
|
+
filtered_runs.append(deepcopy(run_data))
|
|
910
922
|
|
|
911
923
|
total_count = len(filtered_runs)
|
|
912
924
|
|
|
@@ -945,10 +957,11 @@ class InMemoryDb(BaseDb):
|
|
|
945
957
|
|
|
946
958
|
log_debug(f"Renamed eval run with id '{eval_run_id}' to '{name}'")
|
|
947
959
|
|
|
960
|
+
run_data_copy = deepcopy(run_data)
|
|
948
961
|
if not deserialize:
|
|
949
|
-
return
|
|
962
|
+
return run_data_copy
|
|
950
963
|
|
|
951
|
-
return EvalRunRecord.model_validate(
|
|
964
|
+
return EvalRunRecord.model_validate(run_data_copy)
|
|
952
965
|
|
|
953
966
|
return None
|
|
954
967
|
|
agno/db/postgres/postgres.py
CHANGED
|
@@ -756,7 +756,7 @@ class PostgresDb(BaseDb):
|
|
|
756
756
|
)
|
|
757
757
|
|
|
758
758
|
with self.Session() as sess, sess.begin():
|
|
759
|
-
stmt = postgresql.insert(table)
|
|
759
|
+
stmt: Any = postgresql.insert(table)
|
|
760
760
|
update_columns = {
|
|
761
761
|
col.name: stmt.excluded[col.name]
|
|
762
762
|
for col in table.columns
|
|
@@ -1263,13 +1263,15 @@ class PostgresDb(BaseDb):
|
|
|
1263
1263
|
results: List[Union[UserMemory, Dict[str, Any]]] = []
|
|
1264
1264
|
|
|
1265
1265
|
with self.Session() as sess, sess.begin():
|
|
1266
|
-
|
|
1266
|
+
insert_stmt = postgresql.insert(table)
|
|
1267
1267
|
update_columns = {
|
|
1268
|
-
col.name:
|
|
1268
|
+
col.name: insert_stmt.excluded[col.name]
|
|
1269
1269
|
for col in table.columns
|
|
1270
1270
|
if col.name not in ["memory_id"] # Don't update primary key
|
|
1271
1271
|
}
|
|
1272
|
-
stmt =
|
|
1272
|
+
stmt = insert_stmt.on_conflict_do_update(index_elements=["memory_id"], set_=update_columns).returning(
|
|
1273
|
+
table
|
|
1274
|
+
)
|
|
1273
1275
|
|
|
1274
1276
|
result = sess.execute(stmt, memory_records)
|
|
1275
1277
|
for row in result.fetchall():
|
agno/exceptions.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
3
|
|
|
3
4
|
from agno.models.message import Message
|
|
4
5
|
|
|
@@ -17,6 +18,8 @@ class AgentRunException(Exception):
|
|
|
17
18
|
self.agent_message = agent_message
|
|
18
19
|
self.messages = messages
|
|
19
20
|
self.stop_execution = stop_execution
|
|
21
|
+
self.type = "agent_run_error"
|
|
22
|
+
self.error_id = "agent_run_error"
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class RetryAgentRun(AgentRunException):
|
|
@@ -32,6 +35,7 @@ class RetryAgentRun(AgentRunException):
|
|
|
32
35
|
super().__init__(
|
|
33
36
|
exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=False
|
|
34
37
|
)
|
|
38
|
+
self.error_id = "retry_agent_run_error"
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
class StopAgentRun(AgentRunException):
|
|
@@ -47,6 +51,7 @@ class StopAgentRun(AgentRunException):
|
|
|
47
51
|
super().__init__(
|
|
48
52
|
exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True
|
|
49
53
|
)
|
|
54
|
+
self.error_id = "stop_agent_run_error"
|
|
50
55
|
|
|
51
56
|
|
|
52
57
|
class RunCancelledException(Exception):
|
|
@@ -54,6 +59,8 @@ class RunCancelledException(Exception):
|
|
|
54
59
|
|
|
55
60
|
def __init__(self, message: str = "Operation cancelled by user"):
|
|
56
61
|
super().__init__(message)
|
|
62
|
+
self.type = "run_cancelled_error"
|
|
63
|
+
self.error_id = "run_cancelled_error"
|
|
57
64
|
|
|
58
65
|
|
|
59
66
|
class AgnoError(Exception):
|
|
@@ -63,6 +70,8 @@ class AgnoError(Exception):
|
|
|
63
70
|
super().__init__(message)
|
|
64
71
|
self.message = message
|
|
65
72
|
self.status_code = status_code
|
|
73
|
+
self.type = "agno_error"
|
|
74
|
+
self.error_id = "agno_error"
|
|
66
75
|
|
|
67
76
|
def __str__(self) -> str:
|
|
68
77
|
return str(self.message)
|
|
@@ -78,6 +87,9 @@ class ModelProviderError(AgnoError):
|
|
|
78
87
|
self.model_name = model_name
|
|
79
88
|
self.model_id = model_id
|
|
80
89
|
|
|
90
|
+
self.type = "model_provider_error"
|
|
91
|
+
self.error_id = "model_provider_error"
|
|
92
|
+
|
|
81
93
|
|
|
82
94
|
class ModelRateLimitError(ModelProviderError):
|
|
83
95
|
"""Exception raised when a model provider returns a rate limit error."""
|
|
@@ -86,9 +98,58 @@ class ModelRateLimitError(ModelProviderError):
|
|
|
86
98
|
self, message: str, status_code: int = 429, model_name: Optional[str] = None, model_id: Optional[str] = None
|
|
87
99
|
):
|
|
88
100
|
super().__init__(message, status_code, model_name, model_id)
|
|
101
|
+
self.error_id = "model_rate_limit_error"
|
|
89
102
|
|
|
90
103
|
|
|
91
104
|
class EvalError(Exception):
|
|
92
105
|
"""Exception raised when an evaluation fails."""
|
|
93
106
|
|
|
94
107
|
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CheckTrigger(Enum):
|
|
111
|
+
"""Enum for guardrail triggers."""
|
|
112
|
+
|
|
113
|
+
OFF_TOPIC = "off_topic"
|
|
114
|
+
INPUT_NOT_ALLOWED = "input_not_allowed"
|
|
115
|
+
OUTPUT_NOT_ALLOWED = "output_not_allowed"
|
|
116
|
+
VALIDATION_FAILED = "validation_failed"
|
|
117
|
+
|
|
118
|
+
PROMPT_INJECTION = "prompt_injection"
|
|
119
|
+
PII_DETECTED = "pii_detected"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class InputCheckError(Exception):
|
|
123
|
+
"""Exception raised when an input check fails."""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
message: str,
|
|
128
|
+
check_trigger: CheckTrigger = CheckTrigger.INPUT_NOT_ALLOWED,
|
|
129
|
+
additional_data: Optional[Dict[str, Any]] = None,
|
|
130
|
+
):
|
|
131
|
+
super().__init__(message)
|
|
132
|
+
self.type = "input_check_error"
|
|
133
|
+
self.error_id = check_trigger.value
|
|
134
|
+
|
|
135
|
+
self.message = message
|
|
136
|
+
self.check_trigger = check_trigger
|
|
137
|
+
self.additional_data = additional_data
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class OutputCheckError(Exception):
|
|
141
|
+
"""Exception raised when an output check fails."""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
message: str,
|
|
146
|
+
check_trigger: CheckTrigger = CheckTrigger.OUTPUT_NOT_ALLOWED,
|
|
147
|
+
additional_data: Optional[Dict[str, Any]] = None,
|
|
148
|
+
):
|
|
149
|
+
super().__init__(message)
|
|
150
|
+
self.type = "output_check_error"
|
|
151
|
+
self.error_id = check_trigger.value
|
|
152
|
+
|
|
153
|
+
self.message = message
|
|
154
|
+
self.check_trigger = check_trigger
|
|
155
|
+
self.additional_data = additional_data
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from agno.guardrails.base import BaseGuardrail
|
|
2
|
+
from agno.guardrails.openai import OpenAIModerationGuardrail
|
|
3
|
+
from agno.guardrails.pii import PIIDetectionGuardrail
|
|
4
|
+
from agno.guardrails.prompt_injection import PromptInjectionGuardrail
|
|
5
|
+
|
|
6
|
+
__all__ = ["BaseGuardrail", "OpenAIModerationGuardrail", "PIIDetectionGuardrail", "PromptInjectionGuardrail"]
|
agno/guardrails/base.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from agno.run.agent import RunInput
|
|
5
|
+
from agno.run.team import TeamRunInput
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseGuardrail(ABC):
|
|
9
|
+
"""Abstract base class for all guardrail implementations."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
13
|
+
"""Perform synchronous guardrail check."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
18
|
+
"""Perform asynchronous guardrail check."""
|
|
19
|
+
pass
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from os import getenv
|
|
2
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
5
|
+
from agno.guardrails.base import BaseGuardrail
|
|
6
|
+
from agno.run.agent import RunInput
|
|
7
|
+
from agno.run.team import TeamRunInput
|
|
8
|
+
from agno.utils.log import log_debug
|
|
9
|
+
from agno.utils.openai import images_to_message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OpenAIModerationGuardrail(BaseGuardrail):
|
|
13
|
+
"""Guardrail for detecting content that violates OpenAI's content policy.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
moderation_model (str): The model to use for moderation. Defaults to "omni-moderation-latest".
|
|
17
|
+
raise_for_categories (List[str]): The categories to raise for.
|
|
18
|
+
Options are: "sexual", "sexual/minors", "harassment",
|
|
19
|
+
"harassment/threatening", "hate", "hate/threatening",
|
|
20
|
+
"illicit", "illicit/violent", "self-harm", "self-harm/intent",
|
|
21
|
+
"self-harm/instructions", "violence", "violence/graphic".
|
|
22
|
+
Defaults to include all categories.
|
|
23
|
+
api_key (str): The API key to use for moderation. Defaults to the OPENAI_API_KEY environment variable.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
moderation_model: str = "omni-moderation-latest",
|
|
29
|
+
raise_for_categories: Optional[
|
|
30
|
+
List[
|
|
31
|
+
Literal[
|
|
32
|
+
"sexual",
|
|
33
|
+
"sexual/minors",
|
|
34
|
+
"harassment",
|
|
35
|
+
"harassment/threatening",
|
|
36
|
+
"hate",
|
|
37
|
+
"hate/threatening",
|
|
38
|
+
"illicit",
|
|
39
|
+
"illicit/violent",
|
|
40
|
+
"self-harm",
|
|
41
|
+
"self-harm/intent",
|
|
42
|
+
"self-harm/instructions",
|
|
43
|
+
"violence",
|
|
44
|
+
"violence/graphic",
|
|
45
|
+
]
|
|
46
|
+
]
|
|
47
|
+
] = None,
|
|
48
|
+
api_key: Optional[str] = None,
|
|
49
|
+
):
|
|
50
|
+
self.moderation_model = moderation_model
|
|
51
|
+
self.api_key = api_key or getenv("OPENAI_API_KEY")
|
|
52
|
+
self.raise_for_categories = raise_for_categories
|
|
53
|
+
|
|
54
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
55
|
+
"""Check for content that violates OpenAI's content policy."""
|
|
56
|
+
try:
|
|
57
|
+
from openai import OpenAI as OpenAIClient
|
|
58
|
+
except ImportError:
|
|
59
|
+
raise ImportError("`openai` not installed. Please install using `pip install openai`")
|
|
60
|
+
|
|
61
|
+
content = run_input.input_content_string()
|
|
62
|
+
images = run_input.images
|
|
63
|
+
|
|
64
|
+
log_debug(f"Moderating content using {self.moderation_model}")
|
|
65
|
+
client = OpenAIClient(api_key=self.api_key)
|
|
66
|
+
|
|
67
|
+
model_input: Union[str, List[Dict[str, Any]]] = content
|
|
68
|
+
|
|
69
|
+
if images is not None:
|
|
70
|
+
model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
|
|
71
|
+
|
|
72
|
+
# Prepare input based on content type
|
|
73
|
+
response = client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
|
|
74
|
+
|
|
75
|
+
result = response.results[0]
|
|
76
|
+
|
|
77
|
+
if result.flagged:
|
|
78
|
+
moderation_result = {
|
|
79
|
+
"categories": result.categories.model_dump(),
|
|
80
|
+
"category_scores": result.category_scores.model_dump(),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
trigger_validation = False
|
|
84
|
+
|
|
85
|
+
if self.raise_for_categories is not None:
|
|
86
|
+
for category in self.raise_for_categories:
|
|
87
|
+
if moderation_result["categories"][category]:
|
|
88
|
+
trigger_validation = True
|
|
89
|
+
else:
|
|
90
|
+
# Since at least one category is flagged, we need to raise the check
|
|
91
|
+
trigger_validation = True
|
|
92
|
+
|
|
93
|
+
if trigger_validation:
|
|
94
|
+
raise InputCheckError(
|
|
95
|
+
"OpenAI moderation violation detected.",
|
|
96
|
+
additional_data=moderation_result,
|
|
97
|
+
check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
101
|
+
"""Check for content that violates OpenAI's content policy."""
|
|
102
|
+
try:
|
|
103
|
+
from openai import AsyncOpenAI as OpenAIClient
|
|
104
|
+
except ImportError:
|
|
105
|
+
raise ImportError("`openai` not installed. Please install using `pip install openai`")
|
|
106
|
+
|
|
107
|
+
content = run_input.input_content_string()
|
|
108
|
+
images = run_input.images
|
|
109
|
+
|
|
110
|
+
log_debug(f"Moderating content using {self.moderation_model}")
|
|
111
|
+
client = OpenAIClient(api_key=self.api_key)
|
|
112
|
+
|
|
113
|
+
model_input: Union[str, List[Dict[str, Any]]] = content
|
|
114
|
+
|
|
115
|
+
if images is not None:
|
|
116
|
+
model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
|
|
117
|
+
|
|
118
|
+
# Prepare input based on content type
|
|
119
|
+
response = await client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
|
|
120
|
+
|
|
121
|
+
result = response.results[0]
|
|
122
|
+
|
|
123
|
+
if result.flagged:
|
|
124
|
+
moderation_result = {
|
|
125
|
+
"categories": result.categories.model_dump(),
|
|
126
|
+
"category_scores": result.category_scores.model_dump(),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
trigger_validation = False
|
|
130
|
+
|
|
131
|
+
if self.raise_for_categories is not None:
|
|
132
|
+
for category in self.raise_for_categories:
|
|
133
|
+
if moderation_result["categories"][category]:
|
|
134
|
+
trigger_validation = True
|
|
135
|
+
else:
|
|
136
|
+
# Since at least one category is flagged, we need to raise the check
|
|
137
|
+
trigger_validation = True
|
|
138
|
+
|
|
139
|
+
if trigger_validation:
|
|
140
|
+
raise InputCheckError(
|
|
141
|
+
"OpenAI moderation violation detected.",
|
|
142
|
+
additional_data=moderation_result,
|
|
143
|
+
check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
|
|
144
|
+
)
|
agno/guardrails/pii.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from re import Pattern
|
|
2
|
+
from typing import Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
5
|
+
from agno.guardrails.base import BaseGuardrail
|
|
6
|
+
from agno.run.agent import RunInput
|
|
7
|
+
from agno.run.team import TeamRunInput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PIIDetectionGuardrail(BaseGuardrail):
|
|
11
|
+
"""Guardrail for detecting Personally Identifiable Information (PII).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
mask_pii: Whether to mask the PII in the input, rather than raising an error.
|
|
15
|
+
enable_ssn_check: Whether to check for Social Security Numbers. True by default.
|
|
16
|
+
enable_credit_card_check: Whether to check for credit cards. True by default.
|
|
17
|
+
enable_email_check: Whether to check for emails. True by default.
|
|
18
|
+
enable_phone_check: Whether to check for phone numbers. True by default.
|
|
19
|
+
custom_patterns: A dictionary of custom PII patterns to detect. This is added to the default patterns.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
mask_pii: bool = False,
|
|
25
|
+
enable_ssn_check: bool = True,
|
|
26
|
+
enable_credit_card_check: bool = True,
|
|
27
|
+
enable_email_check: bool = True,
|
|
28
|
+
enable_phone_check: bool = True,
|
|
29
|
+
custom_patterns: Optional[Dict[str, Pattern[str]]] = None,
|
|
30
|
+
):
|
|
31
|
+
import re
|
|
32
|
+
|
|
33
|
+
self.mask_pii = mask_pii
|
|
34
|
+
self.pii_patterns = {}
|
|
35
|
+
|
|
36
|
+
if enable_ssn_check:
|
|
37
|
+
self.pii_patterns["SSN"] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
|
38
|
+
if enable_credit_card_check:
|
|
39
|
+
self.pii_patterns["Credit Card"] = re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b")
|
|
40
|
+
if enable_email_check:
|
|
41
|
+
self.pii_patterns["Email"] = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
|
|
42
|
+
if enable_phone_check:
|
|
43
|
+
self.pii_patterns["Phone"] = re.compile(r"\b\d{3}[\s.-]?\d{3}[\s.-]?\d{4}\b")
|
|
44
|
+
|
|
45
|
+
if custom_patterns:
|
|
46
|
+
self.pii_patterns.update(custom_patterns)
|
|
47
|
+
|
|
48
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
49
|
+
"""Check for PII patterns in the input."""
|
|
50
|
+
content = run_input.input_content_string()
|
|
51
|
+
detected_pii = []
|
|
52
|
+
for pii_type, pattern in self.pii_patterns.items():
|
|
53
|
+
if pattern.search(content):
|
|
54
|
+
detected_pii.append(pii_type)
|
|
55
|
+
if detected_pii:
|
|
56
|
+
if self.mask_pii:
|
|
57
|
+
for pii_type in detected_pii:
|
|
58
|
+
|
|
59
|
+
def mask_match(match):
|
|
60
|
+
return "*" * len(match.group(0))
|
|
61
|
+
|
|
62
|
+
content = self.pii_patterns[pii_type].sub(mask_match, content)
|
|
63
|
+
run_input.input_content = content
|
|
64
|
+
return
|
|
65
|
+
else:
|
|
66
|
+
raise InputCheckError(
|
|
67
|
+
"Potential PII detected in input",
|
|
68
|
+
additional_data={"detected_pii": detected_pii},
|
|
69
|
+
check_trigger=CheckTrigger.PII_DETECTED,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
73
|
+
"""Asynchronously check for PII patterns in the input."""
|
|
74
|
+
content = run_input.input_content_string()
|
|
75
|
+
detected_pii = []
|
|
76
|
+
for pii_type, pattern in self.pii_patterns.items():
|
|
77
|
+
if pattern.search(content):
|
|
78
|
+
detected_pii.append(pii_type)
|
|
79
|
+
if detected_pii:
|
|
80
|
+
if self.mask_pii:
|
|
81
|
+
for pii_type in detected_pii:
|
|
82
|
+
|
|
83
|
+
def mask_match(match):
|
|
84
|
+
return "*" * len(match.group(0))
|
|
85
|
+
|
|
86
|
+
content = self.pii_patterns[pii_type].sub(mask_match, content)
|
|
87
|
+
run_input.input_content = content
|
|
88
|
+
return
|
|
89
|
+
else:
|
|
90
|
+
raise InputCheckError(
|
|
91
|
+
"Potential PII detected in input",
|
|
92
|
+
additional_data={"detected_pii": detected_pii},
|
|
93
|
+
check_trigger=CheckTrigger.PII_DETECTED,
|
|
94
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from agno.exceptions import CheckTrigger, InputCheckError
|
|
4
|
+
from agno.guardrails.base import BaseGuardrail
|
|
5
|
+
from agno.run.agent import RunInput
|
|
6
|
+
from agno.run.team import TeamRunInput
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PromptInjectionGuardrail(BaseGuardrail):
|
|
10
|
+
"""Guardrail for detecting prompt injection attempts.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
injection_patterns (Optional[List[str]]): A list of patterns to check for. Defaults to a list of common prompt injection patterns.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, injection_patterns: Optional[List[str]] = None):
|
|
17
|
+
self.injection_patterns = injection_patterns or [
|
|
18
|
+
"ignore previous instructions",
|
|
19
|
+
"ignore your instructions",
|
|
20
|
+
"you are now a",
|
|
21
|
+
"forget everything above",
|
|
22
|
+
"developer mode",
|
|
23
|
+
"override safety",
|
|
24
|
+
"disregard guidelines",
|
|
25
|
+
"system prompt",
|
|
26
|
+
"jailbreak",
|
|
27
|
+
"act as if",
|
|
28
|
+
"pretend you are",
|
|
29
|
+
"roleplay as",
|
|
30
|
+
"simulate being",
|
|
31
|
+
"bypass restrictions",
|
|
32
|
+
"ignore safeguards",
|
|
33
|
+
"admin override",
|
|
34
|
+
"root access",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
38
|
+
"""Check for prompt injection patterns in the input."""
|
|
39
|
+
if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
|
|
40
|
+
raise InputCheckError(
|
|
41
|
+
"Potential jailbreaking or prompt injection detected.",
|
|
42
|
+
check_trigger=CheckTrigger.PROMPT_INJECTION,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
|
|
46
|
+
"""Asynchronously check for prompt injection patterns in the input."""
|
|
47
|
+
if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
|
|
48
|
+
raise InputCheckError(
|
|
49
|
+
"Potential jailbreaking or prompt injection detected.",
|
|
50
|
+
check_trigger=CheckTrigger.PROMPT_INJECTION,
|
|
51
|
+
)
|