agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,665 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ try:
7
+ import boto3
8
+ from botocore.exceptions import ClientError, NoCredentialsError
9
+ except ImportError:
10
+ raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
11
+
12
+ from agentrun_mem0.configs.llms.base import BaseLlmConfig
13
+ from agentrun_mem0.configs.llms.aws_bedrock import AWSBedrockConfig
14
+ from agentrun_mem0.llms.base import LLMBase
15
+ from agentrun_mem0.memory.utils import extract_json
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ PROVIDERS = [
20
+ "ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer",
21
+ "deepseek", "gpt-oss", "perplexity", "snowflake", "titan", "command", "j2", "llama"
22
+ ]
23
+
24
+
25
+ def extract_provider(model: str) -> str:
26
+ """Extract provider from model identifier."""
27
+ for provider in PROVIDERS:
28
+ if re.search(rf"\b{re.escape(provider)}\b", model):
29
+ return provider
30
+ raise ValueError(f"Unknown provider in model: {model}")
31
+
32
+
33
+ class AWSBedrockLLM(LLMBase):
34
+ """
35
+ AWS Bedrock LLM integration for Mem0.
36
+
37
+ Supports all available Bedrock models with automatic provider detection.
38
+ """
39
+
40
+ def __init__(self, config: Optional[Union[AWSBedrockConfig, BaseLlmConfig, Dict]] = None):
41
+ """
42
+ Initialize AWS Bedrock LLM.
43
+
44
+ Args:
45
+ config: AWS Bedrock configuration object
46
+ """
47
+ # Convert to AWSBedrockConfig if needed
48
+ if config is None:
49
+ config = AWSBedrockConfig()
50
+ elif isinstance(config, dict):
51
+ config = AWSBedrockConfig(**config)
52
+ elif isinstance(config, BaseLlmConfig) and not isinstance(config, AWSBedrockConfig):
53
+ # Convert BaseLlmConfig to AWSBedrockConfig
54
+ config = AWSBedrockConfig(
55
+ model=config.model,
56
+ temperature=config.temperature,
57
+ max_tokens=config.max_tokens,
58
+ top_p=config.top_p,
59
+ top_k=config.top_k,
60
+ enable_vision=getattr(config, "enable_vision", False),
61
+ )
62
+
63
+ super().__init__(config)
64
+ self.config = config
65
+
66
+ # Initialize AWS client
67
+ self._initialize_aws_client()
68
+
69
+ # Get model configuration
70
+ self.model_config = self.config.get_model_config()
71
+ self.provider = extract_provider(self.config.model)
72
+
73
+ # Initialize provider-specific settings
74
+ self._initialize_provider_settings()
75
+
76
+ def _initialize_aws_client(self):
77
+ """Initialize AWS Bedrock client with proper credentials."""
78
+ try:
79
+ aws_config = self.config.get_aws_config()
80
+
81
+ # Create Bedrock runtime client
82
+ self.client = boto3.client("bedrock-runtime", **aws_config)
83
+
84
+ # Test connection
85
+ self._test_connection()
86
+
87
+ except NoCredentialsError:
88
+ raise ValueError(
89
+ "AWS credentials not found. Please set AWS_ACCESS_KEY_ID, "
90
+ "AWS_SECRET_ACCESS_KEY, and AWS_REGION environment variables, "
91
+ "or provide them in the config."
92
+ )
93
+ except ClientError as e:
94
+ if e.response["Error"]["Code"] == "UnauthorizedOperation":
95
+ raise ValueError(
96
+ f"Unauthorized access to Bedrock. Please ensure your AWS credentials "
97
+ f"have permission to access Bedrock in region {self.config.aws_region}."
98
+ )
99
+ else:
100
+ raise ValueError(f"AWS Bedrock error: {e}")
101
+
102
+ def _test_connection(self):
103
+ """Test connection to AWS Bedrock service."""
104
+ try:
105
+ # List available models to test connection
106
+ bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
107
+ response = bedrock_client.list_foundation_models()
108
+ self.available_models = [model["modelId"] for model in response["modelSummaries"]]
109
+
110
+ # Check if our model is available
111
+ if self.config.model not in self.available_models:
112
+ logger.warning(f"Model {self.config.model} may not be available in region {self.config.aws_region}")
113
+ logger.info(f"Available models: {', '.join(self.available_models[:5])}...")
114
+
115
+ except Exception as e:
116
+ logger.warning(f"Could not verify model availability: {e}")
117
+ self.available_models = []
118
+
119
+ def _initialize_provider_settings(self):
120
+ """Initialize provider-specific settings and capabilities."""
121
+ # Determine capabilities based on provider and model
122
+ self.supports_tools = self.provider in ["anthropic", "cohere", "amazon"]
123
+ self.supports_vision = self.provider in ["anthropic", "amazon", "meta", "mistral"]
124
+ self.supports_streaming = self.provider in ["anthropic", "cohere", "mistral", "amazon", "meta"]
125
+
126
+ # Set message formatting method
127
+ if self.provider == "anthropic":
128
+ self._format_messages = self._format_messages_anthropic
129
+ elif self.provider == "cohere":
130
+ self._format_messages = self._format_messages_cohere
131
+ elif self.provider == "amazon":
132
+ self._format_messages = self._format_messages_amazon
133
+ elif self.provider == "meta":
134
+ self._format_messages = self._format_messages_meta
135
+ elif self.provider == "mistral":
136
+ self._format_messages = self._format_messages_mistral
137
+ else:
138
+ self._format_messages = self._format_messages_generic
139
+
140
+ def _format_messages_anthropic(self, messages: List[Dict[str, str]]) -> tuple[List[Dict[str, Any]], Optional[str]]:
141
+ """Format messages for Anthropic models."""
142
+ formatted_messages = []
143
+ system_message = None
144
+
145
+ for message in messages:
146
+ role = message["role"]
147
+ content = message["content"]
148
+
149
+ if role == "system":
150
+ # Anthropic supports system messages as a separate parameter
151
+ # see: https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
152
+ system_message = content
153
+ elif role == "user":
154
+ # Use Converse API format
155
+ formatted_messages.append({"role": "user", "content": [{"text": content}]})
156
+ elif role == "assistant":
157
+ # Use Converse API format
158
+ formatted_messages.append({"role": "assistant", "content": [{"text": content}]})
159
+
160
+ return formatted_messages, system_message
161
+
162
+ def _format_messages_cohere(self, messages: List[Dict[str, str]]) -> str:
163
+ """Format messages for Cohere models."""
164
+ formatted_messages = []
165
+
166
+ for message in messages:
167
+ role = message["role"].capitalize()
168
+ content = message["content"]
169
+ formatted_messages.append(f"{role}: {content}")
170
+
171
+ return "\n".join(formatted_messages)
172
+
173
+ def _format_messages_amazon(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
174
+ """Format messages for Amazon models (including Nova)."""
175
+ formatted_messages = []
176
+
177
+ for message in messages:
178
+ role = message["role"]
179
+ content = message["content"]
180
+
181
+ if role == "system":
182
+ # Amazon models support system messages
183
+ formatted_messages.append({"role": "system", "content": content})
184
+ elif role == "user":
185
+ formatted_messages.append({"role": "user", "content": content})
186
+ elif role == "assistant":
187
+ formatted_messages.append({"role": "assistant", "content": content})
188
+
189
+ return formatted_messages
190
+
191
+ def _format_messages_meta(self, messages: List[Dict[str, str]]) -> str:
192
+ """Format messages for Meta models."""
193
+ formatted_messages = []
194
+
195
+ for message in messages:
196
+ role = message["role"].capitalize()
197
+ content = message["content"]
198
+ formatted_messages.append(f"{role}: {content}")
199
+
200
+ return "\n".join(formatted_messages)
201
+
202
+ def _format_messages_mistral(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
203
+ """Format messages for Mistral models."""
204
+ formatted_messages = []
205
+
206
+ for message in messages:
207
+ role = message["role"]
208
+ content = message["content"]
209
+
210
+ if role == "system":
211
+ # Mistral supports system messages
212
+ formatted_messages.append({"role": "system", "content": content})
213
+ elif role == "user":
214
+ formatted_messages.append({"role": "user", "content": content})
215
+ elif role == "assistant":
216
+ formatted_messages.append({"role": "assistant", "content": content})
217
+
218
+ return formatted_messages
219
+
220
+ def _format_messages_generic(self, messages: List[Dict[str, str]]) -> str:
221
+ """Generic message formatting for other providers."""
222
+ formatted_messages = []
223
+
224
+ for message in messages:
225
+ role = message["role"].capitalize()
226
+ content = message["content"]
227
+ formatted_messages.append(f"\n\n{role}: {content}")
228
+
229
+ return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
230
+
231
+ def _prepare_input(self, prompt: str) -> Dict[str, Any]:
232
+ """
233
+ Prepare input for the current provider's model.
234
+
235
+ Args:
236
+ prompt: Text prompt to process
237
+
238
+ Returns:
239
+ Prepared input dictionary
240
+ """
241
+ # Base configuration
242
+ input_body = {"prompt": prompt}
243
+
244
+ # Provider-specific parameter mappings
245
+ provider_mappings = {
246
+ "meta": {"max_tokens": "max_gen_len"},
247
+ "ai21": {"max_tokens": "maxTokens", "top_p": "topP"},
248
+ "mistral": {"max_tokens": "max_tokens"},
249
+ "cohere": {"max_tokens": "max_tokens", "top_p": "p"},
250
+ "amazon": {"max_tokens": "maxTokenCount", "top_p": "topP"},
251
+ "anthropic": {"max_tokens": "max_tokens", "top_p": "top_p"},
252
+ }
253
+
254
+ # Apply provider mappings
255
+ if self.provider in provider_mappings:
256
+ for old_key, new_key in provider_mappings[self.provider].items():
257
+ if old_key in self.model_config:
258
+ input_body[new_key] = self.model_config[old_key]
259
+
260
+ # Special handling for specific providers
261
+ if self.provider == "cohere" and "cohere.command" in self.config.model:
262
+ input_body["message"] = input_body.pop("prompt")
263
+ elif self.provider == "amazon":
264
+ # Amazon Nova and other Amazon models
265
+ if "nova" in self.config.model.lower():
266
+ # Nova models use the converse API format
267
+ input_body = {
268
+ "messages": [{"role": "user", "content": prompt}],
269
+ "max_tokens": self.model_config.get("max_tokens", 5000),
270
+ "temperature": self.model_config.get("temperature", 0.1),
271
+ "top_p": self.model_config.get("top_p", 0.9),
272
+ }
273
+ else:
274
+ # Legacy Amazon models
275
+ input_body = {
276
+ "inputText": prompt,
277
+ "textGenerationConfig": {
278
+ "maxTokenCount": self.model_config.get("max_tokens", 5000),
279
+ "topP": self.model_config.get("top_p", 0.9),
280
+ "temperature": self.model_config.get("temperature", 0.1),
281
+ },
282
+ }
283
+ # Remove None values
284
+ input_body["textGenerationConfig"] = {
285
+ k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
286
+ }
287
+ elif self.provider == "anthropic":
288
+ input_body = {
289
+ "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
290
+ "max_tokens": self.model_config.get("max_tokens", 2000),
291
+ "temperature": self.model_config.get("temperature", 0.1),
292
+ "top_p": self.model_config.get("top_p", 0.9),
293
+ "anthropic_version": "bedrock-2023-05-31",
294
+ }
295
+ elif self.provider == "meta":
296
+ input_body = {
297
+ "prompt": prompt,
298
+ "max_gen_len": self.model_config.get("max_tokens", 5000),
299
+ "temperature": self.model_config.get("temperature", 0.1),
300
+ "top_p": self.model_config.get("top_p", 0.9),
301
+ }
302
+ elif self.provider == "mistral":
303
+ input_body = {
304
+ "prompt": prompt,
305
+ "max_tokens": self.model_config.get("max_tokens", 5000),
306
+ "temperature": self.model_config.get("temperature", 0.1),
307
+ "top_p": self.model_config.get("top_p", 0.9),
308
+ }
309
+ else:
310
+ # Generic case - add all model config parameters
311
+ input_body.update(self.model_config)
312
+
313
+ return input_body
314
+
315
+ def _convert_tool_format(self, original_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
316
+ """
317
+ Convert tools to Bedrock-compatible format.
318
+
319
+ Args:
320
+ original_tools: List of tool definitions
321
+
322
+ Returns:
323
+ Converted tools in Bedrock format
324
+ """
325
+ new_tools = []
326
+
327
+ for tool in original_tools:
328
+ if tool["type"] == "function":
329
+ function = tool["function"]
330
+ new_tool = {
331
+ "toolSpec": {
332
+ "name": function["name"],
333
+ "description": function.get("description", ""),
334
+ "inputSchema": {
335
+ "json": {
336
+ "type": "object",
337
+ "properties": {},
338
+ "required": function["parameters"].get("required", []),
339
+ }
340
+ },
341
+ }
342
+ }
343
+
344
+ # Add properties
345
+ for prop, details in function["parameters"].get("properties", {}).items():
346
+ new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
347
+
348
+ new_tools.append(new_tool)
349
+
350
+ return new_tools
351
+
352
+ def _parse_response(
353
+ self, response: Dict[str, Any], tools: Optional[List[Dict]] = None
354
+ ) -> Union[str, Dict[str, Any]]:
355
+ """
356
+ Parse response from Bedrock API.
357
+
358
+ Args:
359
+ response: Raw API response
360
+ tools: List of tools if used
361
+
362
+ Returns:
363
+ Parsed response
364
+ """
365
+ if tools:
366
+ # Handle tool-enabled responses
367
+ processed_response = {"tool_calls": []}
368
+
369
+ if response.get("output", {}).get("message", {}).get("content"):
370
+ for item in response["output"]["message"]["content"]:
371
+ if "toolUse" in item:
372
+ processed_response["tool_calls"].append(
373
+ {
374
+ "name": item["toolUse"]["name"],
375
+ "arguments": json.loads(extract_json(json.dumps(item["toolUse"]["input"]))),
376
+ }
377
+ )
378
+
379
+ return processed_response
380
+
381
+ # Handle regular text responses
382
+ try:
383
+ response_body = response.get("body").read().decode()
384
+ response_json = json.loads(response_body)
385
+
386
+ # Provider-specific response parsing
387
+ if self.provider == "anthropic":
388
+ return response_json.get("content", [{"text": ""}])[0].get("text", "")
389
+ elif self.provider == "amazon":
390
+ # Handle both Nova and legacy Amazon models
391
+ if "nova" in self.config.model.lower():
392
+ # Nova models return content in a different format
393
+ if "content" in response_json:
394
+ return response_json["content"][0]["text"]
395
+ elif "completion" in response_json:
396
+ return response_json["completion"]
397
+ else:
398
+ # Legacy Amazon models
399
+ return response_json.get("completion", "")
400
+ elif self.provider == "meta":
401
+ return response_json.get("generation", "")
402
+ elif self.provider == "mistral":
403
+ return response_json.get("outputs", [{"text": ""}])[0].get("text", "")
404
+ elif self.provider == "cohere":
405
+ return response_json.get("generations", [{"text": ""}])[0].get("text", "")
406
+ elif self.provider == "ai21":
407
+ return response_json.get("completions", [{"data", {"text": ""}}])[0].get("data", {}).get("text", "")
408
+ else:
409
+ # Generic parsing - try common response fields
410
+ for field in ["content", "text", "completion", "generation"]:
411
+ if field in response_json:
412
+ if isinstance(response_json[field], list) and response_json[field]:
413
+ return response_json[field][0].get("text", "")
414
+ elif isinstance(response_json[field], str):
415
+ return response_json[field]
416
+
417
+ # Fallback
418
+ return str(response_json)
419
+
420
+ except Exception as e:
421
+ logger.warning(f"Could not parse response: {e}")
422
+ return "Error parsing response"
423
+
424
+ def generate_response(
425
+ self,
426
+ messages: List[Dict[str, str]],
427
+ response_format: Optional[str] = None,
428
+ tools: Optional[List[Dict]] = None,
429
+ tool_choice: str = "auto",
430
+ stream: bool = False,
431
+ **kwargs,
432
+ ) -> Union[str, Dict[str, Any]]:
433
+ """
434
+ Generate response using AWS Bedrock.
435
+
436
+ Args:
437
+ messages: List of message dictionaries
438
+ response_format: Response format specification
439
+ tools: List of tools for function calling
440
+ tool_choice: Tool choice method
441
+ stream: Whether to stream the response
442
+ **kwargs: Additional parameters
443
+
444
+ Returns:
445
+ Generated response
446
+ """
447
+ try:
448
+ if tools and self.supports_tools:
449
+ # Use converse method for tool-enabled models
450
+ return self._generate_with_tools(messages, tools, stream)
451
+ else:
452
+ # Use standard invoke_model method
453
+ return self._generate_standard(messages, stream)
454
+
455
+ except Exception as e:
456
+ logger.error(f"Failed to generate response: {e}")
457
+ raise RuntimeError(f"Failed to generate response: {e}")
458
+
459
+ @staticmethod
460
+ def _convert_tools_to_converse_format(tools: List[Dict]) -> List[Dict]:
461
+ """Convert OpenAI-style tools to Converse API format."""
462
+ if not tools:
463
+ return []
464
+
465
+ converse_tools = []
466
+ for tool in tools:
467
+ if tool.get("type") == "function" and "function" in tool:
468
+ func = tool["function"]
469
+ converse_tool = {
470
+ "toolSpec": {
471
+ "name": func["name"],
472
+ "description": func.get("description", ""),
473
+ "inputSchema": {
474
+ "json": func.get("parameters", {})
475
+ }
476
+ }
477
+ }
478
+ converse_tools.append(converse_tool)
479
+
480
+ return converse_tools
481
+
482
+ def _generate_with_tools(self, messages: List[Dict[str, str]], tools: List[Dict], stream: bool = False) -> Dict[str, Any]:
483
+ """Generate response with tool calling support using correct message format."""
484
+ # Format messages for tool-enabled models
485
+ system_message = None
486
+ if self.provider == "anthropic":
487
+ formatted_messages, system_message = self._format_messages_anthropic(messages)
488
+ elif self.provider == "amazon":
489
+ formatted_messages = self._format_messages_amazon(messages)
490
+ else:
491
+ formatted_messages = [{"role": "user", "content": [{"text": messages[-1]["content"]}]}]
492
+
493
+ # Prepare tool configuration in Converse API format
494
+ tool_config = None
495
+ if tools:
496
+ converse_tools = self._convert_tools_to_converse_format(tools)
497
+ if converse_tools:
498
+ tool_config = {"tools": converse_tools}
499
+
500
+ # Prepare converse parameters
501
+ converse_params = {
502
+ "modelId": self.config.model,
503
+ "messages": formatted_messages,
504
+ "inferenceConfig": {
505
+ "maxTokens": self.model_config.get("max_tokens", 2000),
506
+ "temperature": self.model_config.get("temperature", 0.1),
507
+ "topP": self.model_config.get("top_p", 0.9),
508
+ }
509
+ }
510
+
511
+ # Add system message if present (for Anthropic)
512
+ if system_message:
513
+ converse_params["system"] = [{"text": system_message}]
514
+
515
+ # Add tool config if present
516
+ if tool_config:
517
+ converse_params["toolConfig"] = tool_config
518
+
519
+ # Make API call
520
+ response = self.client.converse(**converse_params)
521
+
522
+ return self._parse_response(response, tools)
523
+
524
+ def _generate_standard(self, messages: List[Dict[str, str]], stream: bool = False) -> str:
525
+ """Generate standard text response using Converse API for Anthropic models."""
526
+ # For Anthropic models, always use Converse API
527
+ if self.provider == "anthropic":
528
+ formatted_messages, system_message = self._format_messages_anthropic(messages)
529
+
530
+ # Prepare converse parameters
531
+ converse_params = {
532
+ "modelId": self.config.model,
533
+ "messages": formatted_messages,
534
+ "inferenceConfig": {
535
+ "maxTokens": self.model_config.get("max_tokens", 2000),
536
+ "temperature": self.model_config.get("temperature", 0.1),
537
+ "topP": self.model_config.get("top_p", 0.9),
538
+ }
539
+ }
540
+
541
+ # Add system message if present
542
+ if system_message:
543
+ converse_params["system"] = [{"text": system_message}]
544
+
545
+ # Use converse API for Anthropic models
546
+ response = self.client.converse(**converse_params)
547
+
548
+ # Parse Converse API response
549
+ if hasattr(response, 'output') and hasattr(response.output, 'message'):
550
+ return response.output.message.content[0].text
551
+ elif 'output' in response and 'message' in response['output']:
552
+ return response['output']['message']['content'][0]['text']
553
+ else:
554
+ return str(response)
555
+
556
+ elif self.provider == "amazon" and "nova" in self.config.model.lower():
557
+ # Nova models use converse API even without tools
558
+ formatted_messages = self._format_messages_amazon(messages)
559
+ input_body = {
560
+ "messages": formatted_messages,
561
+ "max_tokens": self.model_config.get("max_tokens", 5000),
562
+ "temperature": self.model_config.get("temperature", 0.1),
563
+ "top_p": self.model_config.get("top_p", 0.9),
564
+ }
565
+
566
+ # Use converse API for Nova models
567
+ response = self.client.converse(
568
+ modelId=self.config.model,
569
+ messages=input_body["messages"],
570
+ inferenceConfig={
571
+ "maxTokens": input_body["max_tokens"],
572
+ "temperature": input_body["temperature"],
573
+ "topP": input_body["top_p"],
574
+ }
575
+ )
576
+
577
+ return self._parse_response(response)
578
+ else:
579
+ # For other providers and legacy Amazon models (like Titan)
580
+ if self.provider == "amazon":
581
+ # Legacy Amazon models need string formatting, not array formatting
582
+ prompt = self._format_messages_generic(messages)
583
+ else:
584
+ prompt = self._format_messages(messages)
585
+ input_body = self._prepare_input(prompt)
586
+
587
+ # Convert to JSON
588
+ body = json.dumps(input_body)
589
+
590
+ # Make API call
591
+ response = self.client.invoke_model(
592
+ body=body,
593
+ modelId=self.config.model,
594
+ accept="application/json",
595
+ contentType="application/json",
596
+ )
597
+
598
+ return self._parse_response(response)
599
+
600
+ def list_available_models(self) -> List[Dict[str, Any]]:
601
+ """List all available models in the current region."""
602
+ try:
603
+ bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
604
+ response = bedrock_client.list_foundation_models()
605
+
606
+ models = []
607
+ for model in response["modelSummaries"]:
608
+ provider = extract_provider(model["modelId"])
609
+ models.append(
610
+ {
611
+ "model_id": model["modelId"],
612
+ "provider": provider,
613
+ "model_name": model["modelId"].split(".", 1)[1]
614
+ if "." in model["modelId"]
615
+ else model["modelId"],
616
+ "modelArn": model.get("modelArn", ""),
617
+ "providerName": model.get("providerName", ""),
618
+ "inputModalities": model.get("inputModalities", []),
619
+ "outputModalities": model.get("outputModalities", []),
620
+ "responseStreamingSupported": model.get("responseStreamingSupported", False),
621
+ }
622
+ )
623
+
624
+ return models
625
+
626
+ except Exception as e:
627
+ logger.warning(f"Could not list models: {e}")
628
+ return []
629
+
630
+ def get_model_capabilities(self) -> Dict[str, Any]:
631
+ """Get capabilities of the current model."""
632
+ return {
633
+ "model_id": self.config.model,
634
+ "provider": self.provider,
635
+ "model_name": self.config.model_name,
636
+ "supports_tools": self.supports_tools,
637
+ "supports_vision": self.supports_vision,
638
+ "supports_streaming": self.supports_streaming,
639
+ "max_tokens": self.model_config.get("max_tokens", 2000),
640
+ }
641
+
642
+ def validate_model_access(self) -> bool:
643
+ """Validate if the model is accessible."""
644
+ try:
645
+ # Try to invoke the model with a minimal request
646
+ if self.provider == "amazon" and "nova" in self.config.model.lower():
647
+ # Test Nova model with converse API
648
+ test_messages = [{"role": "user", "content": "test"}]
649
+ self.client.converse(
650
+ modelId=self.config.model,
651
+ messages=test_messages,
652
+ inferenceConfig={"maxTokens": 10}
653
+ )
654
+ else:
655
+ # Test other models with invoke_model
656
+ test_body = json.dumps({"prompt": "test"})
657
+ self.client.invoke_model(
658
+ body=test_body,
659
+ modelId=self.config.model,
660
+ accept="application/json",
661
+ contentType="application/json",
662
+ )
663
+ return True
664
+ except Exception:
665
+ return False