agno 2.0.10__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. agno/agent/agent.py +608 -175
  2. agno/db/in_memory/in_memory_db.py +42 -29
  3. agno/db/postgres/postgres.py +6 -4
  4. agno/exceptions.py +62 -1
  5. agno/guardrails/__init__.py +6 -0
  6. agno/guardrails/base.py +19 -0
  7. agno/guardrails/openai.py +144 -0
  8. agno/guardrails/pii.py +94 -0
  9. agno/guardrails/prompt_injection.py +51 -0
  10. agno/knowledge/embedder/aws_bedrock.py +9 -4
  11. agno/knowledge/embedder/azure_openai.py +54 -0
  12. agno/knowledge/embedder/base.py +2 -0
  13. agno/knowledge/embedder/cohere.py +184 -5
  14. agno/knowledge/embedder/google.py +79 -1
  15. agno/knowledge/embedder/huggingface.py +9 -4
  16. agno/knowledge/embedder/jina.py +63 -0
  17. agno/knowledge/embedder/mistral.py +78 -11
  18. agno/knowledge/embedder/ollama.py +5 -0
  19. agno/knowledge/embedder/openai.py +18 -54
  20. agno/knowledge/embedder/voyageai.py +69 -16
  21. agno/knowledge/knowledge.py +5 -4
  22. agno/knowledge/reader/pdf_reader.py +4 -3
  23. agno/knowledge/reader/website_reader.py +3 -2
  24. agno/models/base.py +125 -32
  25. agno/models/cerebras/cerebras.py +1 -0
  26. agno/models/cerebras/cerebras_openai.py +1 -0
  27. agno/models/dashscope/dashscope.py +1 -0
  28. agno/models/google/gemini.py +27 -5
  29. agno/models/litellm/chat.py +17 -0
  30. agno/models/openai/chat.py +13 -4
  31. agno/models/perplexity/perplexity.py +2 -3
  32. agno/models/requesty/__init__.py +5 -0
  33. agno/models/requesty/requesty.py +49 -0
  34. agno/models/vllm/vllm.py +1 -0
  35. agno/models/xai/xai.py +1 -0
  36. agno/os/app.py +167 -148
  37. agno/os/interfaces/whatsapp/router.py +2 -0
  38. agno/os/mcp.py +1 -1
  39. agno/os/middleware/__init__.py +7 -0
  40. agno/os/middleware/jwt.py +233 -0
  41. agno/os/router.py +181 -45
  42. agno/os/routers/home.py +2 -2
  43. agno/os/routers/memory/memory.py +23 -1
  44. agno/os/routers/memory/schemas.py +1 -1
  45. agno/os/routers/session/session.py +20 -3
  46. agno/os/utils.py +172 -8
  47. agno/run/agent.py +120 -77
  48. agno/run/team.py +115 -72
  49. agno/run/workflow.py +5 -15
  50. agno/session/summary.py +9 -10
  51. agno/session/team.py +2 -1
  52. agno/team/team.py +720 -168
  53. agno/tools/firecrawl.py +4 -4
  54. agno/tools/function.py +42 -2
  55. agno/tools/knowledge.py +3 -3
  56. agno/tools/searxng.py +2 -2
  57. agno/tools/serper.py +2 -2
  58. agno/tools/spider.py +2 -2
  59. agno/tools/workflow.py +4 -5
  60. agno/utils/events.py +66 -1
  61. agno/utils/hooks.py +57 -0
  62. agno/utils/media.py +11 -9
  63. agno/utils/print_response/agent.py +43 -5
  64. agno/utils/print_response/team.py +48 -12
  65. agno/vectordb/cassandra/cassandra.py +44 -4
  66. agno/vectordb/chroma/chromadb.py +79 -8
  67. agno/vectordb/clickhouse/clickhousedb.py +43 -6
  68. agno/vectordb/couchbase/couchbase.py +76 -5
  69. agno/vectordb/lancedb/lance_db.py +38 -3
  70. agno/vectordb/llamaindex/__init__.py +3 -0
  71. agno/vectordb/milvus/milvus.py +76 -4
  72. agno/vectordb/mongodb/mongodb.py +76 -4
  73. agno/vectordb/pgvector/pgvector.py +50 -6
  74. agno/vectordb/pineconedb/pineconedb.py +39 -2
  75. agno/vectordb/qdrant/qdrant.py +76 -26
  76. agno/vectordb/singlestore/singlestore.py +77 -4
  77. agno/vectordb/upstashdb/upstashdb.py +42 -2
  78. agno/vectordb/weaviate/weaviate.py +39 -3
  79. agno/workflow/types.py +1 -0
  80. agno/workflow/workflow.py +58 -2
  81. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/METADATA +4 -3
  82. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/RECORD +85 -75
  83. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/WHEEL +0 -0
  84. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/licenses/LICENSE +0 -0
  85. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import time
2
+ from copy import deepcopy
2
3
  from datetime import date, datetime, timedelta, timezone
3
4
  from typing import Any, Dict, List, Optional, Tuple, Union
4
5
  from uuid import uuid4
@@ -107,15 +108,17 @@ class InMemoryDb(BaseDb):
107
108
  if session_data.get("session_type") != session_type_value:
108
109
  continue
109
110
 
111
+ session_data_copy = deepcopy(session_data)
112
+
110
113
  if not deserialize:
111
- return session_data
114
+ return session_data_copy
112
115
 
113
116
  if session_type == SessionType.AGENT:
114
- return AgentSession.from_dict(session_data)
117
+ return AgentSession.from_dict(session_data_copy)
115
118
  elif session_type == SessionType.TEAM:
116
- return TeamSession.from_dict(session_data)
119
+ return TeamSession.from_dict(session_data_copy)
117
120
  else:
118
- return WorkflowSession.from_dict(session_data)
121
+ return WorkflowSession.from_dict(session_data_copy)
119
122
 
120
123
  return None
121
124
 
@@ -188,7 +191,7 @@ class InMemoryDb(BaseDb):
188
191
  if session_data.get("session_type") != session_type_value:
189
192
  continue
190
193
 
191
- filtered_sessions.append(session_data)
194
+ filtered_sessions.append(deepcopy(session_data))
192
195
 
193
196
  total_count = len(filtered_sessions)
194
197
 
@@ -233,15 +236,16 @@ class InMemoryDb(BaseDb):
233
236
 
234
237
  log_debug(f"Renamed session with id '{session_id}' to '{session_name}'")
235
238
 
239
+ session_copy = deepcopy(session)
236
240
  if not deserialize:
237
- return session
241
+ return session_copy
238
242
 
239
243
  if session_type == SessionType.AGENT:
240
- return AgentSession.from_dict(session)
244
+ return AgentSession.from_dict(session_copy)
241
245
  elif session_type == SessionType.TEAM:
242
- return TeamSession.from_dict(session)
246
+ return TeamSession.from_dict(session_copy)
243
247
  else:
244
- return WorkflowSession.from_dict(session)
248
+ return WorkflowSession.from_dict(session_copy)
245
249
 
246
250
  return None
247
251
 
@@ -269,22 +273,26 @@ class InMemoryDb(BaseDb):
269
273
  if existing_session.get("session_id") == session_dict.get("session_id") and self._matches_session_key(
270
274
  existing_session, session
271
275
  ):
272
- # Update existing session
273
276
  session_dict["updated_at"] = int(time.time())
274
- self._sessions[i] = session_dict
277
+ self._sessions[i] = deepcopy(session_dict)
275
278
  session_updated = True
276
279
  break
277
280
 
278
281
  if not session_updated:
279
- # Add new session
280
282
  session_dict["created_at"] = session_dict.get("created_at", int(time.time()))
281
283
  session_dict["updated_at"] = session_dict.get("created_at")
282
- self._sessions.append(session_dict)
284
+ self._sessions.append(deepcopy(session_dict))
283
285
 
286
+ session_dict_copy = deepcopy(session_dict)
284
287
  if not deserialize:
285
- return session_dict
288
+ return session_dict_copy
286
289
 
287
- return session
290
+ if session_dict_copy["session_type"] == SessionType.AGENT:
291
+ return AgentSession.from_dict(session_dict_copy)
292
+ elif session_dict_copy["session_type"] == SessionType.TEAM:
293
+ return TeamSession.from_dict(session_dict_copy)
294
+ else:
295
+ return WorkflowSession.from_dict(session_dict_copy)
288
296
 
289
297
  except Exception as e:
290
298
  log_error(f"Exception upserting session: {e}")
@@ -378,9 +386,10 @@ class InMemoryDb(BaseDb):
378
386
  try:
379
387
  for memory_data in self._memories:
380
388
  if memory_data.get("memory_id") == memory_id:
389
+ memory_data_copy = deepcopy(memory_data)
381
390
  if not deserialize:
382
- return memory_data
383
- return UserMemory.from_dict(memory_data)
391
+ return memory_data_copy
392
+ return UserMemory.from_dict(memory_data_copy)
384
393
 
385
394
  return None
386
395
 
@@ -420,7 +429,7 @@ class InMemoryDb(BaseDb):
420
429
  if search_content.lower() not in memory_content.lower():
421
430
  continue
422
431
 
423
- filtered_memories.append(memory_data)
432
+ filtered_memories.append(deepcopy(memory_data))
424
433
 
425
434
  total_count = len(filtered_memories)
426
435
 
@@ -499,9 +508,11 @@ class InMemoryDb(BaseDb):
499
508
  if not memory_updated:
500
509
  self._memories.append(memory_dict)
501
510
 
511
+ memory_dict_copy = deepcopy(memory_dict)
502
512
  if not deserialize:
503
- return memory_dict
504
- return UserMemory.from_dict(memory_dict)
513
+ return memory_dict_copy
514
+
515
+ return UserMemory.from_dict(memory_dict_copy)
505
516
 
506
517
  except Exception as e:
507
518
  log_warning(f"Exception upserting user memory: {e}")
@@ -657,8 +668,8 @@ class InMemoryDb(BaseDb):
657
668
  # Only include necessary fields for metrics
658
669
  filtered_session = {
659
670
  "user_id": session.get("user_id"),
660
- "session_data": session.get("session_data"),
661
- "runs": session.get("runs"),
671
+ "session_data": deepcopy(session.get("session_data")),
672
+ "runs": deepcopy(session.get("runs")),
662
673
  "created_at": session.get("created_at"),
663
674
  "session_type": session.get("session_type"),
664
675
  }
@@ -688,7 +699,7 @@ class InMemoryDb(BaseDb):
688
699
  if ending_date and metric_date > ending_date:
689
700
  continue
690
701
 
691
- filtered_metrics.append(metric)
702
+ filtered_metrics.append(deepcopy(metric))
692
703
 
693
704
  updated_at = metric.get("updated_at")
694
705
  if updated_at and (latest_updated_at is None or updated_at > latest_updated_at):
@@ -763,7 +774,7 @@ class InMemoryDb(BaseDb):
763
774
  Exception: If an error occurs during retrieval.
764
775
  """
765
776
  try:
766
- knowledge_items = self._knowledge.copy()
777
+ knowledge_items = [deepcopy(item) for item in self._knowledge]
767
778
 
768
779
  total_count = len(knowledge_items)
769
780
 
@@ -858,9 +869,10 @@ class InMemoryDb(BaseDb):
858
869
  try:
859
870
  for run_data in self._eval_runs:
860
871
  if run_data.get("run_id") == eval_run_id:
872
+ run_data_copy = deepcopy(run_data)
861
873
  if not deserialize:
862
- return run_data
863
- return EvalRunRecord.model_validate(run_data)
874
+ return run_data_copy
875
+ return EvalRunRecord.model_validate(run_data_copy)
864
876
 
865
877
  return None
866
878
 
@@ -906,7 +918,7 @@ class InMemoryDb(BaseDb):
906
918
  elif filter_type == EvalFilterType.WORKFLOW and run_data.get("workflow_id") is None:
907
919
  continue
908
920
 
909
- filtered_runs.append(run_data)
921
+ filtered_runs.append(deepcopy(run_data))
910
922
 
911
923
  total_count = len(filtered_runs)
912
924
 
@@ -945,10 +957,11 @@ class InMemoryDb(BaseDb):
945
957
 
946
958
  log_debug(f"Renamed eval run with id '{eval_run_id}' to '{name}'")
947
959
 
960
+ run_data_copy = deepcopy(run_data)
948
961
  if not deserialize:
949
- return run_data
962
+ return run_data_copy
950
963
 
951
- return EvalRunRecord.model_validate(run_data)
964
+ return EvalRunRecord.model_validate(run_data_copy)
952
965
 
953
966
  return None
954
967
 
@@ -756,7 +756,7 @@ class PostgresDb(BaseDb):
756
756
  )
757
757
 
758
758
  with self.Session() as sess, sess.begin():
759
- stmt = postgresql.insert(table)
759
+ stmt: Any = postgresql.insert(table)
760
760
  update_columns = {
761
761
  col.name: stmt.excluded[col.name]
762
762
  for col in table.columns
@@ -1263,13 +1263,15 @@ class PostgresDb(BaseDb):
1263
1263
  results: List[Union[UserMemory, Dict[str, Any]]] = []
1264
1264
 
1265
1265
  with self.Session() as sess, sess.begin():
1266
- stmt = postgresql.insert(table)
1266
+ insert_stmt = postgresql.insert(table)
1267
1267
  update_columns = {
1268
- col.name: stmt.excluded[col.name]
1268
+ col.name: insert_stmt.excluded[col.name]
1269
1269
  for col in table.columns
1270
1270
  if col.name not in ["memory_id"] # Don't update primary key
1271
1271
  }
1272
- stmt = stmt.on_conflict_do_update(index_elements=["memory_id"], set_=update_columns).returning(table)
1272
+ stmt = insert_stmt.on_conflict_do_update(index_elements=["memory_id"], set_=update_columns).returning(
1273
+ table
1274
+ )
1273
1275
 
1274
1276
  result = sess.execute(stmt, memory_records)
1275
1277
  for row in result.fetchall():
agno/exceptions.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import List, Optional, Union
1
+ from enum import Enum
2
+ from typing import Any, Dict, List, Optional, Union
2
3
 
3
4
  from agno.models.message import Message
4
5
 
@@ -17,6 +18,8 @@ class AgentRunException(Exception):
17
18
  self.agent_message = agent_message
18
19
  self.messages = messages
19
20
  self.stop_execution = stop_execution
21
+ self.type = "agent_run_error"
22
+ self.error_id = "agent_run_error"
20
23
 
21
24
 
22
25
  class RetryAgentRun(AgentRunException):
@@ -32,6 +35,7 @@ class RetryAgentRun(AgentRunException):
32
35
  super().__init__(
33
36
  exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=False
34
37
  )
38
+ self.error_id = "retry_agent_run_error"
35
39
 
36
40
 
37
41
  class StopAgentRun(AgentRunException):
@@ -47,6 +51,7 @@ class StopAgentRun(AgentRunException):
47
51
  super().__init__(
48
52
  exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True
49
53
  )
54
+ self.error_id = "stop_agent_run_error"
50
55
 
51
56
 
52
57
  class RunCancelledException(Exception):
@@ -54,6 +59,8 @@ class RunCancelledException(Exception):
54
59
 
55
60
  def __init__(self, message: str = "Operation cancelled by user"):
56
61
  super().__init__(message)
62
+ self.type = "run_cancelled_error"
63
+ self.error_id = "run_cancelled_error"
57
64
 
58
65
 
59
66
  class AgnoError(Exception):
@@ -63,6 +70,8 @@ class AgnoError(Exception):
63
70
  super().__init__(message)
64
71
  self.message = message
65
72
  self.status_code = status_code
73
+ self.type = "agno_error"
74
+ self.error_id = "agno_error"
66
75
 
67
76
  def __str__(self) -> str:
68
77
  return str(self.message)
@@ -78,6 +87,9 @@ class ModelProviderError(AgnoError):
78
87
  self.model_name = model_name
79
88
  self.model_id = model_id
80
89
 
90
+ self.type = "model_provider_error"
91
+ self.error_id = "model_provider_error"
92
+
81
93
 
82
94
  class ModelRateLimitError(ModelProviderError):
83
95
  """Exception raised when a model provider returns a rate limit error."""
@@ -86,9 +98,58 @@ class ModelRateLimitError(ModelProviderError):
86
98
  self, message: str, status_code: int = 429, model_name: Optional[str] = None, model_id: Optional[str] = None
87
99
  ):
88
100
  super().__init__(message, status_code, model_name, model_id)
101
+ self.error_id = "model_rate_limit_error"
89
102
 
90
103
 
91
104
  class EvalError(Exception):
92
105
  """Exception raised when an evaluation fails."""
93
106
 
94
107
  pass
108
+
109
+
110
+ class CheckTrigger(Enum):
111
+ """Enum for guardrail triggers."""
112
+
113
+ OFF_TOPIC = "off_topic"
114
+ INPUT_NOT_ALLOWED = "input_not_allowed"
115
+ OUTPUT_NOT_ALLOWED = "output_not_allowed"
116
+ VALIDATION_FAILED = "validation_failed"
117
+
118
+ PROMPT_INJECTION = "prompt_injection"
119
+ PII_DETECTED = "pii_detected"
120
+
121
+
122
+ class InputCheckError(Exception):
123
+ """Exception raised when an input check fails."""
124
+
125
+ def __init__(
126
+ self,
127
+ message: str,
128
+ check_trigger: CheckTrigger = CheckTrigger.INPUT_NOT_ALLOWED,
129
+ additional_data: Optional[Dict[str, Any]] = None,
130
+ ):
131
+ super().__init__(message)
132
+ self.type = "input_check_error"
133
+ self.error_id = check_trigger.value
134
+
135
+ self.message = message
136
+ self.check_trigger = check_trigger
137
+ self.additional_data = additional_data
138
+
139
+
140
+ class OutputCheckError(Exception):
141
+ """Exception raised when an output check fails."""
142
+
143
+ def __init__(
144
+ self,
145
+ message: str,
146
+ check_trigger: CheckTrigger = CheckTrigger.OUTPUT_NOT_ALLOWED,
147
+ additional_data: Optional[Dict[str, Any]] = None,
148
+ ):
149
+ super().__init__(message)
150
+ self.type = "output_check_error"
151
+ self.error_id = check_trigger.value
152
+
153
+ self.message = message
154
+ self.check_trigger = check_trigger
155
+ self.additional_data = additional_data
@@ -0,0 +1,6 @@
1
+ from agno.guardrails.base import BaseGuardrail
2
+ from agno.guardrails.openai import OpenAIModerationGuardrail
3
+ from agno.guardrails.pii import PIIDetectionGuardrail
4
+ from agno.guardrails.prompt_injection import PromptInjectionGuardrail
5
+
6
+ __all__ = ["BaseGuardrail", "OpenAIModerationGuardrail", "PIIDetectionGuardrail", "PromptInjectionGuardrail"]
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Union
3
+
4
+ from agno.run.agent import RunInput
5
+ from agno.run.team import TeamRunInput
6
+
7
+
8
+ class BaseGuardrail(ABC):
9
+ """Abstract base class for all guardrail implementations."""
10
+
11
+ @abstractmethod
12
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
13
+ """Perform synchronous guardrail check."""
14
+ pass
15
+
16
+ @abstractmethod
17
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
18
+ """Perform asynchronous guardrail check."""
19
+ pass
@@ -0,0 +1,144 @@
1
+ from os import getenv
2
+ from typing import Any, Dict, List, Literal, Optional, Union
3
+
4
+ from agno.exceptions import CheckTrigger, InputCheckError
5
+ from agno.guardrails.base import BaseGuardrail
6
+ from agno.run.agent import RunInput
7
+ from agno.run.team import TeamRunInput
8
+ from agno.utils.log import log_debug
9
+ from agno.utils.openai import images_to_message
10
+
11
+
12
+ class OpenAIModerationGuardrail(BaseGuardrail):
13
+ """Guardrail for detecting content that violates OpenAI's content policy.
14
+
15
+ Args:
16
+ moderation_model (str): The model to use for moderation. Defaults to "omni-moderation-latest".
17
+ raise_for_categories (List[str]): The categories to raise for.
18
+ Options are: "sexual", "sexual/minors", "harassment",
19
+ "harassment/threatening", "hate", "hate/threatening",
20
+ "illicit", "illicit/violent", "self-harm", "self-harm/intent",
21
+ "self-harm/instructions", "violence", "violence/graphic".
22
+ Defaults to include all categories.
23
+ api_key (str): The API key to use for moderation. Defaults to the OPENAI_API_KEY environment variable.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ moderation_model: str = "omni-moderation-latest",
29
+ raise_for_categories: Optional[
30
+ List[
31
+ Literal[
32
+ "sexual",
33
+ "sexual/minors",
34
+ "harassment",
35
+ "harassment/threatening",
36
+ "hate",
37
+ "hate/threatening",
38
+ "illicit",
39
+ "illicit/violent",
40
+ "self-harm",
41
+ "self-harm/intent",
42
+ "self-harm/instructions",
43
+ "violence",
44
+ "violence/graphic",
45
+ ]
46
+ ]
47
+ ] = None,
48
+ api_key: Optional[str] = None,
49
+ ):
50
+ self.moderation_model = moderation_model
51
+ self.api_key = api_key or getenv("OPENAI_API_KEY")
52
+ self.raise_for_categories = raise_for_categories
53
+
54
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
55
+ """Check for content that violates OpenAI's content policy."""
56
+ try:
57
+ from openai import OpenAI as OpenAIClient
58
+ except ImportError:
59
+ raise ImportError("`openai` not installed. Please install using `pip install openai`")
60
+
61
+ content = run_input.input_content_string()
62
+ images = run_input.images
63
+
64
+ log_debug(f"Moderating content using {self.moderation_model}")
65
+ client = OpenAIClient(api_key=self.api_key)
66
+
67
+ model_input: Union[str, List[Dict[str, Any]]] = content
68
+
69
+ if images is not None:
70
+ model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
71
+
72
+ # Prepare input based on content type
73
+ response = client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
74
+
75
+ result = response.results[0]
76
+
77
+ if result.flagged:
78
+ moderation_result = {
79
+ "categories": result.categories.model_dump(),
80
+ "category_scores": result.category_scores.model_dump(),
81
+ }
82
+
83
+ trigger_validation = False
84
+
85
+ if self.raise_for_categories is not None:
86
+ for category in self.raise_for_categories:
87
+ if moderation_result["categories"][category]:
88
+ trigger_validation = True
89
+ else:
90
+ # Since at least one category is flagged, we need to raise the check
91
+ trigger_validation = True
92
+
93
+ if trigger_validation:
94
+ raise InputCheckError(
95
+ "OpenAI moderation violation detected.",
96
+ additional_data=moderation_result,
97
+ check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
98
+ )
99
+
100
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
101
+ """Check for content that violates OpenAI's content policy."""
102
+ try:
103
+ from openai import AsyncOpenAI as OpenAIClient
104
+ except ImportError:
105
+ raise ImportError("`openai` not installed. Please install using `pip install openai`")
106
+
107
+ content = run_input.input_content_string()
108
+ images = run_input.images
109
+
110
+ log_debug(f"Moderating content using {self.moderation_model}")
111
+ client = OpenAIClient(api_key=self.api_key)
112
+
113
+ model_input: Union[str, List[Dict[str, Any]]] = content
114
+
115
+ if images is not None:
116
+ model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
117
+
118
+ # Prepare input based on content type
119
+ response = await client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
120
+
121
+ result = response.results[0]
122
+
123
+ if result.flagged:
124
+ moderation_result = {
125
+ "categories": result.categories.model_dump(),
126
+ "category_scores": result.category_scores.model_dump(),
127
+ }
128
+
129
+ trigger_validation = False
130
+
131
+ if self.raise_for_categories is not None:
132
+ for category in self.raise_for_categories:
133
+ if moderation_result["categories"][category]:
134
+ trigger_validation = True
135
+ else:
136
+ # Since at least one category is flagged, we need to raise the check
137
+ trigger_validation = True
138
+
139
+ if trigger_validation:
140
+ raise InputCheckError(
141
+ "OpenAI moderation violation detected.",
142
+ additional_data=moderation_result,
143
+ check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
144
+ )
agno/guardrails/pii.py ADDED
@@ -0,0 +1,94 @@
1
+ from re import Pattern
2
+ from typing import Dict, Optional, Union
3
+
4
+ from agno.exceptions import CheckTrigger, InputCheckError
5
+ from agno.guardrails.base import BaseGuardrail
6
+ from agno.run.agent import RunInput
7
+ from agno.run.team import TeamRunInput
8
+
9
+
10
+ class PIIDetectionGuardrail(BaseGuardrail):
11
+ """Guardrail for detecting Personally Identifiable Information (PII).
12
+
13
+ Args:
14
+ mask_pii: Whether to mask the PII in the input, rather than raising an error.
15
+ enable_ssn_check: Whether to check for Social Security Numbers. True by default.
16
+ enable_credit_card_check: Whether to check for credit cards. True by default.
17
+ enable_email_check: Whether to check for emails. True by default.
18
+ enable_phone_check: Whether to check for phone numbers. True by default.
19
+ custom_patterns: A dictionary of custom PII patterns to detect. This is added to the default patterns.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ mask_pii: bool = False,
25
+ enable_ssn_check: bool = True,
26
+ enable_credit_card_check: bool = True,
27
+ enable_email_check: bool = True,
28
+ enable_phone_check: bool = True,
29
+ custom_patterns: Optional[Dict[str, Pattern[str]]] = None,
30
+ ):
31
+ import re
32
+
33
+ self.mask_pii = mask_pii
34
+ self.pii_patterns = {}
35
+
36
+ if enable_ssn_check:
37
+ self.pii_patterns["SSN"] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
38
+ if enable_credit_card_check:
39
+ self.pii_patterns["Credit Card"] = re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b")
40
+ if enable_email_check:
41
+ self.pii_patterns["Email"] = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
42
+ if enable_phone_check:
43
+ self.pii_patterns["Phone"] = re.compile(r"\b\d{3}[\s.-]?\d{3}[\s.-]?\d{4}\b")
44
+
45
+ if custom_patterns:
46
+ self.pii_patterns.update(custom_patterns)
47
+
48
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
49
+ """Check for PII patterns in the input."""
50
+ content = run_input.input_content_string()
51
+ detected_pii = []
52
+ for pii_type, pattern in self.pii_patterns.items():
53
+ if pattern.search(content):
54
+ detected_pii.append(pii_type)
55
+ if detected_pii:
56
+ if self.mask_pii:
57
+ for pii_type in detected_pii:
58
+
59
+ def mask_match(match):
60
+ return "*" * len(match.group(0))
61
+
62
+ content = self.pii_patterns[pii_type].sub(mask_match, content)
63
+ run_input.input_content = content
64
+ return
65
+ else:
66
+ raise InputCheckError(
67
+ "Potential PII detected in input",
68
+ additional_data={"detected_pii": detected_pii},
69
+ check_trigger=CheckTrigger.PII_DETECTED,
70
+ )
71
+
72
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
73
+ """Asynchronously check for PII patterns in the input."""
74
+ content = run_input.input_content_string()
75
+ detected_pii = []
76
+ for pii_type, pattern in self.pii_patterns.items():
77
+ if pattern.search(content):
78
+ detected_pii.append(pii_type)
79
+ if detected_pii:
80
+ if self.mask_pii:
81
+ for pii_type in detected_pii:
82
+
83
+ def mask_match(match):
84
+ return "*" * len(match.group(0))
85
+
86
+ content = self.pii_patterns[pii_type].sub(mask_match, content)
87
+ run_input.input_content = content
88
+ return
89
+ else:
90
+ raise InputCheckError(
91
+ "Potential PII detected in input",
92
+ additional_data={"detected_pii": detected_pii},
93
+ check_trigger=CheckTrigger.PII_DETECTED,
94
+ )
@@ -0,0 +1,51 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from agno.exceptions import CheckTrigger, InputCheckError
4
+ from agno.guardrails.base import BaseGuardrail
5
+ from agno.run.agent import RunInput
6
+ from agno.run.team import TeamRunInput
7
+
8
+
9
+ class PromptInjectionGuardrail(BaseGuardrail):
10
+ """Guardrail for detecting prompt injection attempts.
11
+
12
+ Args:
13
+ injection_patterns (Optional[List[str]]): A list of patterns to check for. Defaults to a list of common prompt injection patterns.
14
+ """
15
+
16
+ def __init__(self, injection_patterns: Optional[List[str]] = None):
17
+ self.injection_patterns = injection_patterns or [
18
+ "ignore previous instructions",
19
+ "ignore your instructions",
20
+ "you are now a",
21
+ "forget everything above",
22
+ "developer mode",
23
+ "override safety",
24
+ "disregard guidelines",
25
+ "system prompt",
26
+ "jailbreak",
27
+ "act as if",
28
+ "pretend you are",
29
+ "roleplay as",
30
+ "simulate being",
31
+ "bypass restrictions",
32
+ "ignore safeguards",
33
+ "admin override",
34
+ "root access",
35
+ ]
36
+
37
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
38
+ """Check for prompt injection patterns in the input."""
39
+ if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
40
+ raise InputCheckError(
41
+ "Potential jailbreaking or prompt injection detected.",
42
+ check_trigger=CheckTrigger.PROMPT_INJECTION,
43
+ )
44
+
45
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
46
+ """Asynchronously check for prompt injection patterns in the input."""
47
+ if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
48
+ raise InputCheckError(
49
+ "Potential jailbreaking or prompt injection detected.",
50
+ check_trigger=CheckTrigger.PROMPT_INJECTION,
51
+ )