agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
6
|
+
from openai import AzureOpenAI
|
|
7
|
+
|
|
8
|
+
from agentrun_mem0.configs.llms.azure import AzureOpenAIConfig
|
|
9
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
10
|
+
from agentrun_mem0.llms.base import LLMBase
|
|
11
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
12
|
+
|
|
13
|
+
SCOPE = "https://cognitiveservices.azure.com/.default"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AzureOpenAILLM(LLMBase):
|
|
17
|
+
def __init__(self, config: Optional[Union[BaseLlmConfig, AzureOpenAIConfig, Dict]] = None):
|
|
18
|
+
# Convert to AzureOpenAIConfig if needed
|
|
19
|
+
if config is None:
|
|
20
|
+
config = AzureOpenAIConfig()
|
|
21
|
+
elif isinstance(config, dict):
|
|
22
|
+
config = AzureOpenAIConfig(**config)
|
|
23
|
+
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AzureOpenAIConfig):
|
|
24
|
+
# Convert BaseLlmConfig to AzureOpenAIConfig
|
|
25
|
+
config = AzureOpenAIConfig(
|
|
26
|
+
model=config.model,
|
|
27
|
+
temperature=config.temperature,
|
|
28
|
+
api_key=config.api_key,
|
|
29
|
+
max_tokens=config.max_tokens,
|
|
30
|
+
top_p=config.top_p,
|
|
31
|
+
top_k=config.top_k,
|
|
32
|
+
enable_vision=config.enable_vision,
|
|
33
|
+
vision_details=config.vision_details,
|
|
34
|
+
http_client_proxies=config.http_client,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
super().__init__(config)
|
|
38
|
+
|
|
39
|
+
# Model name should match the custom deployment name chosen for it.
|
|
40
|
+
if not self.config.model:
|
|
41
|
+
self.config.model = "gpt-4.1-nano-2025-04-14"
|
|
42
|
+
|
|
43
|
+
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
|
44
|
+
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
|
45
|
+
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
|
46
|
+
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
|
47
|
+
default_headers = self.config.azure_kwargs.default_headers
|
|
48
|
+
|
|
49
|
+
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
|
50
|
+
if api_key is None or api_key == "" or api_key == "your-api-key":
|
|
51
|
+
self.credential = DefaultAzureCredential()
|
|
52
|
+
azure_ad_token_provider = get_bearer_token_provider(
|
|
53
|
+
self.credential,
|
|
54
|
+
SCOPE,
|
|
55
|
+
)
|
|
56
|
+
api_key = None
|
|
57
|
+
else:
|
|
58
|
+
azure_ad_token_provider = None
|
|
59
|
+
|
|
60
|
+
self.client = AzureOpenAI(
|
|
61
|
+
azure_deployment=azure_deployment,
|
|
62
|
+
azure_endpoint=azure_endpoint,
|
|
63
|
+
azure_ad_token_provider=azure_ad_token_provider,
|
|
64
|
+
api_version=api_version,
|
|
65
|
+
api_key=api_key,
|
|
66
|
+
http_client=self.config.http_client,
|
|
67
|
+
default_headers=default_headers,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _parse_response(self, response, tools):
|
|
71
|
+
"""
|
|
72
|
+
Process the response based on whether tools are used or not.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
response: The raw response from API.
|
|
76
|
+
tools: The list of tools provided in the request.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
str or dict: The processed response.
|
|
80
|
+
"""
|
|
81
|
+
if tools:
|
|
82
|
+
processed_response = {
|
|
83
|
+
"content": response.choices[0].message.content,
|
|
84
|
+
"tool_calls": [],
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if response.choices[0].message.tool_calls:
|
|
88
|
+
for tool_call in response.choices[0].message.tool_calls:
|
|
89
|
+
processed_response["tool_calls"].append(
|
|
90
|
+
{
|
|
91
|
+
"name": tool_call.function.name,
|
|
92
|
+
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return processed_response
|
|
97
|
+
else:
|
|
98
|
+
return response.choices[0].message.content
|
|
99
|
+
|
|
100
|
+
def generate_response(
|
|
101
|
+
self,
|
|
102
|
+
messages: List[Dict[str, str]],
|
|
103
|
+
response_format=None,
|
|
104
|
+
tools: Optional[List[Dict]] = None,
|
|
105
|
+
tool_choice: str = "auto",
|
|
106
|
+
**kwargs,
|
|
107
|
+
):
|
|
108
|
+
"""
|
|
109
|
+
Generate a response based on the given messages using Azure OpenAI.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
messages (list): List of message dicts containing 'role' and 'content'.
|
|
113
|
+
response_format (str or object, optional): Format of the response. Defaults to "text".
|
|
114
|
+
tools (list, optional): List of tools that the model can call. Defaults to None.
|
|
115
|
+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
|
116
|
+
**kwargs: Additional Azure OpenAI-specific parameters.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
str: The generated response.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
user_prompt = messages[-1]["content"]
|
|
123
|
+
|
|
124
|
+
user_prompt = user_prompt.replace("assistant", "ai")
|
|
125
|
+
|
|
126
|
+
messages[-1]["content"] = user_prompt
|
|
127
|
+
|
|
128
|
+
params = self._get_supported_params(messages=messages, **kwargs)
|
|
129
|
+
|
|
130
|
+
# Add model and messages
|
|
131
|
+
params.update({
|
|
132
|
+
"model": self.config.model,
|
|
133
|
+
"messages": messages,
|
|
134
|
+
})
|
|
135
|
+
|
|
136
|
+
if tools:
|
|
137
|
+
params["tools"] = tools
|
|
138
|
+
params["tool_choice"] = tool_choice
|
|
139
|
+
|
|
140
|
+
response = self.client.chat.completions.create(**params)
|
|
141
|
+
return self._parse_response(response, tools)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
5
|
+
from openai import AzureOpenAI
|
|
6
|
+
|
|
7
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
8
|
+
from agentrun_mem0.llms.base import LLMBase
|
|
9
|
+
|
|
10
|
+
SCOPE = "https://cognitiveservices.azure.com/.default"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AzureOpenAIStructuredLLM(LLMBase):
|
|
14
|
+
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
|
15
|
+
super().__init__(config)
|
|
16
|
+
|
|
17
|
+
# Model name should match the custom deployment name chosen for it.
|
|
18
|
+
if not self.config.model:
|
|
19
|
+
self.config.model = "gpt-4.1-nano-2025-04-14"
|
|
20
|
+
|
|
21
|
+
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
|
22
|
+
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
|
23
|
+
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
|
24
|
+
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
|
25
|
+
default_headers = self.config.azure_kwargs.default_headers
|
|
26
|
+
|
|
27
|
+
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
|
28
|
+
if api_key is None or api_key == "" or api_key == "your-api-key":
|
|
29
|
+
self.credential = DefaultAzureCredential()
|
|
30
|
+
azure_ad_token_provider = get_bearer_token_provider(
|
|
31
|
+
self.credential,
|
|
32
|
+
SCOPE,
|
|
33
|
+
)
|
|
34
|
+
api_key = None
|
|
35
|
+
else:
|
|
36
|
+
azure_ad_token_provider = None
|
|
37
|
+
|
|
38
|
+
# Can display a warning if API version is of model and api-version
|
|
39
|
+
self.client = AzureOpenAI(
|
|
40
|
+
azure_deployment=azure_deployment,
|
|
41
|
+
azure_endpoint=azure_endpoint,
|
|
42
|
+
azure_ad_token_provider=azure_ad_token_provider,
|
|
43
|
+
api_version=api_version,
|
|
44
|
+
api_key=api_key,
|
|
45
|
+
http_client=self.config.http_client,
|
|
46
|
+
default_headers=default_headers,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def generate_response(
|
|
50
|
+
self,
|
|
51
|
+
messages: List[Dict[str, str]],
|
|
52
|
+
response_format: Optional[str] = None,
|
|
53
|
+
tools: Optional[List[Dict]] = None,
|
|
54
|
+
tool_choice: str = "auto",
|
|
55
|
+
) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Generate a response based on the given messages using Azure OpenAI.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
|
61
|
+
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The generated response.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
user_prompt = messages[-1]["content"]
|
|
68
|
+
|
|
69
|
+
user_prompt = user_prompt.replace("assistant", "ai")
|
|
70
|
+
|
|
71
|
+
messages[-1]["content"] = user_prompt
|
|
72
|
+
|
|
73
|
+
params = {
|
|
74
|
+
"model": self.config.model,
|
|
75
|
+
"messages": messages,
|
|
76
|
+
"temperature": self.config.temperature,
|
|
77
|
+
"max_tokens": self.config.max_tokens,
|
|
78
|
+
"top_p": self.config.top_p,
|
|
79
|
+
}
|
|
80
|
+
if response_format:
|
|
81
|
+
params["response_format"] = response_format
|
|
82
|
+
if tools:
|
|
83
|
+
params["tools"] = tools
|
|
84
|
+
params["tool_choice"] = tool_choice
|
|
85
|
+
|
|
86
|
+
if tools:
|
|
87
|
+
params["tools"] = tools
|
|
88
|
+
params["tool_choice"] = tool_choice
|
|
89
|
+
|
|
90
|
+
response = self.client.chat.completions.create(**params)
|
|
91
|
+
return self._parse_response(response, tools)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LLMBase(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Base class for all LLM providers.
|
|
10
|
+
Handles common functionality and delegates provider-specific logic to subclasses.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, config: Optional[Union[BaseLlmConfig, Dict]] = None):
|
|
14
|
+
"""Initialize a base LLM class
|
|
15
|
+
|
|
16
|
+
:param config: LLM configuration option class or dict, defaults to None
|
|
17
|
+
:type config: Optional[Union[BaseLlmConfig, Dict]], optional
|
|
18
|
+
"""
|
|
19
|
+
if config is None:
|
|
20
|
+
self.config = BaseLlmConfig()
|
|
21
|
+
elif isinstance(config, dict):
|
|
22
|
+
# Handle dict-based configuration (backward compatibility)
|
|
23
|
+
self.config = BaseLlmConfig(**config)
|
|
24
|
+
else:
|
|
25
|
+
self.config = config
|
|
26
|
+
|
|
27
|
+
# Validate configuration
|
|
28
|
+
self._validate_config()
|
|
29
|
+
|
|
30
|
+
def _validate_config(self):
|
|
31
|
+
"""
|
|
32
|
+
Validate the configuration.
|
|
33
|
+
Override in subclasses to add provider-specific validation.
|
|
34
|
+
"""
|
|
35
|
+
if not hasattr(self.config, "model"):
|
|
36
|
+
raise ValueError("Configuration must have a 'model' attribute")
|
|
37
|
+
|
|
38
|
+
if not hasattr(self.config, "api_key") and not hasattr(self.config, "api_key"):
|
|
39
|
+
# Check if API key is available via environment variable
|
|
40
|
+
# This will be handled by individual providers
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def _is_reasoning_model(self, model: str) -> bool:
|
|
44
|
+
"""
|
|
45
|
+
Check if the model is a reasoning model or GPT-5 series that doesn't support certain parameters.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model: The model name to check
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
bool: True if the model is a reasoning model or GPT-5 series
|
|
52
|
+
"""
|
|
53
|
+
reasoning_models = {
|
|
54
|
+
"o1", "o1-preview", "o3-mini", "o3",
|
|
55
|
+
"gpt-5", "gpt-5o", "gpt-5o-mini", "gpt-5o-micro",
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if model.lower() in reasoning_models:
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
model_lower = model.lower()
|
|
62
|
+
if any(reasoning_model in model_lower for reasoning_model in ["gpt-5", "o1", "o3"]):
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
def _get_supported_params(self, **kwargs) -> Dict:
|
|
68
|
+
"""
|
|
69
|
+
Get parameters that are supported by the current model.
|
|
70
|
+
Filters out unsupported parameters for reasoning models and GPT-5 series.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
**kwargs: Additional parameters to include
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Dict: Filtered parameters dictionary
|
|
77
|
+
"""
|
|
78
|
+
model = getattr(self.config, 'model', '')
|
|
79
|
+
|
|
80
|
+
if self._is_reasoning_model(model):
|
|
81
|
+
supported_params = {}
|
|
82
|
+
|
|
83
|
+
if "messages" in kwargs:
|
|
84
|
+
supported_params["messages"] = kwargs["messages"]
|
|
85
|
+
if "response_format" in kwargs:
|
|
86
|
+
supported_params["response_format"] = kwargs["response_format"]
|
|
87
|
+
if "tools" in kwargs:
|
|
88
|
+
supported_params["tools"] = kwargs["tools"]
|
|
89
|
+
if "tool_choice" in kwargs:
|
|
90
|
+
supported_params["tool_choice"] = kwargs["tool_choice"]
|
|
91
|
+
|
|
92
|
+
return supported_params
|
|
93
|
+
else:
|
|
94
|
+
# For regular models, include all common parameters
|
|
95
|
+
return self._get_common_params(**kwargs)
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def generate_response(
|
|
99
|
+
self, messages: List[Dict[str, str]], tools: Optional[List[Dict]] = None, tool_choice: str = "auto", **kwargs
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Generate a response based on the given messages.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
messages (list): List of message dicts containing 'role' and 'content'.
|
|
106
|
+
tools (list, optional): List of tools that the model can call. Defaults to None.
|
|
107
|
+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
|
108
|
+
**kwargs: Additional provider-specific parameters.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
str or dict: The generated response.
|
|
112
|
+
"""
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
def _get_common_params(self, **kwargs) -> Dict:
|
|
116
|
+
"""
|
|
117
|
+
Get common parameters that most providers use.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Dict: Common parameters dictionary.
|
|
121
|
+
"""
|
|
122
|
+
params = {
|
|
123
|
+
"temperature": self.config.temperature,
|
|
124
|
+
"max_tokens": self.config.max_tokens,
|
|
125
|
+
"top_p": self.config.top_p,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
# Add provider-specific parameters from kwargs
|
|
129
|
+
params.update(kwargs)
|
|
130
|
+
|
|
131
|
+
return params
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LlmConfig(BaseModel):
|
|
7
|
+
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
|
8
|
+
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
|
9
|
+
|
|
10
|
+
@field_validator("config")
|
|
11
|
+
def validate_config(cls, v, values):
|
|
12
|
+
provider = values.data.get("provider")
|
|
13
|
+
if provider in (
|
|
14
|
+
"openai",
|
|
15
|
+
"ollama",
|
|
16
|
+
"anthropic",
|
|
17
|
+
"groq",
|
|
18
|
+
"together",
|
|
19
|
+
"aws_bedrock",
|
|
20
|
+
"litellm",
|
|
21
|
+
"azure_openai",
|
|
22
|
+
"openai_structured",
|
|
23
|
+
"azure_openai_structured",
|
|
24
|
+
"gemini",
|
|
25
|
+
"deepseek",
|
|
26
|
+
"xai",
|
|
27
|
+
"sarvam",
|
|
28
|
+
"lmstudio",
|
|
29
|
+
"vllm",
|
|
30
|
+
"langchain",
|
|
31
|
+
):
|
|
32
|
+
return v
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from openai import OpenAI
|
|
6
|
+
|
|
7
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
8
|
+
from agentrun_mem0.configs.llms.deepseek import DeepSeekConfig
|
|
9
|
+
from agentrun_mem0.llms.base import LLMBase
|
|
10
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DeepSeekLLM(LLMBase):
|
|
14
|
+
def __init__(self, config: Optional[Union[BaseLlmConfig, DeepSeekConfig, Dict]] = None):
|
|
15
|
+
# Convert to DeepSeekConfig if needed
|
|
16
|
+
if config is None:
|
|
17
|
+
config = DeepSeekConfig()
|
|
18
|
+
elif isinstance(config, dict):
|
|
19
|
+
config = DeepSeekConfig(**config)
|
|
20
|
+
elif isinstance(config, BaseLlmConfig) and not isinstance(config, DeepSeekConfig):
|
|
21
|
+
# Convert BaseLlmConfig to DeepSeekConfig
|
|
22
|
+
config = DeepSeekConfig(
|
|
23
|
+
model=config.model,
|
|
24
|
+
temperature=config.temperature,
|
|
25
|
+
api_key=config.api_key,
|
|
26
|
+
max_tokens=config.max_tokens,
|
|
27
|
+
top_p=config.top_p,
|
|
28
|
+
top_k=config.top_k,
|
|
29
|
+
enable_vision=config.enable_vision,
|
|
30
|
+
vision_details=config.vision_details,
|
|
31
|
+
http_client_proxies=config.http_client,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
super().__init__(config)
|
|
35
|
+
|
|
36
|
+
if not self.config.model:
|
|
37
|
+
self.config.model = "deepseek-chat"
|
|
38
|
+
|
|
39
|
+
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
|
|
40
|
+
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com"
|
|
41
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
|
42
|
+
|
|
43
|
+
def _parse_response(self, response, tools):
|
|
44
|
+
"""
|
|
45
|
+
Process the response based on whether tools are used or not.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
response: The raw response from API.
|
|
49
|
+
tools: The list of tools provided in the request.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
str or dict: The processed response.
|
|
53
|
+
"""
|
|
54
|
+
if tools:
|
|
55
|
+
processed_response = {
|
|
56
|
+
"content": response.choices[0].message.content,
|
|
57
|
+
"tool_calls": [],
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
if response.choices[0].message.tool_calls:
|
|
61
|
+
for tool_call in response.choices[0].message.tool_calls:
|
|
62
|
+
processed_response["tool_calls"].append(
|
|
63
|
+
{
|
|
64
|
+
"name": tool_call.function.name,
|
|
65
|
+
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return processed_response
|
|
70
|
+
else:
|
|
71
|
+
return response.choices[0].message.content
|
|
72
|
+
|
|
73
|
+
def generate_response(
|
|
74
|
+
self,
|
|
75
|
+
messages: List[Dict[str, str]],
|
|
76
|
+
response_format=None,
|
|
77
|
+
tools: Optional[List[Dict]] = None,
|
|
78
|
+
tool_choice: str = "auto",
|
|
79
|
+
**kwargs,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Generate a response based on the given messages using DeepSeek.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
messages (list): List of message dicts containing 'role' and 'content'.
|
|
86
|
+
response_format (str or object, optional): Format of the response. Defaults to "text".
|
|
87
|
+
tools (list, optional): List of tools that the model can call. Defaults to None.
|
|
88
|
+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
|
89
|
+
**kwargs: Additional DeepSeek-specific parameters.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
str: The generated response.
|
|
93
|
+
"""
|
|
94
|
+
params = self._get_supported_params(messages=messages, **kwargs)
|
|
95
|
+
params.update(
|
|
96
|
+
{
|
|
97
|
+
"model": self.config.model,
|
|
98
|
+
"messages": messages,
|
|
99
|
+
}
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if tools:
|
|
103
|
+
params["tools"] = tools
|
|
104
|
+
params["tool_choice"] = tool_choice
|
|
105
|
+
|
|
106
|
+
response = self.client.chat.completions.create(**params)
|
|
107
|
+
return self._parse_response(response, tools)
|