agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import boto3
|
|
8
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
9
|
+
except ImportError:
|
|
10
|
+
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
|
11
|
+
|
|
12
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
13
|
+
from agentrun_mem0.configs.llms.aws_bedrock import AWSBedrockConfig
|
|
14
|
+
from agentrun_mem0.llms.base import LLMBase
|
|
15
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
PROVIDERS = [
|
|
20
|
+
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer",
|
|
21
|
+
"deepseek", "gpt-oss", "perplexity", "snowflake", "titan", "command", "j2", "llama"
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract_provider(model: str) -> str:
|
|
26
|
+
"""Extract provider from model identifier."""
|
|
27
|
+
for provider in PROVIDERS:
|
|
28
|
+
if re.search(rf"\b{re.escape(provider)}\b", model):
|
|
29
|
+
return provider
|
|
30
|
+
raise ValueError(f"Unknown provider in model: {model}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AWSBedrockLLM(LLMBase):
|
|
34
|
+
"""
|
|
35
|
+
AWS Bedrock LLM integration for Mem0.
|
|
36
|
+
|
|
37
|
+
Supports all available Bedrock models with automatic provider detection.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: Optional[Union[AWSBedrockConfig, BaseLlmConfig, Dict]] = None):
|
|
41
|
+
"""
|
|
42
|
+
Initialize AWS Bedrock LLM.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config: AWS Bedrock configuration object
|
|
46
|
+
"""
|
|
47
|
+
# Convert to AWSBedrockConfig if needed
|
|
48
|
+
if config is None:
|
|
49
|
+
config = AWSBedrockConfig()
|
|
50
|
+
elif isinstance(config, dict):
|
|
51
|
+
config = AWSBedrockConfig(**config)
|
|
52
|
+
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AWSBedrockConfig):
|
|
53
|
+
# Convert BaseLlmConfig to AWSBedrockConfig
|
|
54
|
+
config = AWSBedrockConfig(
|
|
55
|
+
model=config.model,
|
|
56
|
+
temperature=config.temperature,
|
|
57
|
+
max_tokens=config.max_tokens,
|
|
58
|
+
top_p=config.top_p,
|
|
59
|
+
top_k=config.top_k,
|
|
60
|
+
enable_vision=getattr(config, "enable_vision", False),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
super().__init__(config)
|
|
64
|
+
self.config = config
|
|
65
|
+
|
|
66
|
+
# Initialize AWS client
|
|
67
|
+
self._initialize_aws_client()
|
|
68
|
+
|
|
69
|
+
# Get model configuration
|
|
70
|
+
self.model_config = self.config.get_model_config()
|
|
71
|
+
self.provider = extract_provider(self.config.model)
|
|
72
|
+
|
|
73
|
+
# Initialize provider-specific settings
|
|
74
|
+
self._initialize_provider_settings()
|
|
75
|
+
|
|
76
|
+
def _initialize_aws_client(self):
|
|
77
|
+
"""Initialize AWS Bedrock client with proper credentials."""
|
|
78
|
+
try:
|
|
79
|
+
aws_config = self.config.get_aws_config()
|
|
80
|
+
|
|
81
|
+
# Create Bedrock runtime client
|
|
82
|
+
self.client = boto3.client("bedrock-runtime", **aws_config)
|
|
83
|
+
|
|
84
|
+
# Test connection
|
|
85
|
+
self._test_connection()
|
|
86
|
+
|
|
87
|
+
except NoCredentialsError:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID, "
|
|
90
|
+
"AWS_SECRET_ACCESS_KEY, and AWS_REGION environment variables, "
|
|
91
|
+
"or provide them in the config."
|
|
92
|
+
)
|
|
93
|
+
except ClientError as e:
|
|
94
|
+
if e.response["Error"]["Code"] == "UnauthorizedOperation":
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Unauthorized access to Bedrock. Please ensure your AWS credentials "
|
|
97
|
+
f"have permission to access Bedrock in region {self.config.aws_region}."
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"AWS Bedrock error: {e}")
|
|
101
|
+
|
|
102
|
+
def _test_connection(self):
|
|
103
|
+
"""Test connection to AWS Bedrock service."""
|
|
104
|
+
try:
|
|
105
|
+
# List available models to test connection
|
|
106
|
+
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
|
107
|
+
response = bedrock_client.list_foundation_models()
|
|
108
|
+
self.available_models = [model["modelId"] for model in response["modelSummaries"]]
|
|
109
|
+
|
|
110
|
+
# Check if our model is available
|
|
111
|
+
if self.config.model not in self.available_models:
|
|
112
|
+
logger.warning(f"Model {self.config.model} may not be available in region {self.config.aws_region}")
|
|
113
|
+
logger.info(f"Available models: {', '.join(self.available_models[:5])}...")
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logger.warning(f"Could not verify model availability: {e}")
|
|
117
|
+
self.available_models = []
|
|
118
|
+
|
|
119
|
+
def _initialize_provider_settings(self):
|
|
120
|
+
"""Initialize provider-specific settings and capabilities."""
|
|
121
|
+
# Determine capabilities based on provider and model
|
|
122
|
+
self.supports_tools = self.provider in ["anthropic", "cohere", "amazon"]
|
|
123
|
+
self.supports_vision = self.provider in ["anthropic", "amazon", "meta", "mistral"]
|
|
124
|
+
self.supports_streaming = self.provider in ["anthropic", "cohere", "mistral", "amazon", "meta"]
|
|
125
|
+
|
|
126
|
+
# Set message formatting method
|
|
127
|
+
if self.provider == "anthropic":
|
|
128
|
+
self._format_messages = self._format_messages_anthropic
|
|
129
|
+
elif self.provider == "cohere":
|
|
130
|
+
self._format_messages = self._format_messages_cohere
|
|
131
|
+
elif self.provider == "amazon":
|
|
132
|
+
self._format_messages = self._format_messages_amazon
|
|
133
|
+
elif self.provider == "meta":
|
|
134
|
+
self._format_messages = self._format_messages_meta
|
|
135
|
+
elif self.provider == "mistral":
|
|
136
|
+
self._format_messages = self._format_messages_mistral
|
|
137
|
+
else:
|
|
138
|
+
self._format_messages = self._format_messages_generic
|
|
139
|
+
|
|
140
|
+
def _format_messages_anthropic(self, messages: List[Dict[str, str]]) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
|
141
|
+
"""Format messages for Anthropic models."""
|
|
142
|
+
formatted_messages = []
|
|
143
|
+
system_message = None
|
|
144
|
+
|
|
145
|
+
for message in messages:
|
|
146
|
+
role = message["role"]
|
|
147
|
+
content = message["content"]
|
|
148
|
+
|
|
149
|
+
if role == "system":
|
|
150
|
+
# Anthropic supports system messages as a separate parameter
|
|
151
|
+
# see: https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
|
|
152
|
+
system_message = content
|
|
153
|
+
elif role == "user":
|
|
154
|
+
# Use Converse API format
|
|
155
|
+
formatted_messages.append({"role": "user", "content": [{"text": content}]})
|
|
156
|
+
elif role == "assistant":
|
|
157
|
+
# Use Converse API format
|
|
158
|
+
formatted_messages.append({"role": "assistant", "content": [{"text": content}]})
|
|
159
|
+
|
|
160
|
+
return formatted_messages, system_message
|
|
161
|
+
|
|
162
|
+
def _format_messages_cohere(self, messages: List[Dict[str, str]]) -> str:
|
|
163
|
+
"""Format messages for Cohere models."""
|
|
164
|
+
formatted_messages = []
|
|
165
|
+
|
|
166
|
+
for message in messages:
|
|
167
|
+
role = message["role"].capitalize()
|
|
168
|
+
content = message["content"]
|
|
169
|
+
formatted_messages.append(f"{role}: {content}")
|
|
170
|
+
|
|
171
|
+
return "\n".join(formatted_messages)
|
|
172
|
+
|
|
173
|
+
def _format_messages_amazon(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
|
174
|
+
"""Format messages for Amazon models (including Nova)."""
|
|
175
|
+
formatted_messages = []
|
|
176
|
+
|
|
177
|
+
for message in messages:
|
|
178
|
+
role = message["role"]
|
|
179
|
+
content = message["content"]
|
|
180
|
+
|
|
181
|
+
if role == "system":
|
|
182
|
+
# Amazon models support system messages
|
|
183
|
+
formatted_messages.append({"role": "system", "content": content})
|
|
184
|
+
elif role == "user":
|
|
185
|
+
formatted_messages.append({"role": "user", "content": content})
|
|
186
|
+
elif role == "assistant":
|
|
187
|
+
formatted_messages.append({"role": "assistant", "content": content})
|
|
188
|
+
|
|
189
|
+
return formatted_messages
|
|
190
|
+
|
|
191
|
+
def _format_messages_meta(self, messages: List[Dict[str, str]]) -> str:
|
|
192
|
+
"""Format messages for Meta models."""
|
|
193
|
+
formatted_messages = []
|
|
194
|
+
|
|
195
|
+
for message in messages:
|
|
196
|
+
role = message["role"].capitalize()
|
|
197
|
+
content = message["content"]
|
|
198
|
+
formatted_messages.append(f"{role}: {content}")
|
|
199
|
+
|
|
200
|
+
return "\n".join(formatted_messages)
|
|
201
|
+
|
|
202
|
+
def _format_messages_mistral(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
|
203
|
+
"""Format messages for Mistral models."""
|
|
204
|
+
formatted_messages = []
|
|
205
|
+
|
|
206
|
+
for message in messages:
|
|
207
|
+
role = message["role"]
|
|
208
|
+
content = message["content"]
|
|
209
|
+
|
|
210
|
+
if role == "system":
|
|
211
|
+
# Mistral supports system messages
|
|
212
|
+
formatted_messages.append({"role": "system", "content": content})
|
|
213
|
+
elif role == "user":
|
|
214
|
+
formatted_messages.append({"role": "user", "content": content})
|
|
215
|
+
elif role == "assistant":
|
|
216
|
+
formatted_messages.append({"role": "assistant", "content": content})
|
|
217
|
+
|
|
218
|
+
return formatted_messages
|
|
219
|
+
|
|
220
|
+
def _format_messages_generic(self, messages: List[Dict[str, str]]) -> str:
|
|
221
|
+
"""Generic message formatting for other providers."""
|
|
222
|
+
formatted_messages = []
|
|
223
|
+
|
|
224
|
+
for message in messages:
|
|
225
|
+
role = message["role"].capitalize()
|
|
226
|
+
content = message["content"]
|
|
227
|
+
formatted_messages.append(f"\n\n{role}: {content}")
|
|
228
|
+
|
|
229
|
+
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
|
230
|
+
|
|
231
|
+
def _prepare_input(self, prompt: str) -> Dict[str, Any]:
|
|
232
|
+
"""
|
|
233
|
+
Prepare input for the current provider's model.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
prompt: Text prompt to process
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Prepared input dictionary
|
|
240
|
+
"""
|
|
241
|
+
# Base configuration
|
|
242
|
+
input_body = {"prompt": prompt}
|
|
243
|
+
|
|
244
|
+
# Provider-specific parameter mappings
|
|
245
|
+
provider_mappings = {
|
|
246
|
+
"meta": {"max_tokens": "max_gen_len"},
|
|
247
|
+
"ai21": {"max_tokens": "maxTokens", "top_p": "topP"},
|
|
248
|
+
"mistral": {"max_tokens": "max_tokens"},
|
|
249
|
+
"cohere": {"max_tokens": "max_tokens", "top_p": "p"},
|
|
250
|
+
"amazon": {"max_tokens": "maxTokenCount", "top_p": "topP"},
|
|
251
|
+
"anthropic": {"max_tokens": "max_tokens", "top_p": "top_p"},
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Apply provider mappings
|
|
255
|
+
if self.provider in provider_mappings:
|
|
256
|
+
for old_key, new_key in provider_mappings[self.provider].items():
|
|
257
|
+
if old_key in self.model_config:
|
|
258
|
+
input_body[new_key] = self.model_config[old_key]
|
|
259
|
+
|
|
260
|
+
# Special handling for specific providers
|
|
261
|
+
if self.provider == "cohere" and "cohere.command" in self.config.model:
|
|
262
|
+
input_body["message"] = input_body.pop("prompt")
|
|
263
|
+
elif self.provider == "amazon":
|
|
264
|
+
# Amazon Nova and other Amazon models
|
|
265
|
+
if "nova" in self.config.model.lower():
|
|
266
|
+
# Nova models use the converse API format
|
|
267
|
+
input_body = {
|
|
268
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
269
|
+
"max_tokens": self.model_config.get("max_tokens", 5000),
|
|
270
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
271
|
+
"top_p": self.model_config.get("top_p", 0.9),
|
|
272
|
+
}
|
|
273
|
+
else:
|
|
274
|
+
# Legacy Amazon models
|
|
275
|
+
input_body = {
|
|
276
|
+
"inputText": prompt,
|
|
277
|
+
"textGenerationConfig": {
|
|
278
|
+
"maxTokenCount": self.model_config.get("max_tokens", 5000),
|
|
279
|
+
"topP": self.model_config.get("top_p", 0.9),
|
|
280
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
281
|
+
},
|
|
282
|
+
}
|
|
283
|
+
# Remove None values
|
|
284
|
+
input_body["textGenerationConfig"] = {
|
|
285
|
+
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
|
286
|
+
}
|
|
287
|
+
elif self.provider == "anthropic":
|
|
288
|
+
input_body = {
|
|
289
|
+
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
|
290
|
+
"max_tokens": self.model_config.get("max_tokens", 2000),
|
|
291
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
292
|
+
"top_p": self.model_config.get("top_p", 0.9),
|
|
293
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
294
|
+
}
|
|
295
|
+
elif self.provider == "meta":
|
|
296
|
+
input_body = {
|
|
297
|
+
"prompt": prompt,
|
|
298
|
+
"max_gen_len": self.model_config.get("max_tokens", 5000),
|
|
299
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
300
|
+
"top_p": self.model_config.get("top_p", 0.9),
|
|
301
|
+
}
|
|
302
|
+
elif self.provider == "mistral":
|
|
303
|
+
input_body = {
|
|
304
|
+
"prompt": prompt,
|
|
305
|
+
"max_tokens": self.model_config.get("max_tokens", 5000),
|
|
306
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
307
|
+
"top_p": self.model_config.get("top_p", 0.9),
|
|
308
|
+
}
|
|
309
|
+
else:
|
|
310
|
+
# Generic case - add all model config parameters
|
|
311
|
+
input_body.update(self.model_config)
|
|
312
|
+
|
|
313
|
+
return input_body
|
|
314
|
+
|
|
315
|
+
def _convert_tool_format(self, original_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
316
|
+
"""
|
|
317
|
+
Convert tools to Bedrock-compatible format.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
original_tools: List of tool definitions
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Converted tools in Bedrock format
|
|
324
|
+
"""
|
|
325
|
+
new_tools = []
|
|
326
|
+
|
|
327
|
+
for tool in original_tools:
|
|
328
|
+
if tool["type"] == "function":
|
|
329
|
+
function = tool["function"]
|
|
330
|
+
new_tool = {
|
|
331
|
+
"toolSpec": {
|
|
332
|
+
"name": function["name"],
|
|
333
|
+
"description": function.get("description", ""),
|
|
334
|
+
"inputSchema": {
|
|
335
|
+
"json": {
|
|
336
|
+
"type": "object",
|
|
337
|
+
"properties": {},
|
|
338
|
+
"required": function["parameters"].get("required", []),
|
|
339
|
+
}
|
|
340
|
+
},
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
# Add properties
|
|
345
|
+
for prop, details in function["parameters"].get("properties", {}).items():
|
|
346
|
+
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
|
|
347
|
+
|
|
348
|
+
new_tools.append(new_tool)
|
|
349
|
+
|
|
350
|
+
return new_tools
|
|
351
|
+
|
|
352
|
+
def _parse_response(
|
|
353
|
+
self, response: Dict[str, Any], tools: Optional[List[Dict]] = None
|
|
354
|
+
) -> Union[str, Dict[str, Any]]:
|
|
355
|
+
"""
|
|
356
|
+
Parse response from Bedrock API.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
response: Raw API response
|
|
360
|
+
tools: List of tools if used
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Parsed response
|
|
364
|
+
"""
|
|
365
|
+
if tools:
|
|
366
|
+
# Handle tool-enabled responses
|
|
367
|
+
processed_response = {"tool_calls": []}
|
|
368
|
+
|
|
369
|
+
if response.get("output", {}).get("message", {}).get("content"):
|
|
370
|
+
for item in response["output"]["message"]["content"]:
|
|
371
|
+
if "toolUse" in item:
|
|
372
|
+
processed_response["tool_calls"].append(
|
|
373
|
+
{
|
|
374
|
+
"name": item["toolUse"]["name"],
|
|
375
|
+
"arguments": json.loads(extract_json(json.dumps(item["toolUse"]["input"]))),
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
return processed_response
|
|
380
|
+
|
|
381
|
+
# Handle regular text responses
|
|
382
|
+
try:
|
|
383
|
+
response_body = response.get("body").read().decode()
|
|
384
|
+
response_json = json.loads(response_body)
|
|
385
|
+
|
|
386
|
+
# Provider-specific response parsing
|
|
387
|
+
if self.provider == "anthropic":
|
|
388
|
+
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
|
389
|
+
elif self.provider == "amazon":
|
|
390
|
+
# Handle both Nova and legacy Amazon models
|
|
391
|
+
if "nova" in self.config.model.lower():
|
|
392
|
+
# Nova models return content in a different format
|
|
393
|
+
if "content" in response_json:
|
|
394
|
+
return response_json["content"][0]["text"]
|
|
395
|
+
elif "completion" in response_json:
|
|
396
|
+
return response_json["completion"]
|
|
397
|
+
else:
|
|
398
|
+
# Legacy Amazon models
|
|
399
|
+
return response_json.get("completion", "")
|
|
400
|
+
elif self.provider == "meta":
|
|
401
|
+
return response_json.get("generation", "")
|
|
402
|
+
elif self.provider == "mistral":
|
|
403
|
+
return response_json.get("outputs", [{"text": ""}])[0].get("text", "")
|
|
404
|
+
elif self.provider == "cohere":
|
|
405
|
+
return response_json.get("generations", [{"text": ""}])[0].get("text", "")
|
|
406
|
+
elif self.provider == "ai21":
|
|
407
|
+
return response_json.get("completions", [{"data", {"text": ""}}])[0].get("data", {}).get("text", "")
|
|
408
|
+
else:
|
|
409
|
+
# Generic parsing - try common response fields
|
|
410
|
+
for field in ["content", "text", "completion", "generation"]:
|
|
411
|
+
if field in response_json:
|
|
412
|
+
if isinstance(response_json[field], list) and response_json[field]:
|
|
413
|
+
return response_json[field][0].get("text", "")
|
|
414
|
+
elif isinstance(response_json[field], str):
|
|
415
|
+
return response_json[field]
|
|
416
|
+
|
|
417
|
+
# Fallback
|
|
418
|
+
return str(response_json)
|
|
419
|
+
|
|
420
|
+
except Exception as e:
|
|
421
|
+
logger.warning(f"Could not parse response: {e}")
|
|
422
|
+
return "Error parsing response"
|
|
423
|
+
|
|
424
|
+
def generate_response(
|
|
425
|
+
self,
|
|
426
|
+
messages: List[Dict[str, str]],
|
|
427
|
+
response_format: Optional[str] = None,
|
|
428
|
+
tools: Optional[List[Dict]] = None,
|
|
429
|
+
tool_choice: str = "auto",
|
|
430
|
+
stream: bool = False,
|
|
431
|
+
**kwargs,
|
|
432
|
+
) -> Union[str, Dict[str, Any]]:
|
|
433
|
+
"""
|
|
434
|
+
Generate response using AWS Bedrock.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
messages: List of message dictionaries
|
|
438
|
+
response_format: Response format specification
|
|
439
|
+
tools: List of tools for function calling
|
|
440
|
+
tool_choice: Tool choice method
|
|
441
|
+
stream: Whether to stream the response
|
|
442
|
+
**kwargs: Additional parameters
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
Generated response
|
|
446
|
+
"""
|
|
447
|
+
try:
|
|
448
|
+
if tools and self.supports_tools:
|
|
449
|
+
# Use converse method for tool-enabled models
|
|
450
|
+
return self._generate_with_tools(messages, tools, stream)
|
|
451
|
+
else:
|
|
452
|
+
# Use standard invoke_model method
|
|
453
|
+
return self._generate_standard(messages, stream)
|
|
454
|
+
|
|
455
|
+
except Exception as e:
|
|
456
|
+
logger.error(f"Failed to generate response: {e}")
|
|
457
|
+
raise RuntimeError(f"Failed to generate response: {e}")
|
|
458
|
+
|
|
459
|
+
@staticmethod
|
|
460
|
+
def _convert_tools_to_converse_format(tools: List[Dict]) -> List[Dict]:
|
|
461
|
+
"""Convert OpenAI-style tools to Converse API format."""
|
|
462
|
+
if not tools:
|
|
463
|
+
return []
|
|
464
|
+
|
|
465
|
+
converse_tools = []
|
|
466
|
+
for tool in tools:
|
|
467
|
+
if tool.get("type") == "function" and "function" in tool:
|
|
468
|
+
func = tool["function"]
|
|
469
|
+
converse_tool = {
|
|
470
|
+
"toolSpec": {
|
|
471
|
+
"name": func["name"],
|
|
472
|
+
"description": func.get("description", ""),
|
|
473
|
+
"inputSchema": {
|
|
474
|
+
"json": func.get("parameters", {})
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
converse_tools.append(converse_tool)
|
|
479
|
+
|
|
480
|
+
return converse_tools
|
|
481
|
+
|
|
482
|
+
def _generate_with_tools(self, messages: List[Dict[str, str]], tools: List[Dict], stream: bool = False) -> Dict[str, Any]:
|
|
483
|
+
"""Generate response with tool calling support using correct message format."""
|
|
484
|
+
# Format messages for tool-enabled models
|
|
485
|
+
system_message = None
|
|
486
|
+
if self.provider == "anthropic":
|
|
487
|
+
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
|
488
|
+
elif self.provider == "amazon":
|
|
489
|
+
formatted_messages = self._format_messages_amazon(messages)
|
|
490
|
+
else:
|
|
491
|
+
formatted_messages = [{"role": "user", "content": [{"text": messages[-1]["content"]}]}]
|
|
492
|
+
|
|
493
|
+
# Prepare tool configuration in Converse API format
|
|
494
|
+
tool_config = None
|
|
495
|
+
if tools:
|
|
496
|
+
converse_tools = self._convert_tools_to_converse_format(tools)
|
|
497
|
+
if converse_tools:
|
|
498
|
+
tool_config = {"tools": converse_tools}
|
|
499
|
+
|
|
500
|
+
# Prepare converse parameters
|
|
501
|
+
converse_params = {
|
|
502
|
+
"modelId": self.config.model,
|
|
503
|
+
"messages": formatted_messages,
|
|
504
|
+
"inferenceConfig": {
|
|
505
|
+
"maxTokens": self.model_config.get("max_tokens", 2000),
|
|
506
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
507
|
+
"topP": self.model_config.get("top_p", 0.9),
|
|
508
|
+
}
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
# Add system message if present (for Anthropic)
|
|
512
|
+
if system_message:
|
|
513
|
+
converse_params["system"] = [{"text": system_message}]
|
|
514
|
+
|
|
515
|
+
# Add tool config if present
|
|
516
|
+
if tool_config:
|
|
517
|
+
converse_params["toolConfig"] = tool_config
|
|
518
|
+
|
|
519
|
+
# Make API call
|
|
520
|
+
response = self.client.converse(**converse_params)
|
|
521
|
+
|
|
522
|
+
return self._parse_response(response, tools)
|
|
523
|
+
|
|
524
|
+
def _generate_standard(self, messages: List[Dict[str, str]], stream: bool = False) -> str:
|
|
525
|
+
"""Generate standard text response using Converse API for Anthropic models."""
|
|
526
|
+
# For Anthropic models, always use Converse API
|
|
527
|
+
if self.provider == "anthropic":
|
|
528
|
+
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
|
529
|
+
|
|
530
|
+
# Prepare converse parameters
|
|
531
|
+
converse_params = {
|
|
532
|
+
"modelId": self.config.model,
|
|
533
|
+
"messages": formatted_messages,
|
|
534
|
+
"inferenceConfig": {
|
|
535
|
+
"maxTokens": self.model_config.get("max_tokens", 2000),
|
|
536
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
537
|
+
"topP": self.model_config.get("top_p", 0.9),
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
# Add system message if present
|
|
542
|
+
if system_message:
|
|
543
|
+
converse_params["system"] = [{"text": system_message}]
|
|
544
|
+
|
|
545
|
+
# Use converse API for Anthropic models
|
|
546
|
+
response = self.client.converse(**converse_params)
|
|
547
|
+
|
|
548
|
+
# Parse Converse API response
|
|
549
|
+
if hasattr(response, 'output') and hasattr(response.output, 'message'):
|
|
550
|
+
return response.output.message.content[0].text
|
|
551
|
+
elif 'output' in response and 'message' in response['output']:
|
|
552
|
+
return response['output']['message']['content'][0]['text']
|
|
553
|
+
else:
|
|
554
|
+
return str(response)
|
|
555
|
+
|
|
556
|
+
elif self.provider == "amazon" and "nova" in self.config.model.lower():
|
|
557
|
+
# Nova models use converse API even without tools
|
|
558
|
+
formatted_messages = self._format_messages_amazon(messages)
|
|
559
|
+
input_body = {
|
|
560
|
+
"messages": formatted_messages,
|
|
561
|
+
"max_tokens": self.model_config.get("max_tokens", 5000),
|
|
562
|
+
"temperature": self.model_config.get("temperature", 0.1),
|
|
563
|
+
"top_p": self.model_config.get("top_p", 0.9),
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
# Use converse API for Nova models
|
|
567
|
+
response = self.client.converse(
|
|
568
|
+
modelId=self.config.model,
|
|
569
|
+
messages=input_body["messages"],
|
|
570
|
+
inferenceConfig={
|
|
571
|
+
"maxTokens": input_body["max_tokens"],
|
|
572
|
+
"temperature": input_body["temperature"],
|
|
573
|
+
"topP": input_body["top_p"],
|
|
574
|
+
}
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
return self._parse_response(response)
|
|
578
|
+
else:
|
|
579
|
+
# For other providers and legacy Amazon models (like Titan)
|
|
580
|
+
if self.provider == "amazon":
|
|
581
|
+
# Legacy Amazon models need string formatting, not array formatting
|
|
582
|
+
prompt = self._format_messages_generic(messages)
|
|
583
|
+
else:
|
|
584
|
+
prompt = self._format_messages(messages)
|
|
585
|
+
input_body = self._prepare_input(prompt)
|
|
586
|
+
|
|
587
|
+
# Convert to JSON
|
|
588
|
+
body = json.dumps(input_body)
|
|
589
|
+
|
|
590
|
+
# Make API call
|
|
591
|
+
response = self.client.invoke_model(
|
|
592
|
+
body=body,
|
|
593
|
+
modelId=self.config.model,
|
|
594
|
+
accept="application/json",
|
|
595
|
+
contentType="application/json",
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
return self._parse_response(response)
|
|
599
|
+
|
|
600
|
+
def list_available_models(self) -> List[Dict[str, Any]]:
|
|
601
|
+
"""List all available models in the current region."""
|
|
602
|
+
try:
|
|
603
|
+
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
|
604
|
+
response = bedrock_client.list_foundation_models()
|
|
605
|
+
|
|
606
|
+
models = []
|
|
607
|
+
for model in response["modelSummaries"]:
|
|
608
|
+
provider = extract_provider(model["modelId"])
|
|
609
|
+
models.append(
|
|
610
|
+
{
|
|
611
|
+
"model_id": model["modelId"],
|
|
612
|
+
"provider": provider,
|
|
613
|
+
"model_name": model["modelId"].split(".", 1)[1]
|
|
614
|
+
if "." in model["modelId"]
|
|
615
|
+
else model["modelId"],
|
|
616
|
+
"modelArn": model.get("modelArn", ""),
|
|
617
|
+
"providerName": model.get("providerName", ""),
|
|
618
|
+
"inputModalities": model.get("inputModalities", []),
|
|
619
|
+
"outputModalities": model.get("outputModalities", []),
|
|
620
|
+
"responseStreamingSupported": model.get("responseStreamingSupported", False),
|
|
621
|
+
}
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return models
|
|
625
|
+
|
|
626
|
+
except Exception as e:
|
|
627
|
+
logger.warning(f"Could not list models: {e}")
|
|
628
|
+
return []
|
|
629
|
+
|
|
630
|
+
def get_model_capabilities(self) -> Dict[str, Any]:
|
|
631
|
+
"""Get capabilities of the current model."""
|
|
632
|
+
return {
|
|
633
|
+
"model_id": self.config.model,
|
|
634
|
+
"provider": self.provider,
|
|
635
|
+
"model_name": self.config.model_name,
|
|
636
|
+
"supports_tools": self.supports_tools,
|
|
637
|
+
"supports_vision": self.supports_vision,
|
|
638
|
+
"supports_streaming": self.supports_streaming,
|
|
639
|
+
"max_tokens": self.model_config.get("max_tokens", 2000),
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
def validate_model_access(self) -> bool:
|
|
643
|
+
"""Validate if the model is accessible."""
|
|
644
|
+
try:
|
|
645
|
+
# Try to invoke the model with a minimal request
|
|
646
|
+
if self.provider == "amazon" and "nova" in self.config.model.lower():
|
|
647
|
+
# Test Nova model with converse API
|
|
648
|
+
test_messages = [{"role": "user", "content": "test"}]
|
|
649
|
+
self.client.converse(
|
|
650
|
+
modelId=self.config.model,
|
|
651
|
+
messages=test_messages,
|
|
652
|
+
inferenceConfig={"maxTokens": 10}
|
|
653
|
+
)
|
|
654
|
+
else:
|
|
655
|
+
# Test other models with invoke_model
|
|
656
|
+
test_body = json.dumps({"prompt": "test"})
|
|
657
|
+
self.client.invoke_model(
|
|
658
|
+
body=test_body,
|
|
659
|
+
modelId=self.config.model,
|
|
660
|
+
accept="application/json",
|
|
661
|
+
contentType="application/json",
|
|
662
|
+
)
|
|
663
|
+
return True
|
|
664
|
+
except Exception:
|
|
665
|
+
return False
|