agno 2.0.11__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +607 -176
- agno/db/in_memory/in_memory_db.py +42 -29
- agno/db/mongo/mongo.py +65 -66
- agno/db/postgres/postgres.py +6 -4
- agno/db/utils.py +50 -22
- agno/exceptions.py +62 -1
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +51 -0
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/ollama.py +5 -0
- agno/knowledge/embedder/openai.py +18 -54
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +11 -4
- agno/knowledge/reader/pdf_reader.py +4 -3
- agno/knowledge/reader/website_reader.py +3 -2
- agno/models/base.py +125 -32
- agno/models/cerebras/cerebras.py +1 -0
- agno/models/cerebras/cerebras_openai.py +1 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/google/gemini.py +27 -5
- agno/models/openai/chat.py +13 -4
- agno/models/openai/responses.py +1 -1
- agno/models/perplexity/perplexity.py +2 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +49 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +1 -0
- agno/os/app.py +98 -126
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/agui/agui.py +21 -5
- agno/os/interfaces/base.py +4 -2
- agno/os/interfaces/slack/slack.py +13 -8
- agno/os/interfaces/whatsapp/router.py +2 -0
- agno/os/interfaces/whatsapp/whatsapp.py +12 -5
- agno/os/mcp.py +2 -2
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +182 -46
- agno/os/routers/home.py +2 -2
- agno/os/routers/memory/memory.py +23 -1
- agno/os/routers/memory/schemas.py +1 -1
- agno/os/routers/session/session.py +20 -3
- agno/os/utils.py +74 -8
- agno/run/agent.py +120 -77
- agno/run/base.py +2 -13
- agno/run/team.py +115 -72
- agno/run/workflow.py +5 -15
- agno/session/summary.py +9 -10
- agno/session/team.py +2 -1
- agno/team/team.py +721 -169
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +42 -2
- agno/tools/knowledge.py +3 -3
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/spider.py +2 -2
- agno/tools/workflow.py +4 -5
- agno/utils/events.py +66 -1
- agno/utils/hooks.py +57 -0
- agno/utils/media.py +11 -9
- agno/utils/print_response/agent.py +43 -5
- agno/utils/print_response/team.py +48 -12
- agno/utils/serialize.py +32 -0
- agno/vectordb/cassandra/cassandra.py +44 -4
- agno/vectordb/chroma/chromadb.py +79 -8
- agno/vectordb/clickhouse/clickhousedb.py +43 -6
- agno/vectordb/couchbase/couchbase.py +76 -5
- agno/vectordb/lancedb/lance_db.py +38 -3
- agno/vectordb/milvus/milvus.py +76 -4
- agno/vectordb/mongodb/mongodb.py +76 -4
- agno/vectordb/pgvector/pgvector.py +50 -6
- agno/vectordb/pineconedb/pineconedb.py +39 -2
- agno/vectordb/qdrant/qdrant.py +76 -26
- agno/vectordb/singlestore/singlestore.py +77 -4
- agno/vectordb/upstashdb/upstashdb.py +42 -2
- agno/vectordb/weaviate/weaviate.py +39 -3
- agno/workflow/types.py +5 -6
- agno/workflow/workflow.py +58 -2
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/METADATA +4 -3
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/RECORD +93 -82
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/WHEEL +0 -0
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import time
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
3
4
|
|
|
4
5
|
from agno.knowledge.embedder.base import Embedder
|
|
5
|
-
from agno.utils.log import
|
|
6
|
+
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
6
7
|
|
|
7
8
|
try:
|
|
8
9
|
from cohere import AsyncClient as AsyncCohereClient
|
|
@@ -22,6 +23,7 @@ class CohereEmbedder(Embedder):
|
|
|
22
23
|
client_params: Optional[Dict[str, Any]] = None
|
|
23
24
|
cohere_client: Optional[CohereClient] = None
|
|
24
25
|
async_client: Optional[AsyncCohereClient] = None
|
|
26
|
+
exponential_backoff: bool = False # Enable exponential backoff on rate limits
|
|
25
27
|
|
|
26
28
|
@property
|
|
27
29
|
def client(self) -> CohereClient:
|
|
@@ -61,6 +63,111 @@ class CohereEmbedder(Embedder):
|
|
|
61
63
|
request_params.update(self.request_params)
|
|
62
64
|
return self.client.embed(texts=[text], **request_params)
|
|
63
65
|
|
|
66
|
+
def _get_batch_request_params(self) -> Dict[str, Any]:
|
|
67
|
+
"""Get request parameters for batch embedding calls."""
|
|
68
|
+
request_params: Dict[str, Any] = {}
|
|
69
|
+
|
|
70
|
+
if self.id:
|
|
71
|
+
request_params["model"] = self.id
|
|
72
|
+
if self.input_type:
|
|
73
|
+
request_params["input_type"] = self.input_type
|
|
74
|
+
if self.embedding_types:
|
|
75
|
+
request_params["embedding_types"] = self.embedding_types
|
|
76
|
+
if self.request_params:
|
|
77
|
+
request_params.update(self.request_params)
|
|
78
|
+
|
|
79
|
+
return request_params
|
|
80
|
+
|
|
81
|
+
def _is_rate_limit_error(self, error: Exception) -> bool:
|
|
82
|
+
"""Check if the error is a rate limiting error."""
|
|
83
|
+
if hasattr(error, "status_code") and error.status_code == 429:
|
|
84
|
+
return True
|
|
85
|
+
error_str = str(error).lower()
|
|
86
|
+
return any(
|
|
87
|
+
phrase in error_str
|
|
88
|
+
for phrase in ["rate limit", "too many requests", "429", "trial key", "api calls / minute"]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _exponential_backoff_sleep(self, attempt: int, base_delay: float = 1.0) -> None:
|
|
92
|
+
"""Sleep with exponential backoff."""
|
|
93
|
+
delay = base_delay * (2**attempt) + (time.time() % 1) # Add jitter
|
|
94
|
+
log_debug(f"Rate limited, waiting {delay:.2f} seconds before retry (attempt {attempt + 1})")
|
|
95
|
+
time.sleep(delay)
|
|
96
|
+
|
|
97
|
+
async def _async_rate_limit_backoff_sleep(self, attempt: int) -> None:
|
|
98
|
+
"""Async version of rate-limit-aware backoff for APIs with per-minute limits."""
|
|
99
|
+
import asyncio
|
|
100
|
+
|
|
101
|
+
# For 40 req/min APIs like Cohere Trial, we need longer waits
|
|
102
|
+
if attempt == 0:
|
|
103
|
+
delay = 15.0 # Wait 15 seconds (1/4 of minute window)
|
|
104
|
+
elif attempt == 1:
|
|
105
|
+
delay = 30.0 # Wait 30 seconds (1/2 of minute window)
|
|
106
|
+
else:
|
|
107
|
+
delay = 60.0 # Wait full minute for window reset
|
|
108
|
+
|
|
109
|
+
# Add small jitter
|
|
110
|
+
delay += time.time() % 3
|
|
111
|
+
|
|
112
|
+
log_debug(
|
|
113
|
+
f"Async rate limit backoff, waiting {delay:.1f} seconds for rate limit window reset (attempt {attempt + 1})"
|
|
114
|
+
)
|
|
115
|
+
await asyncio.sleep(delay)
|
|
116
|
+
|
|
117
|
+
async def _async_batch_with_retry(
|
|
118
|
+
self, texts: List[str], max_retries: int = 3
|
|
119
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
120
|
+
"""Execute async batch embedding with rate-limit-aware backoff for rate limiting."""
|
|
121
|
+
|
|
122
|
+
log_debug(f"Starting async batch retry for {len(texts)} texts with max_retries={max_retries}")
|
|
123
|
+
|
|
124
|
+
for attempt in range(max_retries + 1):
|
|
125
|
+
try:
|
|
126
|
+
request_params = self._get_batch_request_params()
|
|
127
|
+
response: Union[
|
|
128
|
+
EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse
|
|
129
|
+
] = await self.aclient.embed(texts=texts, **request_params)
|
|
130
|
+
|
|
131
|
+
# Extract embeddings from response
|
|
132
|
+
if isinstance(response, EmbeddingsFloatsEmbedResponse):
|
|
133
|
+
batch_embeddings = response.embeddings
|
|
134
|
+
elif isinstance(response, EmbeddingsByTypeEmbedResponse):
|
|
135
|
+
batch_embeddings = response.embeddings.float_ if response.embeddings.float_ else []
|
|
136
|
+
else:
|
|
137
|
+
log_warning("No embeddings found in response")
|
|
138
|
+
batch_embeddings = []
|
|
139
|
+
|
|
140
|
+
# Extract usage information
|
|
141
|
+
usage = response.meta.billed_units if response.meta else None
|
|
142
|
+
usage_dict = usage.model_dump() if usage else None
|
|
143
|
+
all_usage = [usage_dict] * len(batch_embeddings)
|
|
144
|
+
|
|
145
|
+
log_debug(f"Async batch embedding succeeded on attempt {attempt + 1}")
|
|
146
|
+
return batch_embeddings, all_usage
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
if self._is_rate_limit_error(e):
|
|
150
|
+
if not self.exponential_backoff:
|
|
151
|
+
log_warning(
|
|
152
|
+
"Rate limit detected. To enable automatic backoff retry, set enable_backoff=True when creating the embedder."
|
|
153
|
+
)
|
|
154
|
+
raise e
|
|
155
|
+
|
|
156
|
+
log_info(f"Async rate limit detected on attempt {attempt + 1}")
|
|
157
|
+
if attempt < max_retries:
|
|
158
|
+
await self._async_rate_limit_backoff_sleep(attempt)
|
|
159
|
+
continue
|
|
160
|
+
else:
|
|
161
|
+
log_warning(f"Async max retries ({max_retries}) reached for rate limiting")
|
|
162
|
+
raise e
|
|
163
|
+
else:
|
|
164
|
+
log_debug(f"Async non-rate-limit error on attempt {attempt + 1}: {e}")
|
|
165
|
+
raise e
|
|
166
|
+
|
|
167
|
+
# This should never be reached, but just in case
|
|
168
|
+
log_error("Could not create embeddings. End of retry loop reached.")
|
|
169
|
+
return [], []
|
|
170
|
+
|
|
64
171
|
def get_embedding(self, text: str) -> List[float]:
|
|
65
172
|
response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text)
|
|
66
173
|
try:
|
|
@@ -69,10 +176,10 @@ class CohereEmbedder(Embedder):
|
|
|
69
176
|
elif isinstance(response, EmbeddingsByTypeEmbedResponse):
|
|
70
177
|
return response.embeddings.float_[0] if response.embeddings.float_ else []
|
|
71
178
|
else:
|
|
72
|
-
|
|
179
|
+
log_warning("No embeddings found")
|
|
73
180
|
return []
|
|
74
181
|
except Exception as e:
|
|
75
|
-
|
|
182
|
+
log_warning(e)
|
|
76
183
|
return []
|
|
77
184
|
|
|
78
185
|
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
|
|
@@ -110,10 +217,10 @@ class CohereEmbedder(Embedder):
|
|
|
110
217
|
elif isinstance(response, EmbeddingsByTypeEmbedResponse):
|
|
111
218
|
return response.embeddings.float_[0] if response.embeddings.float_ else []
|
|
112
219
|
else:
|
|
113
|
-
|
|
220
|
+
log_warning("No embeddings found")
|
|
114
221
|
return []
|
|
115
222
|
except Exception as e:
|
|
116
|
-
|
|
223
|
+
log_warning(e)
|
|
117
224
|
return []
|
|
118
225
|
|
|
119
226
|
async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
|
|
@@ -142,3 +249,75 @@ class CohereEmbedder(Embedder):
|
|
|
142
249
|
if usage:
|
|
143
250
|
return embedding, usage.model_dump()
|
|
144
251
|
return embedding, None
|
|
252
|
+
|
|
253
|
+
async def async_get_embeddings_batch_and_usage(
|
|
254
|
+
self, texts: List[str]
|
|
255
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
256
|
+
"""
|
|
257
|
+
Get embeddings and usage for multiple texts in batches (async version).
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
texts: List of text strings to embed
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
s, List of usage dictionaries)
|
|
264
|
+
"""
|
|
265
|
+
all_embeddings = []
|
|
266
|
+
all_usage = []
|
|
267
|
+
log_info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size} (async)")
|
|
268
|
+
|
|
269
|
+
for i in range(0, len(texts), self.batch_size):
|
|
270
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
# Use retry logic for batch processing
|
|
274
|
+
batch_embeddings, batch_usage = await self._async_batch_with_retry(batch_texts)
|
|
275
|
+
all_embeddings.extend(batch_embeddings)
|
|
276
|
+
all_usage.extend(batch_usage)
|
|
277
|
+
|
|
278
|
+
except Exception as e:
|
|
279
|
+
log_warning(f"Async batch embedding failed after retries: {e}")
|
|
280
|
+
|
|
281
|
+
# Check if this is a rate limit error and backoff is disabled
|
|
282
|
+
if self._is_rate_limit_error(e) and not self.exponential_backoff:
|
|
283
|
+
log_warning("Rate limit hit and backoff is disabled. Failing immediately.")
|
|
284
|
+
raise e
|
|
285
|
+
|
|
286
|
+
# Only fall back to individual calls for non-rate-limit errors
|
|
287
|
+
# For rate limit errors, we should reduce batch size instead
|
|
288
|
+
if self._is_rate_limit_error(e):
|
|
289
|
+
log_warning("Rate limit hit even after retries. Consider reducing batch_size or upgrading API key.")
|
|
290
|
+
# Try with smaller batch size
|
|
291
|
+
if len(batch_texts) > 1:
|
|
292
|
+
smaller_batch_size = max(1, len(batch_texts) // 2)
|
|
293
|
+
log_info(f"Retrying with smaller batch size: {smaller_batch_size}")
|
|
294
|
+
for j in range(0, len(batch_texts), smaller_batch_size):
|
|
295
|
+
small_batch = batch_texts[j : j + smaller_batch_size]
|
|
296
|
+
try:
|
|
297
|
+
small_embeddings, small_usage = await self._async_batch_with_retry(small_batch)
|
|
298
|
+
all_embeddings.extend(small_embeddings)
|
|
299
|
+
all_usage.extend(small_usage)
|
|
300
|
+
except Exception as e3:
|
|
301
|
+
log_error(f"Failed even with reduced batch size: {e3}")
|
|
302
|
+
# Fall back to empty results for this batch
|
|
303
|
+
all_embeddings.extend([[] for _ in small_batch])
|
|
304
|
+
all_usage.extend([None for _ in small_batch])
|
|
305
|
+
else:
|
|
306
|
+
# Single item already failed, add empty result
|
|
307
|
+
log_debug("Single item failed, adding empty result")
|
|
308
|
+
all_embeddings.append([])
|
|
309
|
+
all_usage.append(None)
|
|
310
|
+
else:
|
|
311
|
+
# For non-rate-limit errors, fall back to individual calls
|
|
312
|
+
log_debug("Non-rate-limit error, falling back to individual calls")
|
|
313
|
+
for text in batch_texts:
|
|
314
|
+
try:
|
|
315
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
316
|
+
all_embeddings.append(embedding)
|
|
317
|
+
all_usage.append(usage)
|
|
318
|
+
except Exception as e2:
|
|
319
|
+
log_warning(f"Error in individual async embedding fallback: {e2}")
|
|
320
|
+
all_embeddings.append([])
|
|
321
|
+
all_usage.append(None)
|
|
322
|
+
|
|
323
|
+
return all_embeddings, all_usage
|
|
@@ -3,7 +3,7 @@ from os import getenv
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
from agno.knowledge.embedder.base import Embedder
|
|
6
|
-
from agno.utils.log import log_error, log_info
|
|
6
|
+
from agno.utils.log import log_error, log_info, log_warning
|
|
7
7
|
|
|
8
8
|
try:
|
|
9
9
|
from google import genai
|
|
@@ -178,3 +178,81 @@ class GeminiEmbedder(Embedder):
|
|
|
178
178
|
except Exception as e:
|
|
179
179
|
log_error(f"Error extracting embeddings: {e}")
|
|
180
180
|
return [], usage
|
|
181
|
+
|
|
182
|
+
async def async_get_embeddings_batch_and_usage(
|
|
183
|
+
self, texts: List[str]
|
|
184
|
+
) -> Tuple[List[List[float]], List[Optional[Dict[str, Any]]]]:
|
|
185
|
+
"""
|
|
186
|
+
Get embeddings and usage for multiple texts in batches.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
texts: List of text strings to embed
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
193
|
+
"""
|
|
194
|
+
all_embeddings: List[List[float]] = []
|
|
195
|
+
all_usage: List[Optional[Dict[str, Any]]] = []
|
|
196
|
+
log_info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
|
|
197
|
+
|
|
198
|
+
for i in range(0, len(texts), self.batch_size):
|
|
199
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
200
|
+
|
|
201
|
+
# If a user provides a model id with the `models/` prefix, we need to remove it
|
|
202
|
+
_id = self.id
|
|
203
|
+
if _id.startswith("models/"):
|
|
204
|
+
_id = _id.split("/")[-1]
|
|
205
|
+
|
|
206
|
+
_request_params: Dict[str, Any] = {"contents": batch_texts, "model": _id, "config": {}}
|
|
207
|
+
if self.dimensions:
|
|
208
|
+
_request_params["config"]["output_dimensionality"] = self.dimensions
|
|
209
|
+
if self.task_type:
|
|
210
|
+
_request_params["config"]["task_type"] = self.task_type
|
|
211
|
+
if self.title:
|
|
212
|
+
_request_params["config"]["title"] = self.title
|
|
213
|
+
if not _request_params["config"]:
|
|
214
|
+
del _request_params["config"]
|
|
215
|
+
|
|
216
|
+
if self.request_params:
|
|
217
|
+
_request_params.update(self.request_params)
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
response = await self.aclient.aio.models.embed_content(**_request_params)
|
|
221
|
+
|
|
222
|
+
# Extract embeddings from batch response
|
|
223
|
+
if response.embeddings:
|
|
224
|
+
batch_embeddings = []
|
|
225
|
+
for embedding in response.embeddings:
|
|
226
|
+
if embedding.values is not None:
|
|
227
|
+
batch_embeddings.append(embedding.values)
|
|
228
|
+
else:
|
|
229
|
+
batch_embeddings.append([])
|
|
230
|
+
all_embeddings.extend(batch_embeddings)
|
|
231
|
+
else:
|
|
232
|
+
# If no embeddings, add empty lists for each text in batch
|
|
233
|
+
all_embeddings.extend([[] for _ in batch_texts])
|
|
234
|
+
|
|
235
|
+
# Extract usage information
|
|
236
|
+
usage_dict = None
|
|
237
|
+
if response.metadata and hasattr(response.metadata, "billable_character_count"):
|
|
238
|
+
usage_dict = {"billable_character_count": response.metadata.billable_character_count}
|
|
239
|
+
|
|
240
|
+
# Add same usage info for each embedding in the batch
|
|
241
|
+
all_usage.extend([usage_dict] * len(batch_texts))
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
log_warning(f"Error in async batch embedding: {e}")
|
|
245
|
+
# Fallback to individual calls for this batch
|
|
246
|
+
for text in batch_texts:
|
|
247
|
+
try:
|
|
248
|
+
text_embedding: List[float]
|
|
249
|
+
text_usage: Optional[Dict[str, Any]]
|
|
250
|
+
text_embedding, text_usage = await self.async_get_embedding_and_usage(text)
|
|
251
|
+
all_embeddings.append(text_embedding)
|
|
252
|
+
all_usage.append(text_usage)
|
|
253
|
+
except Exception as e2:
|
|
254
|
+
log_warning(f"Error in individual async embedding fallback: {e2}")
|
|
255
|
+
all_embeddings.append([])
|
|
256
|
+
all_usage.append(None)
|
|
257
|
+
|
|
258
|
+
return all_embeddings, all_usage
|
|
@@ -3,12 +3,12 @@ from os import getenv
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
from agno.knowledge.embedder.base import Embedder
|
|
6
|
-
from agno.utils.log import
|
|
6
|
+
from agno.utils.log import log_error, log_warning
|
|
7
7
|
|
|
8
8
|
try:
|
|
9
9
|
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
|
10
10
|
except ImportError:
|
|
11
|
-
|
|
11
|
+
log_error("`huggingface-hub` not installed, please run `pip install huggingface-hub`")
|
|
12
12
|
raise
|
|
13
13
|
|
|
14
14
|
|
|
@@ -22,6 +22,11 @@ class HuggingfaceCustomEmbedder(Embedder):
|
|
|
22
22
|
huggingface_client: Optional[InferenceClient] = None
|
|
23
23
|
async_client: Optional[AsyncInferenceClient] = None
|
|
24
24
|
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
if self.enable_batch:
|
|
27
|
+
log_warning("HuggingfaceEmbedder does not support batch embeddings, setting enable_batch to False")
|
|
28
|
+
self.enable_batch = False
|
|
29
|
+
|
|
25
30
|
@property
|
|
26
31
|
def client(self) -> InferenceClient:
|
|
27
32
|
if self.huggingface_client:
|
|
@@ -61,7 +66,7 @@ class HuggingfaceCustomEmbedder(Embedder):
|
|
|
61
66
|
else:
|
|
62
67
|
return list(response)
|
|
63
68
|
except Exception as e:
|
|
64
|
-
|
|
69
|
+
log_warning(f"Failed to process embeddings: {e}")
|
|
65
70
|
return []
|
|
66
71
|
|
|
67
72
|
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
|
|
@@ -80,7 +85,7 @@ class HuggingfaceCustomEmbedder(Embedder):
|
|
|
80
85
|
else:
|
|
81
86
|
return list(response)
|
|
82
87
|
except Exception as e:
|
|
83
|
-
|
|
88
|
+
log_warning(f"Failed to process embeddings: {e}")
|
|
84
89
|
return []
|
|
85
90
|
|
|
86
91
|
async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
|
agno/knowledge/embedder/jina.py
CHANGED
|
@@ -117,3 +117,66 @@ class JinaEmbedder(Embedder):
|
|
|
117
117
|
except Exception as e:
|
|
118
118
|
logger.warning(f"Failed to get embedding and usage: {e}")
|
|
119
119
|
return [], None
|
|
120
|
+
|
|
121
|
+
async def _async_batch_response(self, texts: List[str]) -> Dict[str, Any]:
|
|
122
|
+
"""Async batch version of _response using aiohttp."""
|
|
123
|
+
data = {
|
|
124
|
+
"model": self.id,
|
|
125
|
+
"late_chunking": self.late_chunking,
|
|
126
|
+
"dimensions": self.dimensions,
|
|
127
|
+
"embedding_type": self.embedding_type,
|
|
128
|
+
"input": texts, # Jina API expects a list of texts for batch processing
|
|
129
|
+
}
|
|
130
|
+
if self.user is not None:
|
|
131
|
+
data["user"] = self.user
|
|
132
|
+
if self.request_params:
|
|
133
|
+
data.update(self.request_params)
|
|
134
|
+
|
|
135
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else None
|
|
136
|
+
|
|
137
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
138
|
+
async with session.post(self.base_url, headers=self._get_headers(), json=data) as response:
|
|
139
|
+
response.raise_for_status()
|
|
140
|
+
return await response.json()
|
|
141
|
+
|
|
142
|
+
async def async_get_embeddings_batch_and_usage(
|
|
143
|
+
self, texts: List[str]
|
|
144
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
145
|
+
"""
|
|
146
|
+
Get embeddings and usage for multiple texts in batches.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
texts: List of text strings to embed
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
153
|
+
"""
|
|
154
|
+
all_embeddings = []
|
|
155
|
+
all_usage = []
|
|
156
|
+
logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
|
|
157
|
+
|
|
158
|
+
for i in range(0, len(texts), self.batch_size):
|
|
159
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
result = await self._async_batch_response(batch_texts)
|
|
163
|
+
batch_embeddings = [data["embedding"] for data in result["data"]]
|
|
164
|
+
all_embeddings.extend(batch_embeddings)
|
|
165
|
+
|
|
166
|
+
# For each embedding in the batch, add the same usage information
|
|
167
|
+
usage_dict = result.get("usage")
|
|
168
|
+
all_usage.extend([usage_dict] * len(batch_embeddings))
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.warning(f"Error in async batch embedding: {e}")
|
|
171
|
+
# Fallback to individual calls for this batch
|
|
172
|
+
for text in batch_texts:
|
|
173
|
+
try:
|
|
174
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
175
|
+
all_embeddings.append(embedding)
|
|
176
|
+
all_usage.append(usage)
|
|
177
|
+
except Exception as e2:
|
|
178
|
+
logger.warning(f"Error in individual async embedding fallback: {e2}")
|
|
179
|
+
all_embeddings.append([])
|
|
180
|
+
all_usage.append(None)
|
|
181
|
+
|
|
182
|
+
return all_embeddings, all_usage
|
|
@@ -3,13 +3,13 @@ from os import getenv
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
from agno.knowledge.embedder.base import Embedder
|
|
6
|
-
from agno.utils.log import
|
|
6
|
+
from agno.utils.log import log_error, log_info, log_warning
|
|
7
7
|
|
|
8
8
|
try:
|
|
9
9
|
from mistralai import Mistral # type: ignore
|
|
10
10
|
from mistralai.models.embeddingresponse import EmbeddingResponse # type: ignore
|
|
11
11
|
except ImportError:
|
|
12
|
-
|
|
12
|
+
log_error("`mistralai` not installed")
|
|
13
13
|
raise
|
|
14
14
|
|
|
15
15
|
|
|
@@ -50,7 +50,7 @@ class MistralEmbedder(Embedder):
|
|
|
50
50
|
|
|
51
51
|
def _response(self, text: str) -> EmbeddingResponse:
|
|
52
52
|
_request_params: Dict[str, Any] = {
|
|
53
|
-
"inputs": text,
|
|
53
|
+
"inputs": [text], # Mistral API expects a list
|
|
54
54
|
"model": self.id,
|
|
55
55
|
}
|
|
56
56
|
if self.request_params:
|
|
@@ -67,7 +67,7 @@ class MistralEmbedder(Embedder):
|
|
|
67
67
|
return response.data[0].embedding
|
|
68
68
|
return []
|
|
69
69
|
except Exception as e:
|
|
70
|
-
|
|
70
|
+
log_warning(f"Error getting embedding: {e}")
|
|
71
71
|
return []
|
|
72
72
|
|
|
73
73
|
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Dict[str, Any]]:
|
|
@@ -79,7 +79,7 @@ class MistralEmbedder(Embedder):
|
|
|
79
79
|
usage: Dict[str, Any] = response.usage.model_dump() if response.usage else {}
|
|
80
80
|
return embedding, usage
|
|
81
81
|
except Exception as e:
|
|
82
|
-
|
|
82
|
+
log_warning(f"Error getting embedding and usage: {e}")
|
|
83
83
|
return [], {}
|
|
84
84
|
|
|
85
85
|
async def async_get_embedding(self, text: str) -> List[float]:
|
|
@@ -88,7 +88,7 @@ class MistralEmbedder(Embedder):
|
|
|
88
88
|
# Check if the client has an async version of embeddings.create
|
|
89
89
|
if hasattr(self.client.embeddings, "create_async"):
|
|
90
90
|
response: EmbeddingResponse = await self.client.embeddings.create_async(
|
|
91
|
-
inputs=text, model=self.id, **self.request_params if self.request_params else {}
|
|
91
|
+
inputs=[text], model=self.id, **self.request_params if self.request_params else {}
|
|
92
92
|
)
|
|
93
93
|
else:
|
|
94
94
|
# Fallback to running sync method in thread executor
|
|
@@ -98,7 +98,7 @@ class MistralEmbedder(Embedder):
|
|
|
98
98
|
response: EmbeddingResponse = await loop.run_in_executor( # type: ignore
|
|
99
99
|
None,
|
|
100
100
|
lambda: self.client.embeddings.create(
|
|
101
|
-
inputs=text, model=self.id, **self.request_params if self.request_params else {}
|
|
101
|
+
inputs=[text], model=self.id, **self.request_params if self.request_params else {}
|
|
102
102
|
),
|
|
103
103
|
)
|
|
104
104
|
|
|
@@ -106,7 +106,7 @@ class MistralEmbedder(Embedder):
|
|
|
106
106
|
return response.data[0].embedding
|
|
107
107
|
return []
|
|
108
108
|
except Exception as e:
|
|
109
|
-
|
|
109
|
+
log_warning(f"Error getting embedding: {e}")
|
|
110
110
|
return []
|
|
111
111
|
|
|
112
112
|
async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Dict[str, Any]]:
|
|
@@ -115,7 +115,7 @@ class MistralEmbedder(Embedder):
|
|
|
115
115
|
# Check if the client has an async version of embeddings.create
|
|
116
116
|
if hasattr(self.client.embeddings, "create_async"):
|
|
117
117
|
response: EmbeddingResponse = await self.client.embeddings.create_async(
|
|
118
|
-
inputs=text, model=self.id, **self.request_params if self.request_params else {}
|
|
118
|
+
inputs=[text], model=self.id, **self.request_params if self.request_params else {}
|
|
119
119
|
)
|
|
120
120
|
else:
|
|
121
121
|
# Fallback to running sync method in thread executor
|
|
@@ -125,7 +125,7 @@ class MistralEmbedder(Embedder):
|
|
|
125
125
|
response: EmbeddingResponse = await loop.run_in_executor( # type: ignore
|
|
126
126
|
None,
|
|
127
127
|
lambda: self.client.embeddings.create(
|
|
128
|
-
inputs=text, model=self.id, **self.request_params if self.request_params else {}
|
|
128
|
+
inputs=[text], model=self.id, **self.request_params if self.request_params else {}
|
|
129
129
|
),
|
|
130
130
|
)
|
|
131
131
|
|
|
@@ -135,5 +135,72 @@ class MistralEmbedder(Embedder):
|
|
|
135
135
|
usage: Dict[str, Any] = response.usage.model_dump() if response.usage else {}
|
|
136
136
|
return embedding, usage
|
|
137
137
|
except Exception as e:
|
|
138
|
-
|
|
138
|
+
log_warning(f"Error getting embedding and usage: {e}")
|
|
139
139
|
return [], {}
|
|
140
|
+
|
|
141
|
+
async def async_get_embeddings_batch_and_usage(
|
|
142
|
+
self, texts: List[str]
|
|
143
|
+
) -> Tuple[List[List[float]], List[Optional[Dict[str, Any]]]]:
|
|
144
|
+
"""
|
|
145
|
+
Get embeddings and usage for multiple texts in batches.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
texts: List of text strings to embed
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
152
|
+
"""
|
|
153
|
+
all_embeddings = []
|
|
154
|
+
all_usage = []
|
|
155
|
+
log_info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
|
|
156
|
+
|
|
157
|
+
for i in range(0, len(texts), self.batch_size):
|
|
158
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
159
|
+
|
|
160
|
+
_request_params: Dict[str, Any] = {
|
|
161
|
+
"inputs": batch_texts, # Mistral API expects a list for batch processing
|
|
162
|
+
"model": self.id,
|
|
163
|
+
}
|
|
164
|
+
if self.request_params:
|
|
165
|
+
_request_params.update(self.request_params)
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
# Check if the client has an async version of embeddings.create
|
|
169
|
+
if hasattr(self.client.embeddings, "create_async"):
|
|
170
|
+
response: EmbeddingResponse = await self.client.embeddings.create_async(**_request_params)
|
|
171
|
+
else:
|
|
172
|
+
# Fallback to running sync method in thread executor
|
|
173
|
+
import asyncio
|
|
174
|
+
|
|
175
|
+
loop = asyncio.get_running_loop()
|
|
176
|
+
response: EmbeddingResponse = await loop.run_in_executor( # type: ignore
|
|
177
|
+
None, lambda: self.client.embeddings.create(**_request_params)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Extract embeddings from batch response
|
|
181
|
+
if response.data:
|
|
182
|
+
batch_embeddings = [data.embedding for data in response.data if data.embedding]
|
|
183
|
+
all_embeddings.extend(batch_embeddings)
|
|
184
|
+
else:
|
|
185
|
+
# If no embeddings, add empty lists for each text in batch
|
|
186
|
+
all_embeddings.extend([[] for _ in batch_texts])
|
|
187
|
+
|
|
188
|
+
# Extract usage information
|
|
189
|
+
usage_dict = response.usage.model_dump() if response.usage else None
|
|
190
|
+
# Add same usage info for each embedding in the batch
|
|
191
|
+
all_usage.extend([usage_dict] * len(batch_texts))
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
log_warning(f"Error in async batch embedding: {e}")
|
|
195
|
+
# Fallback to individual calls for this batch
|
|
196
|
+
for text in batch_texts:
|
|
197
|
+
try:
|
|
198
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
199
|
+
all_embeddings.append(embedding)
|
|
200
|
+
all_usage.append(usage)
|
|
201
|
+
except Exception as e2:
|
|
202
|
+
log_warning(f"Error in individual async embedding fallback: {e2}")
|
|
203
|
+
all_embeddings.append([])
|
|
204
|
+
all_usage.append(None)
|
|
205
|
+
|
|
206
|
+
return all_embeddings, all_usage
|
|
@@ -45,6 +45,11 @@ class OllamaEmbedder(Embedder):
|
|
|
45
45
|
ollama_client: Optional[OllamaClient] = None
|
|
46
46
|
async_client: Optional[AsyncOllamaClient] = None
|
|
47
47
|
|
|
48
|
+
def __post_init__(self):
|
|
49
|
+
if self.enable_batch:
|
|
50
|
+
logger.warning("OllamaEmbedder does not support batch embeddings, setting enable_batch to False")
|
|
51
|
+
self.enable_batch = False
|
|
52
|
+
|
|
48
53
|
@property
|
|
49
54
|
def client(self) -> OllamaClient:
|
|
50
55
|
if self.ollama_client:
|