mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OpenSearchConfig(BaseModel):
|
|
7
|
+
collection_name: str = Field("mem0", description="Name of the index")
|
|
8
|
+
host: str = Field("localhost", description="OpenSearch host")
|
|
9
|
+
port: int = Field(9200, description="OpenSearch port")
|
|
10
|
+
user: Optional[str] = Field(None, description="Username for authentication")
|
|
11
|
+
password: Optional[str] = Field(None, description="Password for authentication")
|
|
12
|
+
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
|
13
|
+
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
|
14
|
+
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
|
15
|
+
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
|
16
|
+
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
|
17
|
+
connection_class: Optional[Union[str, Type]] = Field(
|
|
18
|
+
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
|
19
|
+
)
|
|
20
|
+
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
|
|
21
|
+
|
|
22
|
+
@model_validator(mode="before")
|
|
23
|
+
@classmethod
|
|
24
|
+
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
25
|
+
# Check if host is provided
|
|
26
|
+
if not values.get("host"):
|
|
27
|
+
raise ValueError("Host must be provided for OpenSearch")
|
|
28
|
+
|
|
29
|
+
return values
|
|
30
|
+
|
|
31
|
+
@model_validator(mode="before")
|
|
32
|
+
@classmethod
|
|
33
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
34
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
35
|
+
input_fields = set(values.keys())
|
|
36
|
+
extra_fields = input_fields - allowed_fields
|
|
37
|
+
if extra_fields:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
|
40
|
+
)
|
|
41
|
+
return values
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PGVectorConfig(BaseModel):
|
|
7
|
+
dbname: str = Field("postgres", description="Default name for the database")
|
|
8
|
+
collection_name: str = Field("mem0", description="Default name for the collection")
|
|
9
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
10
|
+
user: Optional[str] = Field(None, description="Database user")
|
|
11
|
+
password: Optional[str] = Field(None, description="Database password")
|
|
12
|
+
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
|
13
|
+
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
|
14
|
+
diskann: Optional[bool] = Field(True, description="Use diskann for approximate nearest neighbors search")
|
|
15
|
+
hnsw: Optional[bool] = Field(False, description="Use hnsw for faster search")
|
|
16
|
+
|
|
17
|
+
@model_validator(mode="before")
|
|
18
|
+
def check_auth_and_connection(cls, values):
|
|
19
|
+
user, password = values.get("user"), values.get("password")
|
|
20
|
+
host, port = values.get("host"), values.get("port")
|
|
21
|
+
if not user and not password:
|
|
22
|
+
raise ValueError("Both 'user' and 'password' must be provided.")
|
|
23
|
+
if not host and not port:
|
|
24
|
+
raise ValueError("Both 'host' and 'port' must be provided.")
|
|
25
|
+
return values
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="before")
|
|
28
|
+
@classmethod
|
|
29
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
30
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
31
|
+
input_fields = set(values.keys())
|
|
32
|
+
extra_fields = input_fields - allowed_fields
|
|
33
|
+
if extra_fields:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
36
|
+
)
|
|
37
|
+
return values
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PineconeConfig(BaseModel):
|
|
8
|
+
"""Configuration for Pinecone vector database."""
|
|
9
|
+
|
|
10
|
+
collection_name: str = Field("mem0", description="Name of the index/collection")
|
|
11
|
+
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
|
12
|
+
client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
|
|
13
|
+
api_key: Optional[str] = Field(None, description="API key for Pinecone")
|
|
14
|
+
environment: Optional[str] = Field(None, description="Pinecone environment")
|
|
15
|
+
serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
|
|
16
|
+
pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
|
|
17
|
+
hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
|
|
18
|
+
metric: str = Field("cosine", description="Distance metric for vector similarity")
|
|
19
|
+
batch_size: int = Field(100, description="Batch size for operations")
|
|
20
|
+
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
|
|
21
|
+
|
|
22
|
+
@model_validator(mode="before")
|
|
23
|
+
@classmethod
|
|
24
|
+
def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
25
|
+
api_key, client = values.get("api_key"), values.get("client")
|
|
26
|
+
if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
"Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
|
|
29
|
+
)
|
|
30
|
+
return values
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="before")
|
|
33
|
+
@classmethod
|
|
34
|
+
def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
35
|
+
pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
|
|
36
|
+
if pod_config and serverless_config:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
|
|
39
|
+
)
|
|
40
|
+
return values
|
|
41
|
+
|
|
42
|
+
@model_validator(mode="before")
|
|
43
|
+
@classmethod
|
|
44
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
45
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
46
|
+
input_fields = set(values.keys())
|
|
47
|
+
extra_fields = input_fields - allowed_fields
|
|
48
|
+
if extra_fields:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
51
|
+
)
|
|
52
|
+
return values
|
|
53
|
+
|
|
54
|
+
model_config = {
|
|
55
|
+
"arbitrary_types_allowed": True,
|
|
56
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class QdrantConfig(BaseModel):
|
|
7
|
+
from qdrant_client import QdrantClient
|
|
8
|
+
|
|
9
|
+
QdrantClient: ClassVar[type] = QdrantClient
|
|
10
|
+
|
|
11
|
+
collection_name: str = Field("mem0", description="Name of the collection")
|
|
12
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
13
|
+
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
|
14
|
+
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
|
15
|
+
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
|
16
|
+
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
|
17
|
+
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
|
18
|
+
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
|
19
|
+
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
|
20
|
+
|
|
21
|
+
@model_validator(mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
24
|
+
host, port, path, url, api_key = (
|
|
25
|
+
values.get("host"),
|
|
26
|
+
values.get("port"),
|
|
27
|
+
values.get("path"),
|
|
28
|
+
values.get("url"),
|
|
29
|
+
values.get("api_key"),
|
|
30
|
+
)
|
|
31
|
+
if not path and not (host and port) and not (url and api_key):
|
|
32
|
+
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
|
|
33
|
+
return values
|
|
34
|
+
|
|
35
|
+
@model_validator(mode="before")
|
|
36
|
+
@classmethod
|
|
37
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
38
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
39
|
+
input_fields = set(values.keys())
|
|
40
|
+
extra_fields = input_fields - allowed_fields
|
|
41
|
+
if extra_fields:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
44
|
+
)
|
|
45
|
+
return values
|
|
46
|
+
|
|
47
|
+
model_config = {
|
|
48
|
+
"arbitrary_types_allowed": True,
|
|
49
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO: Upgrade to latest pydantic version
|
|
7
|
+
class RedisDBConfig(BaseModel):
|
|
8
|
+
redis_url: str = Field(..., description="Redis URL")
|
|
9
|
+
collection_name: str = Field("mem0", description="Collection name")
|
|
10
|
+
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
|
|
11
|
+
|
|
12
|
+
@model_validator(mode="before")
|
|
13
|
+
@classmethod
|
|
14
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
15
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
16
|
+
input_fields = set(values.keys())
|
|
17
|
+
extra_fields = input_fields - allowed_fields
|
|
18
|
+
if extra_fields:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
21
|
+
)
|
|
22
|
+
return values
|
|
23
|
+
|
|
24
|
+
model_config = {
|
|
25
|
+
"arbitrary_types_allowed": True,
|
|
26
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IndexMethod(str, Enum):
|
|
8
|
+
AUTO = "auto"
|
|
9
|
+
HNSW = "hnsw"
|
|
10
|
+
IVFFLAT = "ivfflat"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IndexMeasure(str, Enum):
|
|
14
|
+
COSINE = "cosine_distance"
|
|
15
|
+
L2 = "l2_distance"
|
|
16
|
+
L1 = "l1_distance"
|
|
17
|
+
MAX_INNER_PRODUCT = "max_inner_product"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SupabaseConfig(BaseModel):
|
|
21
|
+
connection_string: str = Field(..., description="PostgreSQL connection string")
|
|
22
|
+
collection_name: str = Field("mem0", description="Name for the vector collection")
|
|
23
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
24
|
+
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
|
|
25
|
+
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="before")
|
|
28
|
+
def check_connection_string(cls, values):
|
|
29
|
+
conn_str = values.get("connection_string")
|
|
30
|
+
if not conn_str or not conn_str.startswith("postgresql://"):
|
|
31
|
+
raise ValueError("A valid PostgreSQL connection string must be provided")
|
|
32
|
+
return values
|
|
33
|
+
|
|
34
|
+
@model_validator(mode="before")
|
|
35
|
+
@classmethod
|
|
36
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
37
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
38
|
+
input_fields = set(values.keys())
|
|
39
|
+
extra_fields = input_fields - allowed_fields
|
|
40
|
+
if extra_fields:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
43
|
+
)
|
|
44
|
+
return values
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from upstash_vector import Index
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UpstashVectorConfig(BaseModel):
|
|
13
|
+
Index: ClassVar[type] = Index
|
|
14
|
+
|
|
15
|
+
url: Optional[str] = Field(None, description="URL for Upstash Vector index")
|
|
16
|
+
token: Optional[str] = Field(None, description="Token for Upstash Vector index")
|
|
17
|
+
client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
|
|
18
|
+
collection_name: str = Field("mem0", description="Namespace to use for the index")
|
|
19
|
+
enable_embeddings: bool = Field(
|
|
20
|
+
False, description="Whether to use built-in upstash embeddings or not. Default is True."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
26
|
+
client = values.get("client")
|
|
27
|
+
url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
|
|
28
|
+
token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
|
|
29
|
+
|
|
30
|
+
if not client and not (url and token):
|
|
31
|
+
raise ValueError("Either a client or URL and token must be provided.")
|
|
32
|
+
return values
|
|
33
|
+
|
|
34
|
+
model_config = {
|
|
35
|
+
"arbitrary_types_allowed": True,
|
|
36
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GoogleMatchingEngineConfig(BaseModel):
|
|
7
|
+
project_id: str = Field(description="Google Cloud project ID")
|
|
8
|
+
project_number: str = Field(description="Google Cloud project number")
|
|
9
|
+
region: str = Field(description="Google Cloud region")
|
|
10
|
+
endpoint_id: str = Field(description="Vertex AI Vector Search endpoint ID")
|
|
11
|
+
index_id: str = Field(description="Vertex AI Vector Search index ID")
|
|
12
|
+
deployment_index_id: str = Field(description="Deployment-specific index ID")
|
|
13
|
+
collection_name: Optional[str] = Field(None, description="Collection name, defaults to index_id")
|
|
14
|
+
credentials_path: Optional[str] = Field(None, description="Path to service account credentials file")
|
|
15
|
+
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
|
|
16
|
+
|
|
17
|
+
model_config = {"extra": "forbid"}
|
|
18
|
+
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
if not self.collection_name:
|
|
22
|
+
self.collection_name = self.index_id
|
|
23
|
+
|
|
24
|
+
def model_post_init(self, _context) -> None:
|
|
25
|
+
"""Set collection_name to index_id if not provided"""
|
|
26
|
+
if self.collection_name is None:
|
|
27
|
+
self.collection_name = self.index_id
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WeaviateConfig(BaseModel):
|
|
7
|
+
from weaviate import WeaviateClient
|
|
8
|
+
|
|
9
|
+
WeaviateClient: ClassVar[type] = WeaviateClient
|
|
10
|
+
|
|
11
|
+
collection_name: str = Field("mem0", description="Name of the collection")
|
|
12
|
+
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
|
13
|
+
cluster_url: Optional[str] = Field(None, description="URL for Weaviate server")
|
|
14
|
+
auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication")
|
|
15
|
+
additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests")
|
|
16
|
+
|
|
17
|
+
@model_validator(mode="before")
|
|
18
|
+
@classmethod
|
|
19
|
+
def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
20
|
+
cluster_url = values.get("cluster_url")
|
|
21
|
+
|
|
22
|
+
if not cluster_url:
|
|
23
|
+
raise ValueError("'cluster_url' must be provided.")
|
|
24
|
+
|
|
25
|
+
return values
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="before")
|
|
28
|
+
@classmethod
|
|
29
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
30
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
31
|
+
input_fields = set(values.keys())
|
|
32
|
+
extra_fields = input_fields - allowed_fields
|
|
33
|
+
|
|
34
|
+
if extra_fields:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return values
|
|
40
|
+
|
|
41
|
+
model_config = {
|
|
42
|
+
"arbitrary_types_allowed": True,
|
|
43
|
+
}
|
mem0/dbs/__init__.py
ADDED
mem0/dbs/base.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from mem0.configs.dbs.base import BaseDBConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DBBase(ABC):
|
|
8
|
+
"""Initialized a base database class
|
|
9
|
+
|
|
10
|
+
:param config: Database configuration option class, defaults to None
|
|
11
|
+
:type config: Optional[BaseDBConfig], optional
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Optional[BaseDBConfig] = None):
|
|
15
|
+
if config is None:
|
|
16
|
+
self.config = BaseDBConfig()
|
|
17
|
+
else:
|
|
18
|
+
self.config = config
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def add_history(
|
|
22
|
+
self,
|
|
23
|
+
memory_id: str,
|
|
24
|
+
old_memory: Optional[str],
|
|
25
|
+
new_memory: Optional[str],
|
|
26
|
+
event: str,
|
|
27
|
+
*,
|
|
28
|
+
created_at: Optional[str] = None,
|
|
29
|
+
updated_at: Optional[str] = None,
|
|
30
|
+
is_deleted: int = 0,
|
|
31
|
+
actor_id: Optional[str] = None,
|
|
32
|
+
role: Optional[str] = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Add a history record to the database.
|
|
35
|
+
|
|
36
|
+
:param memory_id: The ID of the memory being tracked
|
|
37
|
+
:param old_memory: The previous memory content
|
|
38
|
+
:param new_memory: The new memory content
|
|
39
|
+
:param event: The type of event that occurred
|
|
40
|
+
:param created_at: When the record was created
|
|
41
|
+
:param updated_at: When the record was last updated
|
|
42
|
+
:param is_deleted: Whether the record is deleted (0 or 1)
|
|
43
|
+
:param actor_id: ID of the actor who made the change
|
|
44
|
+
:param role: Role of the actor
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def get_history(self, memory_id: str) -> List[Dict[str, Any]]:
|
|
50
|
+
"""Retrieve history records for a given memory ID.
|
|
51
|
+
|
|
52
|
+
:param memory_id: The ID of the memory to get history for
|
|
53
|
+
:return: List of history records as dictionaries
|
|
54
|
+
"""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def reset(self) -> None:
|
|
59
|
+
"""Reset/clear all data in the database."""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def close(self) -> None:
|
|
64
|
+
"""Close the database connection and clean up resources."""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def __del__(self):
|
|
68
|
+
self.close()
|
mem0/dbs/configs.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DBConfig(BaseModel):
|
|
7
|
+
provider: str = Field(
|
|
8
|
+
description="Provider of the database (e.g., 'mysql')",
|
|
9
|
+
default="mysql",
|
|
10
|
+
)
|
|
11
|
+
config: Optional[dict] = Field(description="Configuration for the specific database", default={})
|
|
12
|
+
|
|
13
|
+
@field_validator("config")
|
|
14
|
+
def validate_config(cls, v, values):
|
|
15
|
+
provider = values.data.get("provider")
|
|
16
|
+
if provider in [
|
|
17
|
+
"mysql",
|
|
18
|
+
]:
|
|
19
|
+
return v
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(f"Unsupported database provider: {provider}")
|