agno 2.0.11__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agno/agent/agent.py +607 -176
  2. agno/db/in_memory/in_memory_db.py +42 -29
  3. agno/db/mongo/mongo.py +65 -66
  4. agno/db/postgres/postgres.py +6 -4
  5. agno/db/utils.py +50 -22
  6. agno/exceptions.py +62 -1
  7. agno/guardrails/__init__.py +6 -0
  8. agno/guardrails/base.py +19 -0
  9. agno/guardrails/openai.py +144 -0
  10. agno/guardrails/pii.py +94 -0
  11. agno/guardrails/prompt_injection.py +51 -0
  12. agno/knowledge/embedder/aws_bedrock.py +9 -4
  13. agno/knowledge/embedder/azure_openai.py +54 -0
  14. agno/knowledge/embedder/base.py +2 -0
  15. agno/knowledge/embedder/cohere.py +184 -5
  16. agno/knowledge/embedder/google.py +79 -1
  17. agno/knowledge/embedder/huggingface.py +9 -4
  18. agno/knowledge/embedder/jina.py +63 -0
  19. agno/knowledge/embedder/mistral.py +78 -11
  20. agno/knowledge/embedder/ollama.py +5 -0
  21. agno/knowledge/embedder/openai.py +18 -54
  22. agno/knowledge/embedder/voyageai.py +69 -16
  23. agno/knowledge/knowledge.py +11 -4
  24. agno/knowledge/reader/pdf_reader.py +4 -3
  25. agno/knowledge/reader/website_reader.py +3 -2
  26. agno/models/base.py +125 -32
  27. agno/models/cerebras/cerebras.py +1 -0
  28. agno/models/cerebras/cerebras_openai.py +1 -0
  29. agno/models/dashscope/dashscope.py +1 -0
  30. agno/models/google/gemini.py +27 -5
  31. agno/models/openai/chat.py +13 -4
  32. agno/models/openai/responses.py +1 -1
  33. agno/models/perplexity/perplexity.py +2 -3
  34. agno/models/requesty/__init__.py +5 -0
  35. agno/models/requesty/requesty.py +49 -0
  36. agno/models/vllm/vllm.py +1 -0
  37. agno/models/xai/xai.py +1 -0
  38. agno/os/app.py +98 -126
  39. agno/os/interfaces/__init__.py +1 -0
  40. agno/os/interfaces/agui/agui.py +21 -5
  41. agno/os/interfaces/base.py +4 -2
  42. agno/os/interfaces/slack/slack.py +13 -8
  43. agno/os/interfaces/whatsapp/router.py +2 -0
  44. agno/os/interfaces/whatsapp/whatsapp.py +12 -5
  45. agno/os/mcp.py +2 -2
  46. agno/os/middleware/__init__.py +7 -0
  47. agno/os/middleware/jwt.py +233 -0
  48. agno/os/router.py +182 -46
  49. agno/os/routers/home.py +2 -2
  50. agno/os/routers/memory/memory.py +23 -1
  51. agno/os/routers/memory/schemas.py +1 -1
  52. agno/os/routers/session/session.py +20 -3
  53. agno/os/utils.py +74 -8
  54. agno/run/agent.py +120 -77
  55. agno/run/base.py +2 -13
  56. agno/run/team.py +115 -72
  57. agno/run/workflow.py +5 -15
  58. agno/session/summary.py +9 -10
  59. agno/session/team.py +2 -1
  60. agno/team/team.py +721 -169
  61. agno/tools/firecrawl.py +4 -4
  62. agno/tools/function.py +42 -2
  63. agno/tools/knowledge.py +3 -3
  64. agno/tools/searxng.py +2 -2
  65. agno/tools/serper.py +2 -2
  66. agno/tools/spider.py +2 -2
  67. agno/tools/workflow.py +4 -5
  68. agno/utils/events.py +66 -1
  69. agno/utils/hooks.py +57 -0
  70. agno/utils/media.py +11 -9
  71. agno/utils/print_response/agent.py +43 -5
  72. agno/utils/print_response/team.py +48 -12
  73. agno/utils/serialize.py +32 -0
  74. agno/vectordb/cassandra/cassandra.py +44 -4
  75. agno/vectordb/chroma/chromadb.py +79 -8
  76. agno/vectordb/clickhouse/clickhousedb.py +43 -6
  77. agno/vectordb/couchbase/couchbase.py +76 -5
  78. agno/vectordb/lancedb/lance_db.py +38 -3
  79. agno/vectordb/milvus/milvus.py +76 -4
  80. agno/vectordb/mongodb/mongodb.py +76 -4
  81. agno/vectordb/pgvector/pgvector.py +50 -6
  82. agno/vectordb/pineconedb/pineconedb.py +39 -2
  83. agno/vectordb/qdrant/qdrant.py +76 -26
  84. agno/vectordb/singlestore/singlestore.py +77 -4
  85. agno/vectordb/upstashdb/upstashdb.py +42 -2
  86. agno/vectordb/weaviate/weaviate.py +39 -3
  87. agno/workflow/types.py +5 -6
  88. agno/workflow/workflow.py +58 -2
  89. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/METADATA +4 -3
  90. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/RECORD +93 -82
  91. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/WHEEL +0 -0
  92. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/licenses/LICENSE +0 -0
  93. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/top_level.txt +0 -0
agno/exceptions.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import List, Optional, Union
1
+ from enum import Enum
2
+ from typing import Any, Dict, List, Optional, Union
2
3
 
3
4
  from agno.models.message import Message
4
5
 
@@ -17,6 +18,8 @@ class AgentRunException(Exception):
17
18
  self.agent_message = agent_message
18
19
  self.messages = messages
19
20
  self.stop_execution = stop_execution
21
+ self.type = "agent_run_error"
22
+ self.error_id = "agent_run_error"
20
23
 
21
24
 
22
25
  class RetryAgentRun(AgentRunException):
@@ -32,6 +35,7 @@ class RetryAgentRun(AgentRunException):
32
35
  super().__init__(
33
36
  exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=False
34
37
  )
38
+ self.error_id = "retry_agent_run_error"
35
39
 
36
40
 
37
41
  class StopAgentRun(AgentRunException):
@@ -47,6 +51,7 @@ class StopAgentRun(AgentRunException):
47
51
  super().__init__(
48
52
  exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True
49
53
  )
54
+ self.error_id = "stop_agent_run_error"
50
55
 
51
56
 
52
57
  class RunCancelledException(Exception):
@@ -54,6 +59,8 @@ class RunCancelledException(Exception):
54
59
 
55
60
  def __init__(self, message: str = "Operation cancelled by user"):
56
61
  super().__init__(message)
62
+ self.type = "run_cancelled_error"
63
+ self.error_id = "run_cancelled_error"
57
64
 
58
65
 
59
66
  class AgnoError(Exception):
@@ -63,6 +70,8 @@ class AgnoError(Exception):
63
70
  super().__init__(message)
64
71
  self.message = message
65
72
  self.status_code = status_code
73
+ self.type = "agno_error"
74
+ self.error_id = "agno_error"
66
75
 
67
76
  def __str__(self) -> str:
68
77
  return str(self.message)
@@ -78,6 +87,9 @@ class ModelProviderError(AgnoError):
78
87
  self.model_name = model_name
79
88
  self.model_id = model_id
80
89
 
90
+ self.type = "model_provider_error"
91
+ self.error_id = "model_provider_error"
92
+
81
93
 
82
94
  class ModelRateLimitError(ModelProviderError):
83
95
  """Exception raised when a model provider returns a rate limit error."""
@@ -86,9 +98,58 @@ class ModelRateLimitError(ModelProviderError):
86
98
  self, message: str, status_code: int = 429, model_name: Optional[str] = None, model_id: Optional[str] = None
87
99
  ):
88
100
  super().__init__(message, status_code, model_name, model_id)
101
+ self.error_id = "model_rate_limit_error"
89
102
 
90
103
 
91
104
  class EvalError(Exception):
92
105
  """Exception raised when an evaluation fails."""
93
106
 
94
107
  pass
108
+
109
+
110
+ class CheckTrigger(Enum):
111
+ """Enum for guardrail triggers."""
112
+
113
+ OFF_TOPIC = "off_topic"
114
+ INPUT_NOT_ALLOWED = "input_not_allowed"
115
+ OUTPUT_NOT_ALLOWED = "output_not_allowed"
116
+ VALIDATION_FAILED = "validation_failed"
117
+
118
+ PROMPT_INJECTION = "prompt_injection"
119
+ PII_DETECTED = "pii_detected"
120
+
121
+
122
+ class InputCheckError(Exception):
123
+ """Exception raised when an input check fails."""
124
+
125
+ def __init__(
126
+ self,
127
+ message: str,
128
+ check_trigger: CheckTrigger = CheckTrigger.INPUT_NOT_ALLOWED,
129
+ additional_data: Optional[Dict[str, Any]] = None,
130
+ ):
131
+ super().__init__(message)
132
+ self.type = "input_check_error"
133
+ self.error_id = check_trigger.value
134
+
135
+ self.message = message
136
+ self.check_trigger = check_trigger
137
+ self.additional_data = additional_data
138
+
139
+
140
+ class OutputCheckError(Exception):
141
+ """Exception raised when an output check fails."""
142
+
143
+ def __init__(
144
+ self,
145
+ message: str,
146
+ check_trigger: CheckTrigger = CheckTrigger.OUTPUT_NOT_ALLOWED,
147
+ additional_data: Optional[Dict[str, Any]] = None,
148
+ ):
149
+ super().__init__(message)
150
+ self.type = "output_check_error"
151
+ self.error_id = check_trigger.value
152
+
153
+ self.message = message
154
+ self.check_trigger = check_trigger
155
+ self.additional_data = additional_data
@@ -0,0 +1,6 @@
1
+ from agno.guardrails.base import BaseGuardrail
2
+ from agno.guardrails.openai import OpenAIModerationGuardrail
3
+ from agno.guardrails.pii import PIIDetectionGuardrail
4
+ from agno.guardrails.prompt_injection import PromptInjectionGuardrail
5
+
6
+ __all__ = ["BaseGuardrail", "OpenAIModerationGuardrail", "PIIDetectionGuardrail", "PromptInjectionGuardrail"]
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Union
3
+
4
+ from agno.run.agent import RunInput
5
+ from agno.run.team import TeamRunInput
6
+
7
+
8
+ class BaseGuardrail(ABC):
9
+ """Abstract base class for all guardrail implementations."""
10
+
11
+ @abstractmethod
12
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
13
+ """Perform synchronous guardrail check."""
14
+ pass
15
+
16
+ @abstractmethod
17
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
18
+ """Perform asynchronous guardrail check."""
19
+ pass
@@ -0,0 +1,144 @@
1
+ from os import getenv
2
+ from typing import Any, Dict, List, Literal, Optional, Union
3
+
4
+ from agno.exceptions import CheckTrigger, InputCheckError
5
+ from agno.guardrails.base import BaseGuardrail
6
+ from agno.run.agent import RunInput
7
+ from agno.run.team import TeamRunInput
8
+ from agno.utils.log import log_debug
9
+ from agno.utils.openai import images_to_message
10
+
11
+
12
+ class OpenAIModerationGuardrail(BaseGuardrail):
13
+ """Guardrail for detecting content that violates OpenAI's content policy.
14
+
15
+ Args:
16
+ moderation_model (str): The model to use for moderation. Defaults to "omni-moderation-latest".
17
+ raise_for_categories (List[str]): The categories to raise for.
18
+ Options are: "sexual", "sexual/minors", "harassment",
19
+ "harassment/threatening", "hate", "hate/threatening",
20
+ "illicit", "illicit/violent", "self-harm", "self-harm/intent",
21
+ "self-harm/instructions", "violence", "violence/graphic".
22
+ Defaults to include all categories.
23
+ api_key (str): The API key to use for moderation. Defaults to the OPENAI_API_KEY environment variable.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ moderation_model: str = "omni-moderation-latest",
29
+ raise_for_categories: Optional[
30
+ List[
31
+ Literal[
32
+ "sexual",
33
+ "sexual/minors",
34
+ "harassment",
35
+ "harassment/threatening",
36
+ "hate",
37
+ "hate/threatening",
38
+ "illicit",
39
+ "illicit/violent",
40
+ "self-harm",
41
+ "self-harm/intent",
42
+ "self-harm/instructions",
43
+ "violence",
44
+ "violence/graphic",
45
+ ]
46
+ ]
47
+ ] = None,
48
+ api_key: Optional[str] = None,
49
+ ):
50
+ self.moderation_model = moderation_model
51
+ self.api_key = api_key or getenv("OPENAI_API_KEY")
52
+ self.raise_for_categories = raise_for_categories
53
+
54
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
55
+ """Check for content that violates OpenAI's content policy."""
56
+ try:
57
+ from openai import OpenAI as OpenAIClient
58
+ except ImportError:
59
+ raise ImportError("`openai` not installed. Please install using `pip install openai`")
60
+
61
+ content = run_input.input_content_string()
62
+ images = run_input.images
63
+
64
+ log_debug(f"Moderating content using {self.moderation_model}")
65
+ client = OpenAIClient(api_key=self.api_key)
66
+
67
+ model_input: Union[str, List[Dict[str, Any]]] = content
68
+
69
+ if images is not None:
70
+ model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
71
+
72
+ # Prepare input based on content type
73
+ response = client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
74
+
75
+ result = response.results[0]
76
+
77
+ if result.flagged:
78
+ moderation_result = {
79
+ "categories": result.categories.model_dump(),
80
+ "category_scores": result.category_scores.model_dump(),
81
+ }
82
+
83
+ trigger_validation = False
84
+
85
+ if self.raise_for_categories is not None:
86
+ for category in self.raise_for_categories:
87
+ if moderation_result["categories"][category]:
88
+ trigger_validation = True
89
+ else:
90
+ # Since at least one category is flagged, we need to raise the check
91
+ trigger_validation = True
92
+
93
+ if trigger_validation:
94
+ raise InputCheckError(
95
+ "OpenAI moderation violation detected.",
96
+ additional_data=moderation_result,
97
+ check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
98
+ )
99
+
100
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
101
+ """Check for content that violates OpenAI's content policy."""
102
+ try:
103
+ from openai import AsyncOpenAI as OpenAIClient
104
+ except ImportError:
105
+ raise ImportError("`openai` not installed. Please install using `pip install openai`")
106
+
107
+ content = run_input.input_content_string()
108
+ images = run_input.images
109
+
110
+ log_debug(f"Moderating content using {self.moderation_model}")
111
+ client = OpenAIClient(api_key=self.api_key)
112
+
113
+ model_input: Union[str, List[Dict[str, Any]]] = content
114
+
115
+ if images is not None:
116
+ model_input = [{"type": "text", "text": content}, *images_to_message(images=images)]
117
+
118
+ # Prepare input based on content type
119
+ response = await client.moderations.create(model=self.moderation_model, input=model_input) # type: ignore
120
+
121
+ result = response.results[0]
122
+
123
+ if result.flagged:
124
+ moderation_result = {
125
+ "categories": result.categories.model_dump(),
126
+ "category_scores": result.category_scores.model_dump(),
127
+ }
128
+
129
+ trigger_validation = False
130
+
131
+ if self.raise_for_categories is not None:
132
+ for category in self.raise_for_categories:
133
+ if moderation_result["categories"][category]:
134
+ trigger_validation = True
135
+ else:
136
+ # Since at least one category is flagged, we need to raise the check
137
+ trigger_validation = True
138
+
139
+ if trigger_validation:
140
+ raise InputCheckError(
141
+ "OpenAI moderation violation detected.",
142
+ additional_data=moderation_result,
143
+ check_trigger=CheckTrigger.INPUT_NOT_ALLOWED,
144
+ )
agno/guardrails/pii.py ADDED
@@ -0,0 +1,94 @@
1
+ from re import Pattern
2
+ from typing import Dict, Optional, Union
3
+
4
+ from agno.exceptions import CheckTrigger, InputCheckError
5
+ from agno.guardrails.base import BaseGuardrail
6
+ from agno.run.agent import RunInput
7
+ from agno.run.team import TeamRunInput
8
+
9
+
10
+ class PIIDetectionGuardrail(BaseGuardrail):
11
+ """Guardrail for detecting Personally Identifiable Information (PII).
12
+
13
+ Args:
14
+ mask_pii: Whether to mask the PII in the input, rather than raising an error.
15
+ enable_ssn_check: Whether to check for Social Security Numbers. True by default.
16
+ enable_credit_card_check: Whether to check for credit cards. True by default.
17
+ enable_email_check: Whether to check for emails. True by default.
18
+ enable_phone_check: Whether to check for phone numbers. True by default.
19
+ custom_patterns: A dictionary of custom PII patterns to detect. This is added to the default patterns.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ mask_pii: bool = False,
25
+ enable_ssn_check: bool = True,
26
+ enable_credit_card_check: bool = True,
27
+ enable_email_check: bool = True,
28
+ enable_phone_check: bool = True,
29
+ custom_patterns: Optional[Dict[str, Pattern[str]]] = None,
30
+ ):
31
+ import re
32
+
33
+ self.mask_pii = mask_pii
34
+ self.pii_patterns = {}
35
+
36
+ if enable_ssn_check:
37
+ self.pii_patterns["SSN"] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
38
+ if enable_credit_card_check:
39
+ self.pii_patterns["Credit Card"] = re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b")
40
+ if enable_email_check:
41
+ self.pii_patterns["Email"] = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
42
+ if enable_phone_check:
43
+ self.pii_patterns["Phone"] = re.compile(r"\b\d{3}[\s.-]?\d{3}[\s.-]?\d{4}\b")
44
+
45
+ if custom_patterns:
46
+ self.pii_patterns.update(custom_patterns)
47
+
48
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
49
+ """Check for PII patterns in the input."""
50
+ content = run_input.input_content_string()
51
+ detected_pii = []
52
+ for pii_type, pattern in self.pii_patterns.items():
53
+ if pattern.search(content):
54
+ detected_pii.append(pii_type)
55
+ if detected_pii:
56
+ if self.mask_pii:
57
+ for pii_type in detected_pii:
58
+
59
+ def mask_match(match):
60
+ return "*" * len(match.group(0))
61
+
62
+ content = self.pii_patterns[pii_type].sub(mask_match, content)
63
+ run_input.input_content = content
64
+ return
65
+ else:
66
+ raise InputCheckError(
67
+ "Potential PII detected in input",
68
+ additional_data={"detected_pii": detected_pii},
69
+ check_trigger=CheckTrigger.PII_DETECTED,
70
+ )
71
+
72
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
73
+ """Asynchronously check for PII patterns in the input."""
74
+ content = run_input.input_content_string()
75
+ detected_pii = []
76
+ for pii_type, pattern in self.pii_patterns.items():
77
+ if pattern.search(content):
78
+ detected_pii.append(pii_type)
79
+ if detected_pii:
80
+ if self.mask_pii:
81
+ for pii_type in detected_pii:
82
+
83
+ def mask_match(match):
84
+ return "*" * len(match.group(0))
85
+
86
+ content = self.pii_patterns[pii_type].sub(mask_match, content)
87
+ run_input.input_content = content
88
+ return
89
+ else:
90
+ raise InputCheckError(
91
+ "Potential PII detected in input",
92
+ additional_data={"detected_pii": detected_pii},
93
+ check_trigger=CheckTrigger.PII_DETECTED,
94
+ )
@@ -0,0 +1,51 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from agno.exceptions import CheckTrigger, InputCheckError
4
+ from agno.guardrails.base import BaseGuardrail
5
+ from agno.run.agent import RunInput
6
+ from agno.run.team import TeamRunInput
7
+
8
+
9
+ class PromptInjectionGuardrail(BaseGuardrail):
10
+ """Guardrail for detecting prompt injection attempts.
11
+
12
+ Args:
13
+ injection_patterns (Optional[List[str]]): A list of patterns to check for. Defaults to a list of common prompt injection patterns.
14
+ """
15
+
16
+ def __init__(self, injection_patterns: Optional[List[str]] = None):
17
+ self.injection_patterns = injection_patterns or [
18
+ "ignore previous instructions",
19
+ "ignore your instructions",
20
+ "you are now a",
21
+ "forget everything above",
22
+ "developer mode",
23
+ "override safety",
24
+ "disregard guidelines",
25
+ "system prompt",
26
+ "jailbreak",
27
+ "act as if",
28
+ "pretend you are",
29
+ "roleplay as",
30
+ "simulate being",
31
+ "bypass restrictions",
32
+ "ignore safeguards",
33
+ "admin override",
34
+ "root access",
35
+ ]
36
+
37
+ def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
38
+ """Check for prompt injection patterns in the input."""
39
+ if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
40
+ raise InputCheckError(
41
+ "Potential jailbreaking or prompt injection detected.",
42
+ check_trigger=CheckTrigger.PROMPT_INJECTION,
43
+ )
44
+
45
+ async def async_check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
46
+ """Asynchronously check for prompt injection patterns in the input."""
47
+ if any(keyword in run_input.input_content_string().lower() for keyword in self.injection_patterns):
48
+ raise InputCheckError(
49
+ "Potential jailbreaking or prompt injection detected.",
50
+ check_trigger=CheckTrigger.PROMPT_INJECTION,
51
+ )
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from agno.exceptions import AgnoError, ModelProviderError
7
7
  from agno.knowledge.embedder.base import Embedder
8
- from agno.utils.log import log_error, logger
8
+ from agno.utils.log import log_error, log_warning
9
9
 
10
10
  try:
11
11
  from boto3 import client as AwsClient
@@ -69,6 +69,11 @@ class AwsBedrockEmbedder(Embedder):
69
69
  client_params: Optional[Dict[str, Any]] = None
70
70
  client: Optional[AwsClient] = None
71
71
 
72
+ def __post_init__(self):
73
+ if self.enable_batch:
74
+ log_warning("AwsBedrockEmbedder does not support batch embeddings, setting enable_batch to False")
75
+ self.enable_batch = False
76
+
72
77
  def get_client(self) -> AwsClient:
73
78
  """
74
79
  Returns an AWS Bedrock client.
@@ -220,10 +225,10 @@ class AwsBedrockEmbedder(Embedder):
220
225
  # Fallback to the first available embedding type
221
226
  for embedding_type in response["embeddings"]:
222
227
  return response["embeddings"][embedding_type][0]
223
- logger.warning("No embeddings found in response")
228
+ log_warning("No embeddings found in response")
224
229
  return []
225
230
  except Exception as e:
226
- logger.warning(f"Error extracting embeddings: {e}")
231
+ log_warning(f"Error extracting embeddings: {e}")
227
232
  return []
228
233
 
229
234
  def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
@@ -286,7 +291,7 @@ class AwsBedrockEmbedder(Embedder):
286
291
  # Fallback to the first available embedding type
287
292
  for embedding_type in response_body["embeddings"]:
288
293
  return response_body["embeddings"][embedding_type][0]
289
- logger.warning("No embeddings found in response")
294
+ log_warning("No embeddings found in response")
290
295
  return []
291
296
  except ClientError as e:
292
297
  log_error(f"Unexpected error calling Bedrock API: {str(e)}")
@@ -154,3 +154,57 @@ class AzureOpenAIEmbedder(Embedder):
154
154
  embedding = response.data[0].embedding
155
155
  usage = response.usage
156
156
  return embedding, usage.model_dump()
157
+
158
+ async def async_get_embeddings_batch_and_usage(
159
+ self, texts: List[str]
160
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
161
+ """
162
+ Get embeddings and usage for multiple texts in batches.
163
+
164
+ Args:
165
+ texts: List of text strings to embed
166
+
167
+ Returns:
168
+ Tuple of (List of embedding vectors, List of usage dictionaries)
169
+ """
170
+ all_embeddings = []
171
+ all_usage = []
172
+ logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
173
+
174
+ for i in range(0, len(texts), self.batch_size):
175
+ batch_texts = texts[i : i + self.batch_size]
176
+
177
+ req: Dict[str, Any] = {
178
+ "input": batch_texts,
179
+ "model": self.id,
180
+ "encoding_format": self.encoding_format,
181
+ }
182
+ if self.user is not None:
183
+ req["user"] = self.user
184
+ if self.id.startswith("text-embedding-3"):
185
+ req["dimensions"] = self.dimensions
186
+ if self.request_params:
187
+ req.update(self.request_params)
188
+
189
+ try:
190
+ response: CreateEmbeddingResponse = await self.aclient.embeddings.create(**req)
191
+ batch_embeddings = [data.embedding for data in response.data]
192
+ all_embeddings.extend(batch_embeddings)
193
+
194
+ # For each embedding in the batch, add the same usage information
195
+ usage_dict = response.usage.model_dump() if response.usage else None
196
+ all_usage.extend([usage_dict] * len(batch_embeddings))
197
+ except Exception as e:
198
+ logger.warning(f"Error in async batch embedding: {e}")
199
+ # Fallback to individual calls for this batch
200
+ for text in batch_texts:
201
+ try:
202
+ embedding, usage = await self.async_get_embedding_and_usage(text)
203
+ all_embeddings.append(embedding)
204
+ all_usage.append(usage)
205
+ except Exception as e2:
206
+ logger.warning(f"Error in individual async embedding fallback: {e2}")
207
+ all_embeddings.append([])
208
+ all_usage.append(None)
209
+
210
+ return all_embeddings, all_usage
@@ -7,6 +7,8 @@ class Embedder:
7
7
  """Base class for managing embedders"""
8
8
 
9
9
  dimensions: Optional[int] = 1536
10
+ enable_batch: bool = False
11
+ batch_size: int = 100 # Number of texts to process in each API call
10
12
 
11
13
  def get_embedding(self, text: str) -> List[float]:
12
14
  raise NotImplementedError