mem0ai-azure-mysql 0.1.115.2__py3-none-any.whl → 0.1.116.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/client/main.py +20 -17
- mem0/configs/llms/anthropic.py +56 -0
- mem0/configs/llms/aws_bedrock.py +191 -0
- mem0/configs/llms/azure.py +57 -0
- mem0/configs/llms/base.py +29 -119
- mem0/configs/llms/deepseek.py +56 -0
- mem0/configs/llms/lmstudio.py +59 -0
- mem0/configs/llms/ollama.py +56 -0
- mem0/configs/llms/openai.py +76 -0
- mem0/configs/llms/vllm.py +56 -0
- mem0/configs/vector_stores/databricks.py +63 -0
- mem0/configs/vector_stores/elasticsearch.py +18 -0
- mem0/configs/vector_stores/milvus.py +1 -0
- mem0/configs/vector_stores/pgvector.py +15 -2
- mem0/configs/vector_stores/pinecone.py +1 -0
- mem0/embeddings/azure_openai.py +0 -3
- mem0/embeddings/ollama.py +1 -1
- mem0/graphs/configs.py +26 -2
- mem0/graphs/neptune/main.py +1 -0
- mem0/graphs/tools.py +6 -6
- mem0/llms/anthropic.py +33 -10
- mem0/llms/aws_bedrock.py +484 -154
- mem0/llms/azure_openai.py +30 -19
- mem0/llms/azure_openai_structured.py +19 -4
- mem0/llms/base.py +105 -6
- mem0/llms/deepseek.py +31 -9
- mem0/llms/lmstudio.py +75 -14
- mem0/llms/ollama.py +44 -32
- mem0/llms/openai.py +39 -22
- mem0/llms/vllm.py +32 -14
- mem0/memory/base.py +2 -2
- mem0/memory/graph_memory.py +166 -54
- mem0/memory/kuzu_memory.py +710 -0
- mem0/memory/main.py +59 -37
- mem0/memory/memgraph_memory.py +43 -35
- mem0/memory/utils.py +51 -0
- mem0/proxy/main.py +5 -10
- mem0/utils/factory.py +132 -25
- mem0/vector_stores/azure_ai_search.py +0 -3
- mem0/vector_stores/chroma.py +27 -2
- mem0/vector_stores/configs.py +1 -0
- mem0/vector_stores/databricks.py +759 -0
- mem0/vector_stores/elasticsearch.py +2 -0
- mem0/vector_stores/langchain.py +3 -2
- mem0/vector_stores/milvus.py +3 -1
- mem0/vector_stores/mongodb.py +20 -1
- mem0/vector_stores/pgvector.py +83 -9
- mem0/vector_stores/pinecone.py +17 -8
- mem0/vector_stores/qdrant.py +30 -0
- mem0ai_azure_mysql-0.1.116.2.data/data/README.md +24 -0
- mem0ai_azure_mysql-0.1.116.2.dist-info/METADATA +88 -0
- {mem0ai_azure_mysql-0.1.115.2.dist-info → mem0ai_azure_mysql-0.1.116.2.dist-info}/RECORD +53 -43
- mem0ai_azure_mysql-0.1.115.2.data/data/README.md +0 -169
- mem0ai_azure_mysql-0.1.115.2.dist-info/METADATA +0 -224
- mem0ai_azure_mysql-0.1.115.2.dist-info/licenses/LICENSE +0 -201
- {mem0ai_azure_mysql-0.1.115.2.dist-info → mem0ai_azure_mysql-0.1.116.2.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from mem0.configs.llms.base import BaseLlmConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OllamaConfig(BaseLlmConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for Ollama-specific parameters.
|
|
9
|
+
Inherits from BaseLlmConfig and adds Ollama-specific settings.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Base parameters
|
|
15
|
+
model: Optional[str] = None,
|
|
16
|
+
temperature: float = 0.1,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: float = 0.1,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
enable_vision: bool = False,
|
|
22
|
+
vision_details: Optional[str] = "auto",
|
|
23
|
+
http_client_proxies: Optional[dict] = None,
|
|
24
|
+
# Ollama-specific parameters
|
|
25
|
+
ollama_base_url: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize Ollama configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Ollama model to use, defaults to None
|
|
32
|
+
temperature: Controls randomness, defaults to 0.1
|
|
33
|
+
api_key: Ollama API key, defaults to None
|
|
34
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
35
|
+
top_p: Nucleus sampling parameter, defaults to 0.1
|
|
36
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
37
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
38
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
39
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
40
|
+
ollama_base_url: Ollama base URL, defaults to None
|
|
41
|
+
"""
|
|
42
|
+
# Initialize base parameters
|
|
43
|
+
super().__init__(
|
|
44
|
+
model=model,
|
|
45
|
+
temperature=temperature,
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
max_tokens=max_tokens,
|
|
48
|
+
top_p=top_p,
|
|
49
|
+
top_k=top_k,
|
|
50
|
+
enable_vision=enable_vision,
|
|
51
|
+
vision_details=vision_details,
|
|
52
|
+
http_client_proxies=http_client_proxies,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Ollama-specific parameters
|
|
56
|
+
self.ollama_base_url = ollama_base_url
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from typing import Any, Callable, List, Optional
|
|
2
|
+
|
|
3
|
+
from mem0.configs.llms.base import BaseLlmConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OpenAIConfig(BaseLlmConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for OpenAI and OpenRouter-specific parameters.
|
|
9
|
+
Inherits from BaseLlmConfig and adds OpenAI-specific settings.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Base parameters
|
|
15
|
+
model: Optional[str] = None,
|
|
16
|
+
temperature: float = 0.1,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: float = 0.1,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
enable_vision: bool = False,
|
|
22
|
+
vision_details: Optional[str] = "auto",
|
|
23
|
+
http_client_proxies: Optional[dict] = None,
|
|
24
|
+
# OpenAI-specific parameters
|
|
25
|
+
openai_base_url: Optional[str] = None,
|
|
26
|
+
models: Optional[List[str]] = None,
|
|
27
|
+
route: Optional[str] = "fallback",
|
|
28
|
+
openrouter_base_url: Optional[str] = None,
|
|
29
|
+
site_url: Optional[str] = None,
|
|
30
|
+
app_name: Optional[str] = None,
|
|
31
|
+
# Response monitoring callback
|
|
32
|
+
response_callback: Optional[Callable[[Any, dict, dict], None]] = None,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize OpenAI configuration.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model: OpenAI model to use, defaults to None
|
|
39
|
+
temperature: Controls randomness, defaults to 0.1
|
|
40
|
+
api_key: OpenAI API key, defaults to None
|
|
41
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
42
|
+
top_p: Nucleus sampling parameter, defaults to 0.1
|
|
43
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
44
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
45
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
46
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
47
|
+
openai_base_url: OpenAI API base URL, defaults to None
|
|
48
|
+
models: List of models for OpenRouter, defaults to None
|
|
49
|
+
route: OpenRouter route strategy, defaults to "fallback"
|
|
50
|
+
openrouter_base_url: OpenRouter base URL, defaults to None
|
|
51
|
+
site_url: Site URL for OpenRouter, defaults to None
|
|
52
|
+
app_name: Application name for OpenRouter, defaults to None
|
|
53
|
+
response_callback: Optional callback for monitoring LLM responses.
|
|
54
|
+
"""
|
|
55
|
+
# Initialize base parameters
|
|
56
|
+
super().__init__(
|
|
57
|
+
model=model,
|
|
58
|
+
temperature=temperature,
|
|
59
|
+
api_key=api_key,
|
|
60
|
+
max_tokens=max_tokens,
|
|
61
|
+
top_p=top_p,
|
|
62
|
+
top_k=top_k,
|
|
63
|
+
enable_vision=enable_vision,
|
|
64
|
+
vision_details=vision_details,
|
|
65
|
+
http_client_proxies=http_client_proxies,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# OpenAI-specific parameters
|
|
69
|
+
self.openai_base_url = openai_base_url
|
|
70
|
+
self.models = models
|
|
71
|
+
self.route = route
|
|
72
|
+
self.openrouter_base_url = openrouter_base_url
|
|
73
|
+
self.site_url = site_url
|
|
74
|
+
self.app_name = app_name
|
|
75
|
+
# Response monitoring
|
|
76
|
+
self.response_callback = response_callback
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from mem0.configs.llms.base import BaseLlmConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class VllmConfig(BaseLlmConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for vLLM-specific parameters.
|
|
9
|
+
Inherits from BaseLlmConfig and adds vLLM-specific settings.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Base parameters
|
|
15
|
+
model: Optional[str] = None,
|
|
16
|
+
temperature: float = 0.1,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: float = 0.1,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
enable_vision: bool = False,
|
|
22
|
+
vision_details: Optional[str] = "auto",
|
|
23
|
+
http_client_proxies: Optional[dict] = None,
|
|
24
|
+
# vLLM-specific parameters
|
|
25
|
+
vllm_base_url: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize vLLM configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: vLLM model to use, defaults to None
|
|
32
|
+
temperature: Controls randomness, defaults to 0.1
|
|
33
|
+
api_key: vLLM API key, defaults to None
|
|
34
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
35
|
+
top_p: Nucleus sampling parameter, defaults to 0.1
|
|
36
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
37
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
38
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
39
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
40
|
+
vllm_base_url: vLLM base URL, defaults to None
|
|
41
|
+
"""
|
|
42
|
+
# Initialize base parameters
|
|
43
|
+
super().__init__(
|
|
44
|
+
model=model,
|
|
45
|
+
temperature=temperature,
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
max_tokens=max_tokens,
|
|
48
|
+
top_p=top_p,
|
|
49
|
+
top_k=top_k,
|
|
50
|
+
enable_vision=enable_vision,
|
|
51
|
+
vision_details=vision_details,
|
|
52
|
+
http_client_proxies=http_client_proxies,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# vLLM-specific parameters
|
|
56
|
+
self.vllm_base_url = vllm_base_url or "http://localhost:8000/v1"
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
from databricks.sdk.service.vectorsearch import EndpointType, VectorIndexType, PipelineType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DatabricksConfig(BaseModel):
|
|
9
|
+
"""Configuration for Databricks Vector Search vector store."""
|
|
10
|
+
|
|
11
|
+
workspace_url: str = Field(..., description="Databricks workspace URL")
|
|
12
|
+
access_token: Optional[str] = Field(None, description="Personal access token for authentication")
|
|
13
|
+
client_id: Optional[str] = Field(None, description="Databricks Service principal client ID")
|
|
14
|
+
client_secret: Optional[str] = Field(None, description="Databricks Service principal client secret")
|
|
15
|
+
azure_client_id: Optional[str] = Field(None, description="Azure AD application client ID (for Azure Databricks)")
|
|
16
|
+
azure_client_secret: Optional[str] = Field(
|
|
17
|
+
None, description="Azure AD application client secret (for Azure Databricks)"
|
|
18
|
+
)
|
|
19
|
+
endpoint_name: str = Field(..., description="Vector search endpoint name")
|
|
20
|
+
catalog: str = Field(..., description="The Unity Catalog catalog name")
|
|
21
|
+
schema: str = Field(..., description="The Unity Catalog schama name")
|
|
22
|
+
table_name: str = Field(..., description="Source Delta table name")
|
|
23
|
+
collection_name: str = Field("mem0", description="Vector search index name")
|
|
24
|
+
index_type: VectorIndexType = Field("DELTA_SYNC", description="Index type: DELTA_SYNC or DIRECT_ACCESS")
|
|
25
|
+
embedding_model_endpoint_name: Optional[str] = Field(
|
|
26
|
+
None, description="Embedding model endpoint for Databricks-computed embeddings"
|
|
27
|
+
)
|
|
28
|
+
embedding_dimension: int = Field(1536, description="Vector embedding dimensions")
|
|
29
|
+
endpoint_type: EndpointType = Field("STANDARD", description="Endpoint type: STANDARD or STORAGE_OPTIMIZED")
|
|
30
|
+
pipeline_type: PipelineType = Field("TRIGGERED", description="Sync pipeline type: TRIGGERED or CONTINUOUS")
|
|
31
|
+
warehouse_name: Optional[str] = Field(None, description="Databricks SQL warehouse Name")
|
|
32
|
+
query_type: str = Field("ANN", description="Query type: `ANN` and `HYBRID`")
|
|
33
|
+
|
|
34
|
+
@model_validator(mode="before")
|
|
35
|
+
@classmethod
|
|
36
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
37
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
38
|
+
input_fields = set(values.keys())
|
|
39
|
+
extra_fields = input_fields - allowed_fields
|
|
40
|
+
if extra_fields:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
43
|
+
)
|
|
44
|
+
return values
|
|
45
|
+
|
|
46
|
+
@model_validator(mode="after")
|
|
47
|
+
def validate_authentication(self):
|
|
48
|
+
"""Validate that either access_token or service principal credentials are provided."""
|
|
49
|
+
has_token = self.access_token is not None
|
|
50
|
+
has_service_principal = (self.client_id is not None and self.client_secret is not None) or (
|
|
51
|
+
self.azure_client_id is not None and self.azure_client_secret is not None
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if not has_token and not has_service_principal:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"Either access_token or both client_id/client_secret or azure_client_id/azure_client_secret must be provided"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
model_config = {
|
|
62
|
+
"arbitrary_types_allowed": True,
|
|
63
|
+
}
|
|
@@ -19,6 +19,7 @@ class ElasticsearchConfig(BaseModel):
|
|
|
19
19
|
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
|
20
20
|
None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
|
21
21
|
)
|
|
22
|
+
headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to include in requests")
|
|
22
23
|
|
|
23
24
|
@model_validator(mode="before")
|
|
24
25
|
@classmethod
|
|
@@ -33,6 +34,23 @@ class ElasticsearchConfig(BaseModel):
|
|
|
33
34
|
|
|
34
35
|
return values
|
|
35
36
|
|
|
37
|
+
@model_validator(mode="before")
|
|
38
|
+
@classmethod
|
|
39
|
+
def validate_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
40
|
+
"""Validate headers format and content"""
|
|
41
|
+
headers = values.get("headers")
|
|
42
|
+
if headers is not None:
|
|
43
|
+
# Check if headers is a dictionary
|
|
44
|
+
if not isinstance(headers, dict):
|
|
45
|
+
raise ValueError("headers must be a dictionary")
|
|
46
|
+
|
|
47
|
+
# Check if all keys and values are strings
|
|
48
|
+
for key, value in headers.items():
|
|
49
|
+
if not isinstance(key, str) or not isinstance(value, str):
|
|
50
|
+
raise ValueError("All header keys and values must be strings")
|
|
51
|
+
|
|
52
|
+
return values
|
|
53
|
+
|
|
36
54
|
@model_validator(mode="before")
|
|
37
55
|
@classmethod
|
|
38
56
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
@@ -25,6 +25,7 @@ class MilvusDBConfig(BaseModel):
|
|
|
25
25
|
collection_name: str = Field("mem0", description="Name of the collection")
|
|
26
26
|
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
|
27
27
|
metric_type: str = Field("L2", description="Metric type for similarity search")
|
|
28
|
+
db_name: str = Field("", description="Name of the database")
|
|
28
29
|
|
|
29
30
|
@model_validator(mode="before")
|
|
30
31
|
@classmethod
|
|
@@ -13,15 +13,28 @@ class PGVectorConfig(BaseModel):
|
|
|
13
13
|
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
|
14
14
|
diskann: Optional[bool] = Field(True, description="Use diskann for approximate nearest neighbors search")
|
|
15
15
|
hnsw: Optional[bool] = Field(False, description="Use hnsw for faster search")
|
|
16
|
+
# New SSL and connection options
|
|
17
|
+
sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
|
|
18
|
+
connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
|
|
19
|
+
connection_pool: Optional[Any] = Field(None, description="psycopg2 connection pool object (overrides connection string and individual parameters)")
|
|
16
20
|
|
|
17
21
|
@model_validator(mode="before")
|
|
18
22
|
def check_auth_and_connection(cls, values):
|
|
23
|
+
# If connection_pool is provided, skip validation of individual connection parameters
|
|
24
|
+
if values.get("connection_pool") is not None:
|
|
25
|
+
return values
|
|
26
|
+
|
|
27
|
+
# If connection_string is provided, skip validation of individual connection parameters
|
|
28
|
+
if values.get("connection_string") is not None:
|
|
29
|
+
return values
|
|
30
|
+
|
|
31
|
+
# Otherwise, validate individual connection parameters
|
|
19
32
|
user, password = values.get("user"), values.get("password")
|
|
20
33
|
host, port = values.get("host"), values.get("port")
|
|
21
34
|
if not user and not password:
|
|
22
|
-
raise ValueError("Both 'user' and 'password' must be provided.")
|
|
35
|
+
raise ValueError("Both 'user' and 'password' must be provided when not using connection_string or connection_pool.")
|
|
23
36
|
if not host and not port:
|
|
24
|
-
raise ValueError("Both 'host' and 'port' must be provided.")
|
|
37
|
+
raise ValueError("Both 'host' and 'port' must be provided when not using connection_string or connection_pool.")
|
|
25
38
|
return values
|
|
26
39
|
|
|
27
40
|
@model_validator(mode="before")
|
|
@@ -18,6 +18,7 @@ class PineconeConfig(BaseModel):
|
|
|
18
18
|
metric: str = Field("cosine", description="Distance metric for vector similarity")
|
|
19
19
|
batch_size: int = Field(100, description="Batch size for operations")
|
|
20
20
|
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
|
|
21
|
+
namespace: Optional[str] = Field(None, description="Namespace for the collection")
|
|
21
22
|
|
|
22
23
|
@model_validator(mode="before")
|
|
23
24
|
@classmethod
|
mem0/embeddings/azure_openai.py
CHANGED
|
@@ -24,13 +24,10 @@ class AzureOpenAIEmbedding(EmbeddingBase):
|
|
|
24
24
|
|
|
25
25
|
auth_kwargs = {}
|
|
26
26
|
if api_key:
|
|
27
|
-
print("Using API key for Azure OpenAI Embedding authentication.")
|
|
28
27
|
auth_kwargs["api_key"] = api_key
|
|
29
28
|
elif azure_ad_token:
|
|
30
|
-
print("Using Azure AD token for Azure OpenAI Embedding authentication.")
|
|
31
29
|
auth_kwargs["azure_ad_token"] = azure_ad_token
|
|
32
30
|
else:
|
|
33
|
-
print("Using DefaultAzureCredential for Azure OpenAI Embedding authentication.")
|
|
34
31
|
credential = DefaultAzureCredential()
|
|
35
32
|
token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
|
|
36
33
|
auth_kwargs["azure_ad_token_provider"] = token_provider
|
mem0/embeddings/ollama.py
CHANGED
|
@@ -36,7 +36,7 @@ class OllamaEmbedding(EmbeddingBase):
|
|
|
36
36
|
Ensure the specified model exists locally. If not, pull it from Ollama.
|
|
37
37
|
"""
|
|
38
38
|
local_models = self.client.list()["models"]
|
|
39
|
-
if not any(model.get("name") == self.config.model for model in local_models):
|
|
39
|
+
if not any(model.get("name") == self.config.model or model.get("model") == self.config.model for model in local_models):
|
|
40
40
|
self.client.pull(self.config.model)
|
|
41
41
|
|
|
42
42
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
mem0/graphs/configs.py
CHANGED
|
@@ -5,12 +5,24 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
5
5
|
from mem0.llms.configs import LlmConfig
|
|
6
6
|
|
|
7
7
|
|
|
8
|
+
class RerankConfig(BaseModel):
|
|
9
|
+
provider: str = Field(
|
|
10
|
+
description="Provider of the rerank model (e.g., 'openai', 'azure', 'cohere')",
|
|
11
|
+
default="cohere",
|
|
12
|
+
)
|
|
13
|
+
config: Optional[dict] = Field(
|
|
14
|
+
description="Configuration for the specific rerank model", default={}
|
|
15
|
+
)
|
|
16
|
+
|
|
8
17
|
class Neo4jConfig(BaseModel):
|
|
9
18
|
url: Optional[str] = Field(None, description="Host address for the graph database")
|
|
10
19
|
username: Optional[str] = Field(None, description="Username for the graph database")
|
|
11
20
|
password: Optional[str] = Field(None, description="Password for the graph database")
|
|
12
21
|
database: Optional[str] = Field(None, description="Database for the graph database")
|
|
13
22
|
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
|
23
|
+
similarity_threshold: float = Field(0.7, description="Threshold for the similarity of nodes")
|
|
24
|
+
top_k: int = Field(5, description="Number of top scored results to return")
|
|
25
|
+
rerank: Optional[RerankConfig] = Field(None, description="Rerank configuration")
|
|
14
26
|
|
|
15
27
|
@model_validator(mode="before")
|
|
16
28
|
def check_host_port_or_path(cls, values):
|
|
@@ -21,6 +33,12 @@ class Neo4jConfig(BaseModel):
|
|
|
21
33
|
)
|
|
22
34
|
if not url or not username or not password:
|
|
23
35
|
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
|
36
|
+
|
|
37
|
+
if values.get("rerank") is not None:
|
|
38
|
+
values["rerank"] = RerankConfig(**values.get("rerank"))
|
|
39
|
+
if values["rerank"].provider not in ("cohere"):
|
|
40
|
+
raise ValueError("Invalid rerank provider. Supported providers are: cohere")
|
|
41
|
+
|
|
24
42
|
return values
|
|
25
43
|
|
|
26
44
|
|
|
@@ -70,12 +88,16 @@ class NeptuneConfig(BaseModel):
|
|
|
70
88
|
)
|
|
71
89
|
|
|
72
90
|
|
|
91
|
+
class KuzuConfig(BaseModel):
|
|
92
|
+
db: Optional[str] = Field(":memory:", description="Path to a Kuzu database file")
|
|
93
|
+
|
|
94
|
+
|
|
73
95
|
class GraphStoreConfig(BaseModel):
|
|
74
96
|
provider: str = Field(
|
|
75
|
-
description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune')",
|
|
97
|
+
description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune', 'kuzu')",
|
|
76
98
|
default="neo4j",
|
|
77
99
|
)
|
|
78
|
-
config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig] = Field(
|
|
100
|
+
config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig, KuzuConfig] = Field(
|
|
79
101
|
description="Configuration for the specific data store", default=None
|
|
80
102
|
)
|
|
81
103
|
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
|
|
@@ -92,5 +114,7 @@ class GraphStoreConfig(BaseModel):
|
|
|
92
114
|
return MemgraphConfig(**v.model_dump())
|
|
93
115
|
elif provider == "neptune":
|
|
94
116
|
return NeptuneConfig(**v.model_dump())
|
|
117
|
+
elif provider == "kuzu":
|
|
118
|
+
return KuzuConfig(**v.model_dump())
|
|
95
119
|
else:
|
|
96
120
|
raise ValueError(f"Unsupported graph store provider: {provider}")
|
mem0/graphs/neptune/main.py
CHANGED
mem0/graphs/tools.py
CHANGED
|
@@ -249,23 +249,23 @@ RELATIONS_STRUCT_TOOL = {
|
|
|
249
249
|
"items": {
|
|
250
250
|
"type": "object",
|
|
251
251
|
"properties": {
|
|
252
|
-
"
|
|
252
|
+
"source": {
|
|
253
253
|
"type": "string",
|
|
254
254
|
"description": "The source entity of the relationship.",
|
|
255
255
|
},
|
|
256
|
-
"
|
|
256
|
+
"relationship": {
|
|
257
257
|
"type": "string",
|
|
258
258
|
"description": "The relationship between the source and destination entities.",
|
|
259
259
|
},
|
|
260
|
-
"
|
|
260
|
+
"destination": {
|
|
261
261
|
"type": "string",
|
|
262
262
|
"description": "The destination entity of the relationship.",
|
|
263
263
|
},
|
|
264
264
|
},
|
|
265
265
|
"required": [
|
|
266
|
-
"
|
|
267
|
-
"
|
|
268
|
-
"
|
|
266
|
+
"source",
|
|
267
|
+
"relationship",
|
|
268
|
+
"destination",
|
|
269
269
|
],
|
|
270
270
|
"additionalProperties": False,
|
|
271
271
|
},
|
mem0/llms/anthropic.py
CHANGED
|
@@ -1,17 +1,37 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Dict, List, Optional
|
|
2
|
+
from typing import Dict, List, Optional, Union
|
|
3
3
|
|
|
4
4
|
try:
|
|
5
5
|
import anthropic
|
|
6
6
|
except ImportError:
|
|
7
7
|
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
|
8
8
|
|
|
9
|
+
from mem0.configs.llms.anthropic import AnthropicConfig
|
|
9
10
|
from mem0.configs.llms.base import BaseLlmConfig
|
|
10
11
|
from mem0.llms.base import LLMBase
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class AnthropicLLM(LLMBase):
|
|
14
|
-
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
|
15
|
+
def __init__(self, config: Optional[Union[BaseLlmConfig, AnthropicConfig, Dict]] = None):
|
|
16
|
+
# Convert to AnthropicConfig if needed
|
|
17
|
+
if config is None:
|
|
18
|
+
config = AnthropicConfig()
|
|
19
|
+
elif isinstance(config, dict):
|
|
20
|
+
config = AnthropicConfig(**config)
|
|
21
|
+
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AnthropicConfig):
|
|
22
|
+
# Convert BaseLlmConfig to AnthropicConfig
|
|
23
|
+
config = AnthropicConfig(
|
|
24
|
+
model=config.model,
|
|
25
|
+
temperature=config.temperature,
|
|
26
|
+
api_key=config.api_key,
|
|
27
|
+
max_tokens=config.max_tokens,
|
|
28
|
+
top_p=config.top_p,
|
|
29
|
+
top_k=config.top_k,
|
|
30
|
+
enable_vision=config.enable_vision,
|
|
31
|
+
vision_details=config.vision_details,
|
|
32
|
+
http_client_proxies=config.http_client,
|
|
33
|
+
)
|
|
34
|
+
|
|
15
35
|
super().__init__(config)
|
|
16
36
|
|
|
17
37
|
if not self.config.model:
|
|
@@ -26,6 +46,7 @@ class AnthropicLLM(LLMBase):
|
|
|
26
46
|
response_format=None,
|
|
27
47
|
tools: Optional[List[Dict]] = None,
|
|
28
48
|
tool_choice: str = "auto",
|
|
49
|
+
**kwargs,
|
|
29
50
|
):
|
|
30
51
|
"""
|
|
31
52
|
Generate a response based on the given messages using Anthropic.
|
|
@@ -35,6 +56,7 @@ class AnthropicLLM(LLMBase):
|
|
|
35
56
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
|
36
57
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
|
37
58
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
|
59
|
+
**kwargs: Additional Anthropic-specific parameters.
|
|
38
60
|
|
|
39
61
|
Returns:
|
|
40
62
|
str: The generated response.
|
|
@@ -48,14 +70,15 @@ class AnthropicLLM(LLMBase):
|
|
|
48
70
|
else:
|
|
49
71
|
filtered_messages.append(message)
|
|
50
72
|
|
|
51
|
-
params =
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
73
|
+
params = self._get_supported_params(messages=messages, **kwargs)
|
|
74
|
+
params.update(
|
|
75
|
+
{
|
|
76
|
+
"model": self.config.model,
|
|
77
|
+
"messages": filtered_messages,
|
|
78
|
+
"system": system_message,
|
|
79
|
+
}
|
|
80
|
+
)
|
|
81
|
+
|
|
59
82
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
|
60
83
|
params["tools"] = tools
|
|
61
84
|
params["tool_choice"] = tool_choice
|