agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,64 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+
6
+ class MySQLVectorConfig(BaseModel):
7
+ dbname: str = Field("mem0", description="Default name for the database")
8
+ collection_name: str = Field("mem0", description="Default name for the collection")
9
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
10
+ user: Optional[str] = Field(None, description="Database user")
11
+ password: Optional[str] = Field(None, description="Database password")
12
+ host: Optional[str] = Field(None, description="Database host. Default is localhost")
13
+ port: Optional[int] = Field(None, description="Database port. Default is 3306")
14
+ distance_function: Optional[str] = Field("cosine", description="Distance function for vector index ('euclidean' or 'cosine')")
15
+ m_value: Optional[int] = Field(16, description="M parameter for HNSW index (3-200). Higher values = more accurate but slower")
16
+ # SSL and connection options
17
+ ssl_disabled: Optional[bool] = Field(False, description="Disable SSL connection")
18
+ ssl_ca: Optional[str] = Field(None, description="SSL CA certificate file path")
19
+ ssl_cert: Optional[str] = Field(None, description="SSL certificate file path")
20
+ ssl_key: Optional[str] = Field(None, description="SSL key file path")
21
+ connection_string: Optional[str] = Field(None, description="AlibabaCloud MySQL connection string (overrides individual connection parameters)")
22
+ charset: Optional[str] = Field("utf8mb4", description="Character set for the connection")
23
+ autocommit: Optional[bool] = Field(True, description="Enable autocommit mode")
24
+
25
+ @model_validator(mode="before")
26
+ def check_auth_and_connection(cls, values):
27
+ # If connection_string is provided, skip validation of individual connection parameters
28
+ if values.get("connection_string") is not None:
29
+ return values
30
+
31
+ # Otherwise, validate individual connection parameters
32
+ user, password = values.get("user"), values.get("password")
33
+ host, port = values.get("host"), values.get("port")
34
+ if not user and not password:
35
+ raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.")
36
+ if not host and not port:
37
+ raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.")
38
+ return values
39
+
40
+ @model_validator(mode="before")
41
+ def validate_distance_function(cls, values):
42
+ distance_function = values.get("distance_function", "cosine")
43
+ if distance_function not in ["euclidean", "cosine"]:
44
+ raise ValueError("distance_function must be either 'euclidean' or 'cosine'")
45
+ return values
46
+
47
+ @model_validator(mode="before")
48
+ def validate_m_value(cls, values):
49
+ m_value = values.get("m_value", 16)
50
+ if not (3 <= m_value <= 200):
51
+ raise ValueError("m_value must be between 3 and 200")
52
+ return values
53
+
54
+ @model_validator(mode="before")
55
+ @classmethod
56
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
57
+ allowed_fields = set(cls.model_fields.keys())
58
+ input_fields = set(values.keys())
59
+ extra_fields = input_fields - allowed_fields
60
+ if extra_fields:
61
+ raise ValueError(
62
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
63
+ )
64
+ return values
@@ -0,0 +1,32 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+ class AliyunTableStoreConfig(BaseModel):
6
+ endpoint: str = Field(description="endpoint of tablestore")
7
+ instance_name: str = Field(description="instance_name of tablestore")
8
+ access_key_id: str = Field(description="access_key_id of tablestore")
9
+ access_key_secret: str = Field(description="access_key_secret of tablestore")
10
+ vector_dimension: int = Field(1536, description="dimension of vector")
11
+ sts_token: Optional[str] = Field(None, description="sts_token of tablestore")
12
+ collection_name: str = Field("mem0", description="name of the collection")
13
+ search_index_name: str = Field("mem0_search_index", description="index name")
14
+ text_field: str = Field("text", description="name of the text in table field")
15
+ embedding_field: str = Field("embedding", description="name of the embedding field")
16
+ vector_metric_type: str = Field("VM_COSINE", description="metric type for vector")
17
+
18
+ @model_validator(mode="before")
19
+ @classmethod
20
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
21
+ allowed_fields = set(cls.model_fields.keys())
22
+ input_fields = set(values.keys())
23
+ extra_fields = input_fields - allowed_fields
24
+ if extra_fields:
25
+ raise ValueError(
26
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
27
+ )
28
+ return values
29
+
30
+ model_config = {
31
+ "arbitrary_types_allowed": True,
32
+ }
@@ -0,0 +1,57 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class AzureAISearchConfig(BaseModel):
7
+ collection_name: str = Field("mem0", description="Name of the collection")
8
+ service_name: str = Field(None, description="Azure AI Search service name")
9
+ api_key: str = Field(None, description="API key for the Azure AI Search service")
10
+ embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
11
+ compression_type: Optional[str] = Field(
12
+ None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
13
+ )
14
+ use_float16: bool = Field(
15
+ False,
16
+ description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
17
+ )
18
+ hybrid_search: bool = Field(
19
+ False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
20
+ )
21
+ vector_filter_mode: Optional[str] = Field(
22
+ "preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
23
+ )
24
+
25
+ @model_validator(mode="before")
26
+ @classmethod
27
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
28
+ allowed_fields = set(cls.model_fields.keys())
29
+ input_fields = set(values.keys())
30
+ extra_fields = input_fields - allowed_fields
31
+
32
+ # Check for use_compression to provide a helpful error
33
+ if "use_compression" in extra_fields:
34
+ raise ValueError(
35
+ "The parameter 'use_compression' is no longer supported. "
36
+ "Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
37
+ "or 'compression_type=None' instead of 'use_compression=False'."
38
+ )
39
+
40
+ if extra_fields:
41
+ raise ValueError(
42
+ f"Extra fields not allowed: {', '.join(extra_fields)}. "
43
+ f"Please input only the following fields: {', '.join(allowed_fields)}"
44
+ )
45
+
46
+ # Validate compression_type values
47
+ if "compression_type" in values and values["compression_type"] is not None:
48
+ valid_types = ["scalar", "binary"]
49
+ if values["compression_type"].lower() not in valid_types:
50
+ raise ValueError(
51
+ f"Invalid compression_type: {values['compression_type']}. "
52
+ f"Must be one of: {', '.join(valid_types)}, or None"
53
+ )
54
+
55
+ return values
56
+
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,84 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+
6
+ class AzureMySQLConfig(BaseModel):
7
+ """Configuration for Azure MySQL vector database."""
8
+
9
+ host: str = Field(..., description="MySQL server host (e.g., myserver.mysql.database.azure.com)")
10
+ port: int = Field(3306, description="MySQL server port")
11
+ user: str = Field(..., description="Database user")
12
+ password: Optional[str] = Field(None, description="Database password (not required if using Azure credential)")
13
+ database: str = Field(..., description="Database name")
14
+ collection_name: str = Field("mem0", description="Collection/table name")
15
+ embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
16
+ use_azure_credential: bool = Field(
17
+ False,
18
+ description="Use Azure DefaultAzureCredential for authentication instead of password"
19
+ )
20
+ ssl_ca: Optional[str] = Field(None, description="Path to SSL CA certificate")
21
+ ssl_disabled: bool = Field(False, description="Disable SSL connection (not recommended for production)")
22
+ minconn: int = Field(1, description="Minimum number of connections in the pool")
23
+ maxconn: int = Field(5, description="Maximum number of connections in the pool")
24
+ connection_pool: Optional[Any] = Field(
25
+ None,
26
+ description="Pre-configured connection pool object (overrides other connection parameters)"
27
+ )
28
+
29
+ @model_validator(mode="before")
30
+ @classmethod
31
+ def check_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
32
+ """Validate authentication parameters."""
33
+ # If connection_pool is provided, skip validation
34
+ if values.get("connection_pool") is not None:
35
+ return values
36
+
37
+ use_azure_credential = values.get("use_azure_credential", False)
38
+ password = values.get("password")
39
+
40
+ # Either password or Azure credential must be provided
41
+ if not use_azure_credential and not password:
42
+ raise ValueError(
43
+ "Either 'password' must be provided or 'use_azure_credential' must be set to True"
44
+ )
45
+
46
+ return values
47
+
48
+ @model_validator(mode="before")
49
+ @classmethod
50
+ def check_required_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
51
+ """Validate required fields."""
52
+ # If connection_pool is provided, skip validation of individual parameters
53
+ if values.get("connection_pool") is not None:
54
+ return values
55
+
56
+ required_fields = ["host", "user", "database"]
57
+ missing_fields = [field for field in required_fields if not values.get(field)]
58
+
59
+ if missing_fields:
60
+ raise ValueError(
61
+ f"Missing required fields: {', '.join(missing_fields)}. "
62
+ f"These fields are required when not using a pre-configured connection_pool."
63
+ )
64
+
65
+ return values
66
+
67
+ @model_validator(mode="before")
68
+ @classmethod
69
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
70
+ """Validate that no extra fields are provided."""
71
+ allowed_fields = set(cls.model_fields.keys())
72
+ input_fields = set(values.keys())
73
+ extra_fields = input_fields - allowed_fields
74
+
75
+ if extra_fields:
76
+ raise ValueError(
77
+ f"Extra fields not allowed: {', '.join(extra_fields)}. "
78
+ f"Please input only the following fields: {', '.join(allowed_fields)}"
79
+ )
80
+
81
+ return values
82
+
83
+ class Config:
84
+ arbitrary_types_allowed = True
@@ -0,0 +1,27 @@
1
+ from typing import Any, Dict
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class BaiduDBConfig(BaseModel):
7
+ endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
8
+ account: str = Field("root", description="Account for Baidu VectorDB")
9
+ api_key: str = Field(None, description="API Key for Baidu VectorDB")
10
+ database_name: str = Field("mem0", description="Name of the database")
11
+ table_name: str = Field("mem0", description="Name of the table")
12
+ embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
13
+ metric_type: str = Field("L2", description="Metric type for similarity search")
14
+
15
+ @model_validator(mode="before")
16
+ @classmethod
17
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
18
+ allowed_fields = set(cls.model_fields.keys())
19
+ input_fields = set(values.keys())
20
+ extra_fields = input_fields - allowed_fields
21
+ if extra_fields:
22
+ raise ValueError(
23
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
24
+ )
25
+ return values
26
+
27
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,58 @@
1
+ from typing import Any, ClassVar, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class ChromaDbConfig(BaseModel):
7
+ try:
8
+ from chromadb.api.client import Client
9
+ except ImportError:
10
+ raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
11
+ Client: ClassVar[type] = Client
12
+
13
+ collection_name: str = Field("mem0", description="Default name for the collection/database")
14
+ client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
15
+ path: Optional[str] = Field(None, description="Path to the database directory")
16
+ host: Optional[str] = Field(None, description="Database connection remote host")
17
+ port: Optional[int] = Field(None, description="Database connection remote port")
18
+ # ChromaDB Cloud configuration
19
+ api_key: Optional[str] = Field(None, description="ChromaDB Cloud API key")
20
+ tenant: Optional[str] = Field(None, description="ChromaDB Cloud tenant ID")
21
+
22
+ @model_validator(mode="before")
23
+ def check_connection_config(cls, values):
24
+ host, port, path = values.get("host"), values.get("port"), values.get("path")
25
+ api_key, tenant = values.get("api_key"), values.get("tenant")
26
+
27
+ # Check if cloud configuration is provided
28
+ cloud_config = bool(api_key and tenant)
29
+
30
+ # If cloud configuration is provided, remove any default path that might have been added
31
+ if cloud_config and path == "/tmp/chroma":
32
+ values.pop("path", None)
33
+ return values
34
+
35
+ # Check if local/server configuration is provided (excluding default tmp path for cloud config)
36
+ local_config = bool(path and path != "/tmp/chroma") or bool(host and port)
37
+
38
+ if not cloud_config and not local_config:
39
+ raise ValueError("Either ChromaDB Cloud configuration (api_key, tenant) or local configuration (path or host/port) must be provided.")
40
+
41
+ if cloud_config and local_config:
42
+ raise ValueError("Cannot specify both cloud configuration and local configuration. Choose one.")
43
+
44
+ return values
45
+
46
+ @model_validator(mode="before")
47
+ @classmethod
48
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
49
+ allowed_fields = set(cls.model_fields.keys())
50
+ input_fields = set(values.keys())
51
+ extra_fields = input_fields - allowed_fields
52
+ if extra_fields:
53
+ raise ValueError(
54
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
55
+ )
56
+ return values
57
+
58
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,61 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+ from databricks.sdk.service.vectorsearch import EndpointType, VectorIndexType, PipelineType
6
+
7
+
8
+ class DatabricksConfig(BaseModel):
9
+ """Configuration for Databricks Vector Search vector store."""
10
+
11
+ workspace_url: str = Field(..., description="Databricks workspace URL")
12
+ access_token: Optional[str] = Field(None, description="Personal access token for authentication")
13
+ client_id: Optional[str] = Field(None, description="Databricks Service principal client ID")
14
+ client_secret: Optional[str] = Field(None, description="Databricks Service principal client secret")
15
+ azure_client_id: Optional[str] = Field(None, description="Azure AD application client ID (for Azure Databricks)")
16
+ azure_client_secret: Optional[str] = Field(
17
+ None, description="Azure AD application client secret (for Azure Databricks)"
18
+ )
19
+ endpoint_name: str = Field(..., description="Vector search endpoint name")
20
+ catalog: str = Field(..., description="The Unity Catalog catalog name")
21
+ schema: str = Field(..., description="The Unity Catalog schama name")
22
+ table_name: str = Field(..., description="Source Delta table name")
23
+ collection_name: str = Field("mem0", description="Vector search index name")
24
+ index_type: VectorIndexType = Field("DELTA_SYNC", description="Index type: DELTA_SYNC or DIRECT_ACCESS")
25
+ embedding_model_endpoint_name: Optional[str] = Field(
26
+ None, description="Embedding model endpoint for Databricks-computed embeddings"
27
+ )
28
+ embedding_dimension: int = Field(1536, description="Vector embedding dimensions")
29
+ endpoint_type: EndpointType = Field("STANDARD", description="Endpoint type: STANDARD or STORAGE_OPTIMIZED")
30
+ pipeline_type: PipelineType = Field("TRIGGERED", description="Sync pipeline type: TRIGGERED or CONTINUOUS")
31
+ warehouse_name: Optional[str] = Field(None, description="Databricks SQL warehouse Name")
32
+ query_type: str = Field("ANN", description="Query type: `ANN` and `HYBRID`")
33
+
34
+ @model_validator(mode="before")
35
+ @classmethod
36
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
37
+ allowed_fields = set(cls.model_fields.keys())
38
+ input_fields = set(values.keys())
39
+ extra_fields = input_fields - allowed_fields
40
+ if extra_fields:
41
+ raise ValueError(
42
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
43
+ )
44
+ return values
45
+
46
+ @model_validator(mode="after")
47
+ def validate_authentication(self):
48
+ """Validate that either access_token or service principal credentials are provided."""
49
+ has_token = self.access_token is not None
50
+ has_service_principal = (self.client_id is not None and self.client_secret is not None) or (
51
+ self.azure_client_id is not None and self.azure_client_secret is not None
52
+ )
53
+
54
+ if not has_token and not has_service_principal:
55
+ raise ValueError(
56
+ "Either access_token or both client_id/client_secret or azure_client_id/azure_client_secret must be provided"
57
+ )
58
+
59
+ return self
60
+
61
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,65 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+
7
+ class ElasticsearchConfig(BaseModel):
8
+ collection_name: str = Field("mem0", description="Name of the index")
9
+ host: str = Field("localhost", description="Elasticsearch host")
10
+ port: int = Field(9200, description="Elasticsearch port")
11
+ user: Optional[str] = Field(None, description="Username for authentication")
12
+ password: Optional[str] = Field(None, description="Password for authentication")
13
+ cloud_id: Optional[str] = Field(None, description="Cloud ID for Elastic Cloud")
14
+ api_key: Optional[str] = Field(None, description="API key for authentication")
15
+ embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
16
+ verify_certs: bool = Field(True, description="Verify SSL certificates")
17
+ use_ssl: bool = Field(True, description="Use SSL for connection")
18
+ auto_create_index: bool = Field(True, description="Automatically create index during initialization")
19
+ custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
20
+ None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
21
+ )
22
+ headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to include in requests")
23
+
24
+ @model_validator(mode="before")
25
+ @classmethod
26
+ def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
27
+ # Check if either cloud_id or host/port is provided
28
+ if not values.get("cloud_id") and not values.get("host"):
29
+ raise ValueError("Either cloud_id or host must be provided")
30
+
31
+ # Check if authentication is provided
32
+ if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
33
+ raise ValueError("Either api_key or user/password must be provided")
34
+
35
+ return values
36
+
37
+ @model_validator(mode="before")
38
+ @classmethod
39
+ def validate_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
40
+ """Validate headers format and content"""
41
+ headers = values.get("headers")
42
+ if headers is not None:
43
+ # Check if headers is a dictionary
44
+ if not isinstance(headers, dict):
45
+ raise ValueError("headers must be a dictionary")
46
+
47
+ # Check if all keys and values are strings
48
+ for key, value in headers.items():
49
+ if not isinstance(key, str) or not isinstance(value, str):
50
+ raise ValueError("All header keys and values must be strings")
51
+
52
+ return values
53
+
54
+ @model_validator(mode="before")
55
+ @classmethod
56
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
57
+ allowed_fields = set(cls.model_fields.keys())
58
+ input_fields = set(values.keys())
59
+ extra_fields = input_fields - allowed_fields
60
+ if extra_fields:
61
+ raise ValueError(
62
+ f"Extra fields not allowed: {', '.join(extra_fields)}. "
63
+ f"Please input only the following fields: {', '.join(allowed_fields)}"
64
+ )
65
+ return values
@@ -0,0 +1,37 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class FAISSConfig(BaseModel):
7
+ collection_name: str = Field("mem0", description="Default name for the collection")
8
+ path: Optional[str] = Field(None, description="Path to store FAISS index and metadata")
9
+ distance_strategy: str = Field(
10
+ "euclidean", description="Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'"
11
+ )
12
+ normalize_L2: bool = Field(
13
+ False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
14
+ )
15
+ embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
16
+
17
+ @model_validator(mode="before")
18
+ @classmethod
19
+ def validate_distance_strategy(cls, values: Dict[str, Any]) -> Dict[str, Any]:
20
+ distance_strategy = values.get("distance_strategy")
21
+ if distance_strategy and distance_strategy not in ["euclidean", "inner_product", "cosine"]:
22
+ raise ValueError("Invalid distance_strategy. Must be one of: 'euclidean', 'inner_product', 'cosine'")
23
+ return values
24
+
25
+ @model_validator(mode="before")
26
+ @classmethod
27
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
28
+ allowed_fields = set(cls.model_fields.keys())
29
+ input_fields = set(values.keys())
30
+ extra_fields = input_fields - allowed_fields
31
+ if extra_fields:
32
+ raise ValueError(
33
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
34
+ )
35
+ return values
36
+
37
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,30 @@
1
+ from typing import Any, ClassVar, Dict
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class LangchainConfig(BaseModel):
7
+ try:
8
+ from langchain_community.vectorstores import VectorStore
9
+ except ImportError:
10
+ raise ImportError(
11
+ "The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
12
+ )
13
+ VectorStore: ClassVar[type] = VectorStore
14
+
15
+ client: VectorStore = Field(description="Existing VectorStore instance")
16
+ collection_name: str = Field("mem0", description="Name of the collection to use")
17
+
18
+ @model_validator(mode="before")
19
+ @classmethod
20
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
21
+ allowed_fields = set(cls.model_fields.keys())
22
+ input_fields = set(values.keys())
23
+ extra_fields = input_fields - allowed_fields
24
+ if extra_fields:
25
+ raise ValueError(
26
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
27
+ )
28
+ return values
29
+
30
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,42 @@
1
+ from enum import Enum
2
+ from typing import Any, Dict
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
5
+
6
+
7
+ class MetricType(str, Enum):
8
+ """
9
+ Metric Constant for milvus/ zilliz server.
10
+ """
11
+
12
+ def __str__(self) -> str:
13
+ return str(self.value)
14
+
15
+ L2 = "L2"
16
+ IP = "IP"
17
+ COSINE = "COSINE"
18
+ HAMMING = "HAMMING"
19
+ JACCARD = "JACCARD"
20
+
21
+
22
+ class MilvusDBConfig(BaseModel):
23
+ url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
24
+ token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
25
+ collection_name: str = Field("mem0", description="Name of the collection")
26
+ embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
27
+ metric_type: str = Field("L2", description="Metric type for similarity search")
28
+ db_name: str = Field("", description="Name of the database")
29
+
30
+ @model_validator(mode="before")
31
+ @classmethod
32
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
33
+ allowed_fields = set(cls.model_fields.keys())
34
+ input_fields = set(values.keys())
35
+ extra_fields = input_fields - allowed_fields
36
+ if extra_fields:
37
+ raise ValueError(
38
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
39
+ )
40
+ return values
41
+
42
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,25 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+
6
+ class MongoDBConfig(BaseModel):
7
+ """Configuration for MongoDB vector database."""
8
+
9
+ db_name: str = Field("mem0_db", description="Name of the MongoDB database")
10
+ collection_name: str = Field("mem0", description="Name of the MongoDB collection")
11
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
12
+ mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
13
+
14
+ @model_validator(mode="before")
15
+ @classmethod
16
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
17
+ allowed_fields = set(cls.model_fields.keys())
18
+ input_fields = set(values.keys())
19
+ extra_fields = input_fields - allowed_fields
20
+ if extra_fields:
21
+ raise ValueError(
22
+ f"Extra fields not allowed: {', '.join(extra_fields)}. "
23
+ f"Please provide only the following fields: {', '.join(allowed_fields)}."
24
+ )
25
+ return values
@@ -0,0 +1,27 @@
1
+ """
2
+ Configuration for Amazon Neptune Analytics vector store.
3
+
4
+ This module provides configuration settings for integrating with Amazon Neptune Analytics
5
+ as a vector store backend for Mem0's memory layer.
6
+ """
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class NeptuneAnalyticsConfig(BaseModel):
12
+ """
13
+ Configuration class for Amazon Neptune Analytics vector store.
14
+
15
+ Amazon Neptune Analytics is a graph analytics engine that can be used as a vector store
16
+ for storing and retrieving memory embeddings in Mem0.
17
+
18
+ Attributes:
19
+ collection_name (str): Name of the collection to store vectors. Defaults to "mem0".
20
+ endpoint (str): Neptune Analytics graph endpoint URL or Graph ID for the runtime.
21
+ """
22
+ collection_name: str = Field("mem0", description="Default name for the collection")
23
+ endpoint: str = Field("endpoint", description="Graph ID for the runtime")
24
+
25
+ model_config = {
26
+ "arbitrary_types_allowed": False,
27
+ }