agno 2.4.5__py3-none-any.whl → 2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +2 -1
- agno/db/singlestore/singlestore.py +4 -5
- agno/db/surrealdb/models.py +1 -1
- agno/knowledge/chunking/agentic.py +1 -5
- agno/knowledge/chunking/code.py +1 -1
- agno/knowledge/chunking/document.py +22 -42
- agno/knowledge/chunking/fixed.py +1 -5
- agno/knowledge/chunking/markdown.py +9 -25
- agno/knowledge/chunking/recursive.py +1 -3
- agno/knowledge/chunking/row.py +3 -2
- agno/knowledge/chunking/semantic.py +1 -1
- agno/knowledge/chunking/strategy.py +19 -0
- agno/knowledge/embedder/aws_bedrock.py +325 -106
- agno/knowledge/knowledge.py +173 -14
- agno/knowledge/reader/text_reader.py +1 -1
- agno/knowledge/reranker/aws_bedrock.py +299 -0
- agno/learn/machine.py +5 -6
- agno/learn/stores/learned_knowledge.py +108 -131
- agno/run/workflow.py +3 -0
- agno/tools/mcp/mcp.py +26 -1
- agno/utils/print_response/agent.py +8 -8
- agno/utils/print_response/team.py +8 -8
- agno/vectordb/lancedb/lance_db.py +9 -9
- agno/workflow/condition.py +135 -56
- {agno-2.4.5.dist-info → agno-2.4.7.dist-info}/METADATA +34 -59
- {agno-2.4.5.dist-info → agno-2.4.7.dist-info}/RECORD +29 -28
- {agno-2.4.5.dist-info → agno-2.4.7.dist-info}/WHEEL +0 -0
- {agno-2.4.5.dist-info → agno-2.4.7.dist-info}/licenses/LICENSE +0 -0
- {agno-2.4.5.dist-info → agno-2.4.7.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from os import getenv
|
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
from agno.exceptions import AgnoError, ModelProviderError
|
|
7
7
|
from agno.knowledge.embedder.base import Embedder
|
|
@@ -17,17 +17,23 @@ except ImportError:
|
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
import aioboto3
|
|
20
|
-
from aioboto3.session import Session as AioSession
|
|
21
20
|
except ImportError:
|
|
22
|
-
|
|
21
|
+
log_warning("`aioboto3` not installed. Async methods will not be available. Install via `pip install aioboto3`.")
|
|
23
22
|
aioboto3 = None
|
|
24
|
-
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Type aliases for clarity
|
|
26
|
+
InputType = Literal["search_document", "search_query", "classification", "clustering"]
|
|
27
|
+
EmbeddingType = Literal["float", "int8", "uint8", "binary", "ubinary"]
|
|
28
|
+
TruncateV3 = Literal["NONE", "START", "END"]
|
|
29
|
+
TruncateV4 = Literal["NONE", "LEFT", "RIGHT"]
|
|
30
|
+
OutputDimension = Literal[256, 512, 1024, 1536]
|
|
25
31
|
|
|
26
32
|
|
|
27
33
|
@dataclass
|
|
28
34
|
class AwsBedrockEmbedder(Embedder):
|
|
29
35
|
"""
|
|
30
|
-
AWS Bedrock embedder.
|
|
36
|
+
AWS Bedrock embedder supporting Cohere Embed v3 and v4 models.
|
|
31
37
|
|
|
32
38
|
To use this embedder, you need to either:
|
|
33
39
|
1. Set the following environment variables:
|
|
@@ -38,13 +44,22 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
38
44
|
|
|
39
45
|
Args:
|
|
40
46
|
id (str): The model ID to use. Default is 'cohere.embed-multilingual-v3'.
|
|
41
|
-
|
|
47
|
+
- v3 models: 'cohere.embed-multilingual-v3', 'cohere.embed-english-v3'
|
|
48
|
+
- v4 model: 'cohere.embed-v4:0'
|
|
49
|
+
dimensions (Optional[int]): The dimensions of the embeddings.
|
|
50
|
+
- v3: Fixed at 1024
|
|
51
|
+
- v4: Configurable via output_dimension (256, 512, 1024, 1536). Default 1536.
|
|
42
52
|
input_type (str): Prepends special tokens to differentiate types. Options:
|
|
43
53
|
'search_document', 'search_query', 'classification', 'clustering'. Default is 'search_query'.
|
|
44
54
|
truncate (Optional[str]): How to handle inputs longer than the maximum token length.
|
|
45
|
-
|
|
55
|
+
- v3: 'NONE', 'START', 'END'
|
|
56
|
+
- v4: 'NONE', 'LEFT', 'RIGHT'
|
|
46
57
|
embedding_types (Optional[List[str]]): Types of embeddings to return. Options:
|
|
47
58
|
'float', 'int8', 'uint8', 'binary', 'ubinary'. Default is ['float'].
|
|
59
|
+
output_dimension (Optional[int]): (v4 only) Vector length. Options: 256, 512, 1024, 1536.
|
|
60
|
+
Default is 1536 if unspecified.
|
|
61
|
+
max_tokens (Optional[int]): (v4 only) Truncation budget per input object.
|
|
62
|
+
The model supports up to ~128,000 tokens.
|
|
48
63
|
aws_region (Optional[str]): The AWS region to use.
|
|
49
64
|
aws_access_key_id (Optional[str]): The AWS access key ID to use.
|
|
50
65
|
aws_secret_access_key (Optional[str]): The AWS secret access key to use.
|
|
@@ -54,11 +69,14 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
54
69
|
"""
|
|
55
70
|
|
|
56
71
|
id: str = "cohere.embed-multilingual-v3"
|
|
57
|
-
dimensions: int = 1024 #
|
|
58
|
-
input_type:
|
|
59
|
-
truncate: Optional[str] = None # 'NONE'
|
|
60
|
-
# 'float', 'int8', 'uint8', etc.
|
|
61
|
-
|
|
72
|
+
dimensions: int = 1024 # v3: 1024, v4: 1536 default (set in __post_init__)
|
|
73
|
+
input_type: InputType = "search_query"
|
|
74
|
+
truncate: Optional[str] = None # v3: 'NONE'|'START'|'END', v4: 'NONE'|'LEFT'|'RIGHT'
|
|
75
|
+
embedding_types: Optional[List[EmbeddingType]] = None # 'float', 'int8', 'uint8', etc.
|
|
76
|
+
|
|
77
|
+
# v4-specific parameters
|
|
78
|
+
output_dimension: Optional[OutputDimension] = None # 256, 512, 1024, 1536
|
|
79
|
+
max_tokens: Optional[int] = None # Up to 128000 for v4
|
|
62
80
|
|
|
63
81
|
aws_region: Optional[str] = None
|
|
64
82
|
aws_access_key_id: Optional[str] = None
|
|
@@ -74,10 +92,31 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
74
92
|
log_warning("AwsBedrockEmbedder does not support batch embeddings, setting enable_batch to False")
|
|
75
93
|
self.enable_batch = False
|
|
76
94
|
|
|
95
|
+
# Set appropriate default dimensions based on model version
|
|
96
|
+
if self._is_v4_model():
|
|
97
|
+
# v4 default is 1536, but can be overridden by output_dimension
|
|
98
|
+
if self.output_dimension:
|
|
99
|
+
self.dimensions = self.output_dimension
|
|
100
|
+
else:
|
|
101
|
+
self.dimensions = 1536
|
|
102
|
+
else:
|
|
103
|
+
# v3 models are fixed at 1024
|
|
104
|
+
self.dimensions = 1024
|
|
105
|
+
|
|
106
|
+
def _is_v4_model(self) -> bool:
|
|
107
|
+
"""Check if the current model is a Cohere Embed v4 model."""
|
|
108
|
+
return "embed-v4" in self.id.lower()
|
|
109
|
+
|
|
77
110
|
def get_client(self) -> AwsClient:
|
|
78
111
|
"""
|
|
79
112
|
Returns an AWS Bedrock client.
|
|
80
113
|
|
|
114
|
+
Credentials are resolved in the following order:
|
|
115
|
+
1. Explicit session parameter
|
|
116
|
+
2. Explicit aws_access_key_id and aws_secret_access_key parameters
|
|
117
|
+
3. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
|
|
118
|
+
4. Default boto3 credential chain (~/.aws/credentials, SSO, IAM role, etc.)
|
|
119
|
+
|
|
81
120
|
Returns:
|
|
82
121
|
AwsClient: An instance of the AWS Bedrock client.
|
|
83
122
|
"""
|
|
@@ -88,29 +127,39 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
88
127
|
self.client = self.session.client("bedrock-runtime")
|
|
89
128
|
return self.client
|
|
90
129
|
|
|
130
|
+
# Try explicit credentials or environment variables
|
|
91
131
|
self.aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
|
|
92
132
|
self.aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
|
|
93
133
|
self.aws_region = self.aws_region or getenv("AWS_REGION")
|
|
94
134
|
|
|
95
|
-
if
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
135
|
+
if self.aws_access_key_id and self.aws_secret_access_key:
|
|
136
|
+
# Use explicit credentials
|
|
137
|
+
self.client = AwsClient(
|
|
138
|
+
service_name="bedrock-runtime",
|
|
139
|
+
region_name=self.aws_region,
|
|
140
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
141
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
142
|
+
**(self.client_params or {}),
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
# Fall back to default credential chain (SSO, credentials file, IAM role, etc.)
|
|
146
|
+
self.client = AwsClient(
|
|
147
|
+
service_name="bedrock-runtime",
|
|
148
|
+
region_name=self.aws_region,
|
|
149
|
+
**(self.client_params or {}),
|
|
99
150
|
)
|
|
100
|
-
|
|
101
|
-
self.client = AwsClient(
|
|
102
|
-
service_name="bedrock-runtime",
|
|
103
|
-
region_name=self.aws_region,
|
|
104
|
-
aws_access_key_id=self.aws_access_key_id,
|
|
105
|
-
aws_secret_access_key=self.aws_secret_access_key,
|
|
106
|
-
**(self.client_params or {}),
|
|
107
|
-
)
|
|
108
151
|
return self.client
|
|
109
152
|
|
|
110
153
|
def get_async_client(self):
|
|
111
154
|
"""
|
|
112
155
|
Returns an async AWS Bedrock client using aioboto3.
|
|
113
156
|
|
|
157
|
+
Credentials are resolved in the following order:
|
|
158
|
+
1. Explicit session parameter
|
|
159
|
+
2. Explicit aws_access_key_id and aws_secret_access_key parameters
|
|
160
|
+
3. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
|
|
161
|
+
4. Default credential chain (~/.aws/credentials, SSO, IAM role, etc.)
|
|
162
|
+
|
|
114
163
|
Returns:
|
|
115
164
|
An aioboto3 bedrock-runtime client context manager.
|
|
116
165
|
"""
|
|
@@ -129,21 +178,21 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
129
178
|
region_name=self.session.region_name,
|
|
130
179
|
)
|
|
131
180
|
else:
|
|
181
|
+
# Try explicit credentials or environment variables
|
|
132
182
|
self.aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
|
|
133
183
|
self.aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
|
|
134
184
|
self.aws_region = self.aws_region or getenv("AWS_REGION")
|
|
135
185
|
|
|
136
|
-
if
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
186
|
+
if self.aws_access_key_id and self.aws_secret_access_key:
|
|
187
|
+
# Use explicit credentials
|
|
188
|
+
aio_session = aioboto3.Session(
|
|
189
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
190
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
191
|
+
region_name=self.aws_region,
|
|
140
192
|
)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
aws_secret_access_key=self.aws_secret_access_key,
|
|
145
|
-
region_name=self.aws_region,
|
|
146
|
-
)
|
|
193
|
+
else:
|
|
194
|
+
# Fall back to default credential chain (SSO, credentials file, IAM role, etc.)
|
|
195
|
+
aio_session = aioboto3.Session(region_name=self.aws_region)
|
|
147
196
|
|
|
148
197
|
return aio_session.client("bedrock-runtime", **(self.client_params or {}))
|
|
149
198
|
|
|
@@ -157,7 +206,7 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
157
206
|
Returns:
|
|
158
207
|
str: The formatted request body as a JSON string.
|
|
159
208
|
"""
|
|
160
|
-
request_body = {
|
|
209
|
+
request_body: Dict[str, Any] = {
|
|
161
210
|
"texts": [text],
|
|
162
211
|
"input_type": self.input_type,
|
|
163
212
|
}
|
|
@@ -168,12 +217,110 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
168
217
|
if self.embedding_types:
|
|
169
218
|
request_body["embedding_types"] = self.embedding_types
|
|
170
219
|
|
|
220
|
+
# v4-specific parameters
|
|
221
|
+
if self._is_v4_model():
|
|
222
|
+
if self.output_dimension:
|
|
223
|
+
request_body["output_dimension"] = self.output_dimension
|
|
224
|
+
if self.max_tokens:
|
|
225
|
+
request_body["max_tokens"] = self.max_tokens
|
|
226
|
+
|
|
171
227
|
# Add additional request parameters if provided
|
|
172
228
|
if self.request_params:
|
|
173
229
|
request_body.update(self.request_params)
|
|
174
230
|
|
|
175
231
|
return json.dumps(request_body)
|
|
176
232
|
|
|
233
|
+
def _format_multimodal_request_body(
|
|
234
|
+
self,
|
|
235
|
+
texts: Optional[List[str]] = None,
|
|
236
|
+
images: Optional[List[str]] = None,
|
|
237
|
+
inputs: Optional[List[Dict[str, Any]]] = None,
|
|
238
|
+
) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Format a multimodal request body for v4 models.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
texts: List of text strings to embed (text-only mode)
|
|
244
|
+
images: List of base64 data URIs for images (image-only mode)
|
|
245
|
+
inputs: List of interleaved content items for mixed modality
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
str: The formatted request body as a JSON string.
|
|
249
|
+
"""
|
|
250
|
+
if not self._is_v4_model():
|
|
251
|
+
raise AgnoError(
|
|
252
|
+
message="Multimodal embeddings are only supported with Cohere Embed v4 models.",
|
|
253
|
+
status_code=400,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
request_body: Dict[str, Any] = {
|
|
257
|
+
"input_type": self.input_type,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
# Set the appropriate input field
|
|
261
|
+
if inputs:
|
|
262
|
+
request_body["inputs"] = inputs
|
|
263
|
+
elif images:
|
|
264
|
+
request_body["images"] = images
|
|
265
|
+
elif texts:
|
|
266
|
+
request_body["texts"] = texts
|
|
267
|
+
else:
|
|
268
|
+
raise AgnoError(
|
|
269
|
+
message="At least one of texts, images, or inputs must be provided.",
|
|
270
|
+
status_code=400,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if self.truncate:
|
|
274
|
+
request_body["truncate"] = self.truncate
|
|
275
|
+
|
|
276
|
+
if self.embedding_types:
|
|
277
|
+
request_body["embedding_types"] = self.embedding_types
|
|
278
|
+
|
|
279
|
+
if self.output_dimension:
|
|
280
|
+
request_body["output_dimension"] = self.output_dimension
|
|
281
|
+
|
|
282
|
+
if self.max_tokens:
|
|
283
|
+
request_body["max_tokens"] = self.max_tokens
|
|
284
|
+
|
|
285
|
+
if self.request_params:
|
|
286
|
+
request_body.update(self.request_params)
|
|
287
|
+
|
|
288
|
+
return json.dumps(request_body)
|
|
289
|
+
|
|
290
|
+
def _extract_embeddings(self, response_body: Dict[str, Any]) -> List[float]:
|
|
291
|
+
"""
|
|
292
|
+
Extract embeddings from the response body, handling both v3 and v4 formats.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
response_body: The parsed response body from the API.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
List[float]: The embedding vector.
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
if "embeddings" in response_body:
|
|
302
|
+
embeddings = response_body["embeddings"]
|
|
303
|
+
|
|
304
|
+
# Handle list format (single embedding type or v3 default)
|
|
305
|
+
if isinstance(embeddings, list):
|
|
306
|
+
return embeddings[0] if embeddings else []
|
|
307
|
+
|
|
308
|
+
# Handle dict format (multiple embedding types requested)
|
|
309
|
+
if isinstance(embeddings, dict):
|
|
310
|
+
# Prefer float embeddings
|
|
311
|
+
if "float" in embeddings:
|
|
312
|
+
return embeddings["float"][0]
|
|
313
|
+
# Fallback to first available type
|
|
314
|
+
for embedding_type in embeddings:
|
|
315
|
+
if embeddings[embedding_type]:
|
|
316
|
+
return embeddings[embedding_type][0]
|
|
317
|
+
|
|
318
|
+
log_warning("No embeddings found in response")
|
|
319
|
+
return []
|
|
320
|
+
except Exception as e:
|
|
321
|
+
log_warning(f"Error extracting embeddings: {e}")
|
|
322
|
+
return []
|
|
323
|
+
|
|
177
324
|
def response(self, text: str) -> Dict[str, Any]:
|
|
178
325
|
"""
|
|
179
326
|
Get embeddings from AWS Bedrock for the given text.
|
|
@@ -212,24 +359,7 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
212
359
|
List[float]: The embedding vector.
|
|
213
360
|
"""
|
|
214
361
|
response = self.response(text=text)
|
|
215
|
-
|
|
216
|
-
# Check if response contains embeddings or embeddings by type
|
|
217
|
-
if "embeddings" in response:
|
|
218
|
-
if isinstance(response["embeddings"], list):
|
|
219
|
-
# Default 'float' embeddings response format
|
|
220
|
-
return response["embeddings"][0]
|
|
221
|
-
elif isinstance(response["embeddings"], dict):
|
|
222
|
-
# If embeddings_types parameter was used, select float embeddings
|
|
223
|
-
if "float" in response["embeddings"]:
|
|
224
|
-
return response["embeddings"]["float"][0]
|
|
225
|
-
# Fallback to the first available embedding type
|
|
226
|
-
for embedding_type in response["embeddings"]:
|
|
227
|
-
return response["embeddings"][embedding_type][0]
|
|
228
|
-
log_warning("No embeddings found in response")
|
|
229
|
-
return []
|
|
230
|
-
except Exception as e:
|
|
231
|
-
log_warning(f"Error extracting embeddings: {e}")
|
|
232
|
-
return []
|
|
362
|
+
return self._extract_embeddings(response)
|
|
233
363
|
|
|
234
364
|
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
|
|
235
365
|
"""
|
|
@@ -242,27 +372,88 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
242
372
|
Tuple[List[float], Optional[Dict[str, Any]]]: The embedding vector and usage information.
|
|
243
373
|
"""
|
|
244
374
|
response = self.response(text=text)
|
|
375
|
+
embedding = self._extract_embeddings(response)
|
|
376
|
+
usage = response.get("usage")
|
|
377
|
+
return embedding, usage
|
|
245
378
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
if isinstance(response["embeddings"], list):
|
|
250
|
-
embedding = response["embeddings"][0]
|
|
251
|
-
elif isinstance(response["embeddings"], dict):
|
|
252
|
-
if "float" in response["embeddings"]:
|
|
253
|
-
embedding = response["embeddings"]["float"][0]
|
|
254
|
-
# Fallback to the first available embedding type
|
|
255
|
-
else:
|
|
256
|
-
for embedding_type in response["embeddings"]:
|
|
257
|
-
embedding = response["embeddings"][embedding_type][0]
|
|
258
|
-
break
|
|
259
|
-
|
|
260
|
-
# Extract usage metrics if available
|
|
261
|
-
usage = None
|
|
262
|
-
if "usage" in response:
|
|
263
|
-
usage = response["usage"]
|
|
379
|
+
def get_image_embedding(self, image_data_uri: str) -> List[float]:
|
|
380
|
+
"""
|
|
381
|
+
Get embeddings for an image (v4 only).
|
|
264
382
|
|
|
265
|
-
|
|
383
|
+
Args:
|
|
384
|
+
image_data_uri (str): Base64 data URI of the image
|
|
385
|
+
(e.g., "data:image/png;base64,...")
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
List[float]: The embedding vector.
|
|
389
|
+
"""
|
|
390
|
+
if not self._is_v4_model():
|
|
391
|
+
raise AgnoError(
|
|
392
|
+
message="Image embeddings are only supported with Cohere Embed v4 models.",
|
|
393
|
+
status_code=400,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
body = self._format_multimodal_request_body(images=[image_data_uri])
|
|
398
|
+
response = self.get_client().invoke_model(
|
|
399
|
+
modelId=self.id,
|
|
400
|
+
body=body,
|
|
401
|
+
contentType="application/json",
|
|
402
|
+
accept="application/json",
|
|
403
|
+
)
|
|
404
|
+
response_body = json.loads(response["body"].read().decode("utf-8"))
|
|
405
|
+
return self._extract_embeddings(response_body)
|
|
406
|
+
except ClientError as e:
|
|
407
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
408
|
+
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
409
|
+
except Exception as e:
|
|
410
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
411
|
+
raise ModelProviderError(message=str(e), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
412
|
+
|
|
413
|
+
def get_multimodal_embedding(
|
|
414
|
+
self,
|
|
415
|
+
content: List[Dict[str, str]],
|
|
416
|
+
) -> List[float]:
|
|
417
|
+
"""
|
|
418
|
+
Get embeddings for interleaved text and image content (v4 only).
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
content: List of content parts, each being either:
|
|
422
|
+
- {"type": "text", "text": "..."}
|
|
423
|
+
- {"type": "image_url", "image_url": "data:image/png;base64,..."}
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
List[float]: The embedding vector.
|
|
427
|
+
|
|
428
|
+
Example:
|
|
429
|
+
embedder.get_multimodal_embedding([
|
|
430
|
+
{"type": "text", "text": "Product description"},
|
|
431
|
+
{"type": "image_url", "image_url": "data:image/png;base64,..."}
|
|
432
|
+
])
|
|
433
|
+
"""
|
|
434
|
+
if not self._is_v4_model():
|
|
435
|
+
raise AgnoError(
|
|
436
|
+
message="Multimodal embeddings are only supported with Cohere Embed v4 models.",
|
|
437
|
+
status_code=400,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
inputs = [{"content": content}]
|
|
442
|
+
body = self._format_multimodal_request_body(inputs=inputs)
|
|
443
|
+
response = self.get_client().invoke_model(
|
|
444
|
+
modelId=self.id,
|
|
445
|
+
body=body,
|
|
446
|
+
contentType="application/json",
|
|
447
|
+
accept="application/json",
|
|
448
|
+
)
|
|
449
|
+
response_body = json.loads(response["body"].read().decode("utf-8"))
|
|
450
|
+
return self._extract_embeddings(response_body)
|
|
451
|
+
except ClientError as e:
|
|
452
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
453
|
+
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
454
|
+
except Exception as e:
|
|
455
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
456
|
+
raise ModelProviderError(message=str(e), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
266
457
|
|
|
267
458
|
async def async_get_embedding(self, text: str) -> List[float]:
|
|
268
459
|
"""
|
|
@@ -278,21 +469,7 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
278
469
|
accept="application/json",
|
|
279
470
|
)
|
|
280
471
|
response_body = json.loads((await response["body"].read()).decode("utf-8"))
|
|
281
|
-
|
|
282
|
-
# Extract embeddings using the same logic as get_embedding
|
|
283
|
-
if "embeddings" in response_body:
|
|
284
|
-
if isinstance(response_body["embeddings"], list):
|
|
285
|
-
# Default 'float' embeddings response format
|
|
286
|
-
return response_body["embeddings"][0]
|
|
287
|
-
elif isinstance(response_body["embeddings"], dict):
|
|
288
|
-
# If embeddings_types parameter was used, select float embeddings
|
|
289
|
-
if "float" in response_body["embeddings"]:
|
|
290
|
-
return response_body["embeddings"]["float"][0]
|
|
291
|
-
# Fallback to the first available embedding type
|
|
292
|
-
for embedding_type in response_body["embeddings"]:
|
|
293
|
-
return response_body["embeddings"][embedding_type][0]
|
|
294
|
-
log_warning("No embeddings found in response")
|
|
295
|
-
return []
|
|
472
|
+
return self._extract_embeddings(response_body)
|
|
296
473
|
except ClientError as e:
|
|
297
474
|
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
298
475
|
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
@@ -314,27 +491,69 @@ class AwsBedrockEmbedder(Embedder):
|
|
|
314
491
|
accept="application/json",
|
|
315
492
|
)
|
|
316
493
|
response_body = json.loads((await response["body"].read()).decode("utf-8"))
|
|
494
|
+
embedding = self._extract_embeddings(response_body)
|
|
495
|
+
usage = response_body.get("usage")
|
|
496
|
+
return embedding, usage
|
|
497
|
+
except ClientError as e:
|
|
498
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
499
|
+
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
500
|
+
except Exception as e:
|
|
501
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
502
|
+
raise ModelProviderError(message=str(e), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
317
503
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
else:
|
|
328
|
-
for embedding_type in response_body["embeddings"]:
|
|
329
|
-
embedding = response_body["embeddings"][embedding_type][0]
|
|
330
|
-
break
|
|
331
|
-
|
|
332
|
-
# Extract usage metrics if available
|
|
333
|
-
usage = None
|
|
334
|
-
if "usage" in response_body:
|
|
335
|
-
usage = response_body["usage"]
|
|
504
|
+
async def async_get_image_embedding(self, image_data_uri: str) -> List[float]:
|
|
505
|
+
"""
|
|
506
|
+
Async version of get_image_embedding() (v4 only).
|
|
507
|
+
"""
|
|
508
|
+
if not self._is_v4_model():
|
|
509
|
+
raise AgnoError(
|
|
510
|
+
message="Image embeddings are only supported with Cohere Embed v4 models.",
|
|
511
|
+
status_code=400,
|
|
512
|
+
)
|
|
336
513
|
|
|
337
|
-
|
|
514
|
+
try:
|
|
515
|
+
body = self._format_multimodal_request_body(images=[image_data_uri])
|
|
516
|
+
async with self.get_async_client() as client:
|
|
517
|
+
response = await client.invoke_model(
|
|
518
|
+
modelId=self.id,
|
|
519
|
+
body=body,
|
|
520
|
+
contentType="application/json",
|
|
521
|
+
accept="application/json",
|
|
522
|
+
)
|
|
523
|
+
response_body = json.loads((await response["body"].read()).decode("utf-8"))
|
|
524
|
+
return self._extract_embeddings(response_body)
|
|
525
|
+
except ClientError as e:
|
|
526
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
527
|
+
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
528
|
+
except Exception as e:
|
|
529
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
530
|
+
raise ModelProviderError(message=str(e), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|
|
531
|
+
|
|
532
|
+
async def async_get_multimodal_embedding(
|
|
533
|
+
self,
|
|
534
|
+
content: List[Dict[str, str]],
|
|
535
|
+
) -> List[float]:
|
|
536
|
+
"""
|
|
537
|
+
Async version of get_multimodal_embedding() (v4 only).
|
|
538
|
+
"""
|
|
539
|
+
if not self._is_v4_model():
|
|
540
|
+
raise AgnoError(
|
|
541
|
+
message="Multimodal embeddings are only supported with Cohere Embed v4 models.",
|
|
542
|
+
status_code=400,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
try:
|
|
546
|
+
inputs = [{"content": content}]
|
|
547
|
+
body = self._format_multimodal_request_body(inputs=inputs)
|
|
548
|
+
async with self.get_async_client() as client:
|
|
549
|
+
response = await client.invoke_model(
|
|
550
|
+
modelId=self.id,
|
|
551
|
+
body=body,
|
|
552
|
+
contentType="application/json",
|
|
553
|
+
accept="application/json",
|
|
554
|
+
)
|
|
555
|
+
response_body = json.loads((await response["body"].read()).decode("utf-8"))
|
|
556
|
+
return self._extract_embeddings(response_body)
|
|
338
557
|
except ClientError as e:
|
|
339
558
|
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
340
559
|
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
|