agno 2.0.10__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. agno/agent/agent.py +608 -175
  2. agno/db/in_memory/in_memory_db.py +42 -29
  3. agno/db/postgres/postgres.py +6 -4
  4. agno/exceptions.py +62 -1
  5. agno/guardrails/__init__.py +6 -0
  6. agno/guardrails/base.py +19 -0
  7. agno/guardrails/openai.py +144 -0
  8. agno/guardrails/pii.py +94 -0
  9. agno/guardrails/prompt_injection.py +51 -0
  10. agno/knowledge/embedder/aws_bedrock.py +9 -4
  11. agno/knowledge/embedder/azure_openai.py +54 -0
  12. agno/knowledge/embedder/base.py +2 -0
  13. agno/knowledge/embedder/cohere.py +184 -5
  14. agno/knowledge/embedder/google.py +79 -1
  15. agno/knowledge/embedder/huggingface.py +9 -4
  16. agno/knowledge/embedder/jina.py +63 -0
  17. agno/knowledge/embedder/mistral.py +78 -11
  18. agno/knowledge/embedder/ollama.py +5 -0
  19. agno/knowledge/embedder/openai.py +18 -54
  20. agno/knowledge/embedder/voyageai.py +69 -16
  21. agno/knowledge/knowledge.py +5 -4
  22. agno/knowledge/reader/pdf_reader.py +4 -3
  23. agno/knowledge/reader/website_reader.py +3 -2
  24. agno/models/base.py +125 -32
  25. agno/models/cerebras/cerebras.py +1 -0
  26. agno/models/cerebras/cerebras_openai.py +1 -0
  27. agno/models/dashscope/dashscope.py +1 -0
  28. agno/models/google/gemini.py +27 -5
  29. agno/models/litellm/chat.py +17 -0
  30. agno/models/openai/chat.py +13 -4
  31. agno/models/perplexity/perplexity.py +2 -3
  32. agno/models/requesty/__init__.py +5 -0
  33. agno/models/requesty/requesty.py +49 -0
  34. agno/models/vllm/vllm.py +1 -0
  35. agno/models/xai/xai.py +1 -0
  36. agno/os/app.py +167 -148
  37. agno/os/interfaces/whatsapp/router.py +2 -0
  38. agno/os/mcp.py +1 -1
  39. agno/os/middleware/__init__.py +7 -0
  40. agno/os/middleware/jwt.py +233 -0
  41. agno/os/router.py +181 -45
  42. agno/os/routers/home.py +2 -2
  43. agno/os/routers/memory/memory.py +23 -1
  44. agno/os/routers/memory/schemas.py +1 -1
  45. agno/os/routers/session/session.py +20 -3
  46. agno/os/utils.py +172 -8
  47. agno/run/agent.py +120 -77
  48. agno/run/team.py +115 -72
  49. agno/run/workflow.py +5 -15
  50. agno/session/summary.py +9 -10
  51. agno/session/team.py +2 -1
  52. agno/team/team.py +720 -168
  53. agno/tools/firecrawl.py +4 -4
  54. agno/tools/function.py +42 -2
  55. agno/tools/knowledge.py +3 -3
  56. agno/tools/searxng.py +2 -2
  57. agno/tools/serper.py +2 -2
  58. agno/tools/spider.py +2 -2
  59. agno/tools/workflow.py +4 -5
  60. agno/utils/events.py +66 -1
  61. agno/utils/hooks.py +57 -0
  62. agno/utils/media.py +11 -9
  63. agno/utils/print_response/agent.py +43 -5
  64. agno/utils/print_response/team.py +48 -12
  65. agno/vectordb/cassandra/cassandra.py +44 -4
  66. agno/vectordb/chroma/chromadb.py +79 -8
  67. agno/vectordb/clickhouse/clickhousedb.py +43 -6
  68. agno/vectordb/couchbase/couchbase.py +76 -5
  69. agno/vectordb/lancedb/lance_db.py +38 -3
  70. agno/vectordb/llamaindex/__init__.py +3 -0
  71. agno/vectordb/milvus/milvus.py +76 -4
  72. agno/vectordb/mongodb/mongodb.py +76 -4
  73. agno/vectordb/pgvector/pgvector.py +50 -6
  74. agno/vectordb/pineconedb/pineconedb.py +39 -2
  75. agno/vectordb/qdrant/qdrant.py +76 -26
  76. agno/vectordb/singlestore/singlestore.py +77 -4
  77. agno/vectordb/upstashdb/upstashdb.py +42 -2
  78. agno/vectordb/weaviate/weaviate.py +39 -3
  79. agno/workflow/types.py +1 -0
  80. agno/workflow/workflow.py +58 -2
  81. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/METADATA +4 -3
  82. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/RECORD +85 -75
  83. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/WHEEL +0 -0
  84. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/licenses/LICENSE +0 -0
  85. {agno-2.0.10.dist-info → agno-2.1.0.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from agno.exceptions import AgnoError, ModelProviderError
7
7
  from agno.knowledge.embedder.base import Embedder
8
- from agno.utils.log import log_error, logger
8
+ from agno.utils.log import log_error, log_warning
9
9
 
10
10
  try:
11
11
  from boto3 import client as AwsClient
@@ -69,6 +69,11 @@ class AwsBedrockEmbedder(Embedder):
69
69
  client_params: Optional[Dict[str, Any]] = None
70
70
  client: Optional[AwsClient] = None
71
71
 
72
+ def __post_init__(self):
73
+ if self.enable_batch:
74
+ log_warning("AwsBedrockEmbedder does not support batch embeddings, setting enable_batch to False")
75
+ self.enable_batch = False
76
+
72
77
  def get_client(self) -> AwsClient:
73
78
  """
74
79
  Returns an AWS Bedrock client.
@@ -220,10 +225,10 @@ class AwsBedrockEmbedder(Embedder):
220
225
  # Fallback to the first available embedding type
221
226
  for embedding_type in response["embeddings"]:
222
227
  return response["embeddings"][embedding_type][0]
223
- logger.warning("No embeddings found in response")
228
+ log_warning("No embeddings found in response")
224
229
  return []
225
230
  except Exception as e:
226
- logger.warning(f"Error extracting embeddings: {e}")
231
+ log_warning(f"Error extracting embeddings: {e}")
227
232
  return []
228
233
 
229
234
  def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
@@ -286,7 +291,7 @@ class AwsBedrockEmbedder(Embedder):
286
291
  # Fallback to the first available embedding type
287
292
  for embedding_type in response_body["embeddings"]:
288
293
  return response_body["embeddings"][embedding_type][0]
289
- logger.warning("No embeddings found in response")
294
+ log_warning("No embeddings found in response")
290
295
  return []
291
296
  except ClientError as e:
292
297
  log_error(f"Unexpected error calling Bedrock API: {str(e)}")
@@ -154,3 +154,57 @@ class AzureOpenAIEmbedder(Embedder):
154
154
  embedding = response.data[0].embedding
155
155
  usage = response.usage
156
156
  return embedding, usage.model_dump()
157
+
158
+ async def async_get_embeddings_batch_and_usage(
159
+ self, texts: List[str]
160
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
161
+ """
162
+ Get embeddings and usage for multiple texts in batches.
163
+
164
+ Args:
165
+ texts: List of text strings to embed
166
+
167
+ Returns:
168
+ Tuple of (List of embedding vectors, List of usage dictionaries)
169
+ """
170
+ all_embeddings = []
171
+ all_usage = []
172
+ logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
173
+
174
+ for i in range(0, len(texts), self.batch_size):
175
+ batch_texts = texts[i : i + self.batch_size]
176
+
177
+ req: Dict[str, Any] = {
178
+ "input": batch_texts,
179
+ "model": self.id,
180
+ "encoding_format": self.encoding_format,
181
+ }
182
+ if self.user is not None:
183
+ req["user"] = self.user
184
+ if self.id.startswith("text-embedding-3"):
185
+ req["dimensions"] = self.dimensions
186
+ if self.request_params:
187
+ req.update(self.request_params)
188
+
189
+ try:
190
+ response: CreateEmbeddingResponse = await self.aclient.embeddings.create(**req)
191
+ batch_embeddings = [data.embedding for data in response.data]
192
+ all_embeddings.extend(batch_embeddings)
193
+
194
+ # For each embedding in the batch, add the same usage information
195
+ usage_dict = response.usage.model_dump() if response.usage else None
196
+ all_usage.extend([usage_dict] * len(batch_embeddings))
197
+ except Exception as e:
198
+ logger.warning(f"Error in async batch embedding: {e}")
199
+ # Fallback to individual calls for this batch
200
+ for text in batch_texts:
201
+ try:
202
+ embedding, usage = await self.async_get_embedding_and_usage(text)
203
+ all_embeddings.append(embedding)
204
+ all_usage.append(usage)
205
+ except Exception as e2:
206
+ logger.warning(f"Error in individual async embedding fallback: {e2}")
207
+ all_embeddings.append([])
208
+ all_usage.append(None)
209
+
210
+ return all_embeddings, all_usage
@@ -7,6 +7,8 @@ class Embedder:
7
7
  """Base class for managing embedders"""
8
8
 
9
9
  dimensions: Optional[int] = 1536
10
+ enable_batch: bool = False
11
+ batch_size: int = 100 # Number of texts to process in each API call
10
12
 
11
13
  def get_embedding(self, text: str) -> List[float]:
12
14
  raise NotImplementedError
@@ -1,8 +1,9 @@
1
+ import time
1
2
  from dataclasses import dataclass
2
3
  from typing import Any, Dict, List, Optional, Tuple, Union
3
4
 
4
5
  from agno.knowledge.embedder.base import Embedder
5
- from agno.utils.log import logger
6
+ from agno.utils.log import log_debug, log_error, log_info, log_warning
6
7
 
7
8
  try:
8
9
  from cohere import AsyncClient as AsyncCohereClient
@@ -22,6 +23,7 @@ class CohereEmbedder(Embedder):
22
23
  client_params: Optional[Dict[str, Any]] = None
23
24
  cohere_client: Optional[CohereClient] = None
24
25
  async_client: Optional[AsyncCohereClient] = None
26
+ exponential_backoff: bool = False # Enable exponential backoff on rate limits
25
27
 
26
28
  @property
27
29
  def client(self) -> CohereClient:
@@ -61,6 +63,111 @@ class CohereEmbedder(Embedder):
61
63
  request_params.update(self.request_params)
62
64
  return self.client.embed(texts=[text], **request_params)
63
65
 
66
+ def _get_batch_request_params(self) -> Dict[str, Any]:
67
+ """Get request parameters for batch embedding calls."""
68
+ request_params: Dict[str, Any] = {}
69
+
70
+ if self.id:
71
+ request_params["model"] = self.id
72
+ if self.input_type:
73
+ request_params["input_type"] = self.input_type
74
+ if self.embedding_types:
75
+ request_params["embedding_types"] = self.embedding_types
76
+ if self.request_params:
77
+ request_params.update(self.request_params)
78
+
79
+ return request_params
80
+
81
+ def _is_rate_limit_error(self, error: Exception) -> bool:
82
+ """Check if the error is a rate limiting error."""
83
+ if hasattr(error, "status_code") and error.status_code == 429:
84
+ return True
85
+ error_str = str(error).lower()
86
+ return any(
87
+ phrase in error_str
88
+ for phrase in ["rate limit", "too many requests", "429", "trial key", "api calls / minute"]
89
+ )
90
+
91
+ def _exponential_backoff_sleep(self, attempt: int, base_delay: float = 1.0) -> None:
92
+ """Sleep with exponential backoff."""
93
+ delay = base_delay * (2**attempt) + (time.time() % 1) # Add jitter
94
+ log_debug(f"Rate limited, waiting {delay:.2f} seconds before retry (attempt {attempt + 1})")
95
+ time.sleep(delay)
96
+
97
+ async def _async_rate_limit_backoff_sleep(self, attempt: int) -> None:
98
+ """Async version of rate-limit-aware backoff for APIs with per-minute limits."""
99
+ import asyncio
100
+
101
+ # For 40 req/min APIs like Cohere Trial, we need longer waits
102
+ if attempt == 0:
103
+ delay = 15.0 # Wait 15 seconds (1/4 of minute window)
104
+ elif attempt == 1:
105
+ delay = 30.0 # Wait 30 seconds (1/2 of minute window)
106
+ else:
107
+ delay = 60.0 # Wait full minute for window reset
108
+
109
+ # Add small jitter
110
+ delay += time.time() % 3
111
+
112
+ log_debug(
113
+ f"Async rate limit backoff, waiting {delay:.1f} seconds for rate limit window reset (attempt {attempt + 1})"
114
+ )
115
+ await asyncio.sleep(delay)
116
+
117
+ async def _async_batch_with_retry(
118
+ self, texts: List[str], max_retries: int = 3
119
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
120
+ """Execute async batch embedding with rate-limit-aware backoff for rate limiting."""
121
+
122
+ log_debug(f"Starting async batch retry for {len(texts)} texts with max_retries={max_retries}")
123
+
124
+ for attempt in range(max_retries + 1):
125
+ try:
126
+ request_params = self._get_batch_request_params()
127
+ response: Union[
128
+ EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse
129
+ ] = await self.aclient.embed(texts=texts, **request_params)
130
+
131
+ # Extract embeddings from response
132
+ if isinstance(response, EmbeddingsFloatsEmbedResponse):
133
+ batch_embeddings = response.embeddings
134
+ elif isinstance(response, EmbeddingsByTypeEmbedResponse):
135
+ batch_embeddings = response.embeddings.float_ if response.embeddings.float_ else []
136
+ else:
137
+ log_warning("No embeddings found in response")
138
+ batch_embeddings = []
139
+
140
+ # Extract usage information
141
+ usage = response.meta.billed_units if response.meta else None
142
+ usage_dict = usage.model_dump() if usage else None
143
+ all_usage = [usage_dict] * len(batch_embeddings)
144
+
145
+ log_debug(f"Async batch embedding succeeded on attempt {attempt + 1}")
146
+ return batch_embeddings, all_usage
147
+
148
+ except Exception as e:
149
+ if self._is_rate_limit_error(e):
150
+ if not self.exponential_backoff:
151
+ log_warning(
152
+ "Rate limit detected. To enable automatic backoff retry, set enable_backoff=True when creating the embedder."
153
+ )
154
+ raise e
155
+
156
+ log_info(f"Async rate limit detected on attempt {attempt + 1}")
157
+ if attempt < max_retries:
158
+ await self._async_rate_limit_backoff_sleep(attempt)
159
+ continue
160
+ else:
161
+ log_warning(f"Async max retries ({max_retries}) reached for rate limiting")
162
+ raise e
163
+ else:
164
+ log_debug(f"Async non-rate-limit error on attempt {attempt + 1}: {e}")
165
+ raise e
166
+
167
+ # This should never be reached, but just in case
168
+ log_error("Could not create embeddings. End of retry loop reached.")
169
+ return [], []
170
+
64
171
  def get_embedding(self, text: str) -> List[float]:
65
172
  response: Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse] = self.response(text=text)
66
173
  try:
@@ -69,10 +176,10 @@ class CohereEmbedder(Embedder):
69
176
  elif isinstance(response, EmbeddingsByTypeEmbedResponse):
70
177
  return response.embeddings.float_[0] if response.embeddings.float_ else []
71
178
  else:
72
- logger.warning("No embeddings found")
179
+ log_warning("No embeddings found")
73
180
  return []
74
181
  except Exception as e:
75
- logger.warning(e)
182
+ log_warning(e)
76
183
  return []
77
184
 
78
185
  def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
@@ -110,10 +217,10 @@ class CohereEmbedder(Embedder):
110
217
  elif isinstance(response, EmbeddingsByTypeEmbedResponse):
111
218
  return response.embeddings.float_[0] if response.embeddings.float_ else []
112
219
  else:
113
- logger.warning("No embeddings found")
220
+ log_warning("No embeddings found")
114
221
  return []
115
222
  except Exception as e:
116
- logger.warning(e)
223
+ log_warning(e)
117
224
  return []
118
225
 
119
226
  async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
@@ -142,3 +249,75 @@ class CohereEmbedder(Embedder):
142
249
  if usage:
143
250
  return embedding, usage.model_dump()
144
251
  return embedding, None
252
+
253
+ async def async_get_embeddings_batch_and_usage(
254
+ self, texts: List[str]
255
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
256
+ """
257
+ Get embeddings and usage for multiple texts in batches (async version).
258
+
259
+ Args:
260
+ texts: List of text strings to embed
261
+
262
+ Returns:
263
+ s, List of usage dictionaries)
264
+ """
265
+ all_embeddings = []
266
+ all_usage = []
267
+ log_info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size} (async)")
268
+
269
+ for i in range(0, len(texts), self.batch_size):
270
+ batch_texts = texts[i : i + self.batch_size]
271
+
272
+ try:
273
+ # Use retry logic for batch processing
274
+ batch_embeddings, batch_usage = await self._async_batch_with_retry(batch_texts)
275
+ all_embeddings.extend(batch_embeddings)
276
+ all_usage.extend(batch_usage)
277
+
278
+ except Exception as e:
279
+ log_warning(f"Async batch embedding failed after retries: {e}")
280
+
281
+ # Check if this is a rate limit error and backoff is disabled
282
+ if self._is_rate_limit_error(e) and not self.exponential_backoff:
283
+ log_warning("Rate limit hit and backoff is disabled. Failing immediately.")
284
+ raise e
285
+
286
+ # Only fall back to individual calls for non-rate-limit errors
287
+ # For rate limit errors, we should reduce batch size instead
288
+ if self._is_rate_limit_error(e):
289
+ log_warning("Rate limit hit even after retries. Consider reducing batch_size or upgrading API key.")
290
+ # Try with smaller batch size
291
+ if len(batch_texts) > 1:
292
+ smaller_batch_size = max(1, len(batch_texts) // 2)
293
+ log_info(f"Retrying with smaller batch size: {smaller_batch_size}")
294
+ for j in range(0, len(batch_texts), smaller_batch_size):
295
+ small_batch = batch_texts[j : j + smaller_batch_size]
296
+ try:
297
+ small_embeddings, small_usage = await self._async_batch_with_retry(small_batch)
298
+ all_embeddings.extend(small_embeddings)
299
+ all_usage.extend(small_usage)
300
+ except Exception as e3:
301
+ log_error(f"Failed even with reduced batch size: {e3}")
302
+ # Fall back to empty results for this batch
303
+ all_embeddings.extend([[] for _ in small_batch])
304
+ all_usage.extend([None for _ in small_batch])
305
+ else:
306
+ # Single item already failed, add empty result
307
+ log_debug("Single item failed, adding empty result")
308
+ all_embeddings.append([])
309
+ all_usage.append(None)
310
+ else:
311
+ # For non-rate-limit errors, fall back to individual calls
312
+ log_debug("Non-rate-limit error, falling back to individual calls")
313
+ for text in batch_texts:
314
+ try:
315
+ embedding, usage = await self.async_get_embedding_and_usage(text)
316
+ all_embeddings.append(embedding)
317
+ all_usage.append(usage)
318
+ except Exception as e2:
319
+ log_warning(f"Error in individual async embedding fallback: {e2}")
320
+ all_embeddings.append([])
321
+ all_usage.append(None)
322
+
323
+ return all_embeddings, all_usage
@@ -3,7 +3,7 @@ from os import getenv
3
3
  from typing import Any, Dict, List, Optional, Tuple
4
4
 
5
5
  from agno.knowledge.embedder.base import Embedder
6
- from agno.utils.log import log_error, log_info
6
+ from agno.utils.log import log_error, log_info, log_warning
7
7
 
8
8
  try:
9
9
  from google import genai
@@ -178,3 +178,81 @@ class GeminiEmbedder(Embedder):
178
178
  except Exception as e:
179
179
  log_error(f"Error extracting embeddings: {e}")
180
180
  return [], usage
181
+
182
+ async def async_get_embeddings_batch_and_usage(
183
+ self, texts: List[str]
184
+ ) -> Tuple[List[List[float]], List[Optional[Dict[str, Any]]]]:
185
+ """
186
+ Get embeddings and usage for multiple texts in batches.
187
+
188
+ Args:
189
+ texts: List of text strings to embed
190
+
191
+ Returns:
192
+ Tuple of (List of embedding vectors, List of usage dictionaries)
193
+ """
194
+ all_embeddings: List[List[float]] = []
195
+ all_usage: List[Optional[Dict[str, Any]]] = []
196
+ log_info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
197
+
198
+ for i in range(0, len(texts), self.batch_size):
199
+ batch_texts = texts[i : i + self.batch_size]
200
+
201
+ # If a user provides a model id with the `models/` prefix, we need to remove it
202
+ _id = self.id
203
+ if _id.startswith("models/"):
204
+ _id = _id.split("/")[-1]
205
+
206
+ _request_params: Dict[str, Any] = {"contents": batch_texts, "model": _id, "config": {}}
207
+ if self.dimensions:
208
+ _request_params["config"]["output_dimensionality"] = self.dimensions
209
+ if self.task_type:
210
+ _request_params["config"]["task_type"] = self.task_type
211
+ if self.title:
212
+ _request_params["config"]["title"] = self.title
213
+ if not _request_params["config"]:
214
+ del _request_params["config"]
215
+
216
+ if self.request_params:
217
+ _request_params.update(self.request_params)
218
+
219
+ try:
220
+ response = await self.aclient.aio.models.embed_content(**_request_params)
221
+
222
+ # Extract embeddings from batch response
223
+ if response.embeddings:
224
+ batch_embeddings = []
225
+ for embedding in response.embeddings:
226
+ if embedding.values is not None:
227
+ batch_embeddings.append(embedding.values)
228
+ else:
229
+ batch_embeddings.append([])
230
+ all_embeddings.extend(batch_embeddings)
231
+ else:
232
+ # If no embeddings, add empty lists for each text in batch
233
+ all_embeddings.extend([[] for _ in batch_texts])
234
+
235
+ # Extract usage information
236
+ usage_dict = None
237
+ if response.metadata and hasattr(response.metadata, "billable_character_count"):
238
+ usage_dict = {"billable_character_count": response.metadata.billable_character_count}
239
+
240
+ # Add same usage info for each embedding in the batch
241
+ all_usage.extend([usage_dict] * len(batch_texts))
242
+
243
+ except Exception as e:
244
+ log_warning(f"Error in async batch embedding: {e}")
245
+ # Fallback to individual calls for this batch
246
+ for text in batch_texts:
247
+ try:
248
+ text_embedding: List[float]
249
+ text_usage: Optional[Dict[str, Any]]
250
+ text_embedding, text_usage = await self.async_get_embedding_and_usage(text)
251
+ all_embeddings.append(text_embedding)
252
+ all_usage.append(text_usage)
253
+ except Exception as e2:
254
+ log_warning(f"Error in individual async embedding fallback: {e2}")
255
+ all_embeddings.append([])
256
+ all_usage.append(None)
257
+
258
+ return all_embeddings, all_usage
@@ -3,12 +3,12 @@ from os import getenv
3
3
  from typing import Any, Dict, List, Optional, Tuple
4
4
 
5
5
  from agno.knowledge.embedder.base import Embedder
6
- from agno.utils.log import logger
6
+ from agno.utils.log import log_error, log_warning
7
7
 
8
8
  try:
9
9
  from huggingface_hub import AsyncInferenceClient, InferenceClient
10
10
  except ImportError:
11
- logger.error("`huggingface-hub` not installed, please run `pip install huggingface-hub`")
11
+ log_error("`huggingface-hub` not installed, please run `pip install huggingface-hub`")
12
12
  raise
13
13
 
14
14
 
@@ -22,6 +22,11 @@ class HuggingfaceCustomEmbedder(Embedder):
22
22
  huggingface_client: Optional[InferenceClient] = None
23
23
  async_client: Optional[AsyncInferenceClient] = None
24
24
 
25
+ def __post_init__(self):
26
+ if self.enable_batch:
27
+ log_warning("HuggingfaceEmbedder does not support batch embeddings, setting enable_batch to False")
28
+ self.enable_batch = False
29
+
25
30
  @property
26
31
  def client(self) -> InferenceClient:
27
32
  if self.huggingface_client:
@@ -61,7 +66,7 @@ class HuggingfaceCustomEmbedder(Embedder):
61
66
  else:
62
67
  return list(response)
63
68
  except Exception as e:
64
- logger.warning(f"Failed to process embeddings: {e}")
69
+ log_warning(f"Failed to process embeddings: {e}")
65
70
  return []
66
71
 
67
72
  def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
@@ -80,7 +85,7 @@ class HuggingfaceCustomEmbedder(Embedder):
80
85
  else:
81
86
  return list(response)
82
87
  except Exception as e:
83
- logger.warning(f"Failed to process embeddings: {e}")
88
+ log_warning(f"Failed to process embeddings: {e}")
84
89
  return []
85
90
 
86
91
  async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
@@ -117,3 +117,66 @@ class JinaEmbedder(Embedder):
117
117
  except Exception as e:
118
118
  logger.warning(f"Failed to get embedding and usage: {e}")
119
119
  return [], None
120
+
121
+ async def _async_batch_response(self, texts: List[str]) -> Dict[str, Any]:
122
+ """Async batch version of _response using aiohttp."""
123
+ data = {
124
+ "model": self.id,
125
+ "late_chunking": self.late_chunking,
126
+ "dimensions": self.dimensions,
127
+ "embedding_type": self.embedding_type,
128
+ "input": texts, # Jina API expects a list of texts for batch processing
129
+ }
130
+ if self.user is not None:
131
+ data["user"] = self.user
132
+ if self.request_params:
133
+ data.update(self.request_params)
134
+
135
+ timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else None
136
+
137
+ async with aiohttp.ClientSession(timeout=timeout) as session:
138
+ async with session.post(self.base_url, headers=self._get_headers(), json=data) as response:
139
+ response.raise_for_status()
140
+ return await response.json()
141
+
142
+ async def async_get_embeddings_batch_and_usage(
143
+ self, texts: List[str]
144
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
145
+ """
146
+ Get embeddings and usage for multiple texts in batches.
147
+
148
+ Args:
149
+ texts: List of text strings to embed
150
+
151
+ Returns:
152
+ Tuple of (List of embedding vectors, List of usage dictionaries)
153
+ """
154
+ all_embeddings = []
155
+ all_usage = []
156
+ logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
157
+
158
+ for i in range(0, len(texts), self.batch_size):
159
+ batch_texts = texts[i : i + self.batch_size]
160
+
161
+ try:
162
+ result = await self._async_batch_response(batch_texts)
163
+ batch_embeddings = [data["embedding"] for data in result["data"]]
164
+ all_embeddings.extend(batch_embeddings)
165
+
166
+ # For each embedding in the batch, add the same usage information
167
+ usage_dict = result.get("usage")
168
+ all_usage.extend([usage_dict] * len(batch_embeddings))
169
+ except Exception as e:
170
+ logger.warning(f"Error in async batch embedding: {e}")
171
+ # Fallback to individual calls for this batch
172
+ for text in batch_texts:
173
+ try:
174
+ embedding, usage = await self.async_get_embedding_and_usage(text)
175
+ all_embeddings.append(embedding)
176
+ all_usage.append(usage)
177
+ except Exception as e2:
178
+ logger.warning(f"Error in individual async embedding fallback: {e2}")
179
+ all_embeddings.append([])
180
+ all_usage.append(None)
181
+
182
+ return all_embeddings, all_usage