agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,41 @@
1
+ from typing import Any, Dict, Optional, Type, Union
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+
6
+ class OpenSearchConfig(BaseModel):
7
+ collection_name: str = Field("mem0", description="Name of the index")
8
+ host: str = Field("localhost", description="OpenSearch host")
9
+ port: int = Field(9200, description="OpenSearch port")
10
+ user: Optional[str] = Field(None, description="Username for authentication")
11
+ password: Optional[str] = Field(None, description="Password for authentication")
12
+ api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
13
+ embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
14
+ verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
15
+ use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
16
+ http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
17
+ connection_class: Optional[Union[str, Type]] = Field(
18
+ "RequestsHttpConnection", description="Connection class for OpenSearch"
19
+ )
20
+ pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
21
+
22
+ @model_validator(mode="before")
23
+ @classmethod
24
+ def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
25
+ # Check if host is provided
26
+ if not values.get("host"):
27
+ raise ValueError("Host must be provided for OpenSearch")
28
+
29
+ return values
30
+
31
+ @model_validator(mode="before")
32
+ @classmethod
33
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
34
+ allowed_fields = set(cls.model_fields.keys())
35
+ input_fields = set(values.keys())
36
+ extra_fields = input_fields - allowed_fields
37
+ if extra_fields:
38
+ raise ValueError(
39
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
40
+ )
41
+ return values
@@ -0,0 +1,52 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+
6
+ class PGVectorConfig(BaseModel):
7
+ dbname: str = Field("postgres", description="Default name for the database")
8
+ collection_name: str = Field("mem0", description="Default name for the collection")
9
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
10
+ user: Optional[str] = Field(None, description="Database user")
11
+ password: Optional[str] = Field(None, description="Database password")
12
+ host: Optional[str] = Field(None, description="Database host. Default is localhost")
13
+ port: Optional[int] = Field(None, description="Database port. Default is 1536")
14
+ diskann: Optional[bool] = Field(False, description="Use diskann for approximate nearest neighbors search")
15
+ hnsw: Optional[bool] = Field(True, description="Use hnsw for faster search")
16
+ minconn: Optional[int] = Field(1, description="Minimum number of connections in the pool")
17
+ maxconn: Optional[int] = Field(5, description="Maximum number of connections in the pool")
18
+ # New SSL and connection options
19
+ sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
20
+ connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
21
+ connection_pool: Optional[Any] = Field(None, description="psycopg connection pool object (overrides connection string and individual parameters)")
22
+
23
+ @model_validator(mode="before")
24
+ def check_auth_and_connection(cls, values):
25
+ # If connection_pool is provided, skip validation of individual connection parameters
26
+ if values.get("connection_pool") is not None:
27
+ return values
28
+
29
+ # If connection_string is provided, skip validation of individual connection parameters
30
+ if values.get("connection_string") is not None:
31
+ return values
32
+
33
+ # Otherwise, validate individual connection parameters
34
+ user, password = values.get("user"), values.get("password")
35
+ host, port = values.get("host"), values.get("port")
36
+ if not user and not password:
37
+ raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.")
38
+ if not host and not port:
39
+ raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.")
40
+ return values
41
+
42
+ @model_validator(mode="before")
43
+ @classmethod
44
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
45
+ allowed_fields = set(cls.model_fields.keys())
46
+ input_fields = set(values.keys())
47
+ extra_fields = input_fields - allowed_fields
48
+ if extra_fields:
49
+ raise ValueError(
50
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
51
+ )
52
+ return values
@@ -0,0 +1,55 @@
1
+ import os
2
+ from typing import Any, Dict, Optional
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
5
+
6
+
7
+ class PineconeConfig(BaseModel):
8
+ """Configuration for Pinecone vector database."""
9
+
10
+ collection_name: str = Field("mem0", description="Name of the index/collection")
11
+ embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
12
+ client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
13
+ api_key: Optional[str] = Field(None, description="API key for Pinecone")
14
+ environment: Optional[str] = Field(None, description="Pinecone environment")
15
+ serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
16
+ pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
17
+ hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
18
+ metric: str = Field("cosine", description="Distance metric for vector similarity")
19
+ batch_size: int = Field(100, description="Batch size for operations")
20
+ extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
21
+ namespace: Optional[str] = Field(None, description="Namespace for the collection")
22
+
23
+ @model_validator(mode="before")
24
+ @classmethod
25
+ def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
26
+ api_key, client = values.get("api_key"), values.get("client")
27
+ if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
28
+ raise ValueError(
29
+ "Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
30
+ )
31
+ return values
32
+
33
+ @model_validator(mode="before")
34
+ @classmethod
35
+ def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
36
+ pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
37
+ if pod_config and serverless_config:
38
+ raise ValueError(
39
+ "Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
40
+ )
41
+ return values
42
+
43
+ @model_validator(mode="before")
44
+ @classmethod
45
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
46
+ allowed_fields = set(cls.model_fields.keys())
47
+ input_fields = set(values.keys())
48
+ extra_fields = input_fields - allowed_fields
49
+ if extra_fields:
50
+ raise ValueError(
51
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
52
+ )
53
+ return values
54
+
55
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,47 @@
1
+ from typing import Any, ClassVar, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class QdrantConfig(BaseModel):
7
+ from qdrant_client import QdrantClient
8
+
9
+ QdrantClient: ClassVar[type] = QdrantClient
10
+
11
+ collection_name: str = Field("mem0", description="Name of the collection")
12
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
13
+ client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
14
+ host: Optional[str] = Field(None, description="Host address for Qdrant server")
15
+ port: Optional[int] = Field(None, description="Port for Qdrant server")
16
+ path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
17
+ url: Optional[str] = Field(None, description="Full URL for Qdrant server")
18
+ api_key: Optional[str] = Field(None, description="API key for Qdrant server")
19
+ on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
20
+
21
+ @model_validator(mode="before")
22
+ @classmethod
23
+ def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
24
+ host, port, path, url, api_key = (
25
+ values.get("host"),
26
+ values.get("port"),
27
+ values.get("path"),
28
+ values.get("url"),
29
+ values.get("api_key"),
30
+ )
31
+ if not path and not (host and port) and not (url and api_key):
32
+ raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
33
+ return values
34
+
35
+ @model_validator(mode="before")
36
+ @classmethod
37
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
38
+ allowed_fields = set(cls.model_fields.keys())
39
+ input_fields = set(values.keys())
40
+ extra_fields = input_fields - allowed_fields
41
+ if extra_fields:
42
+ raise ValueError(
43
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
44
+ )
45
+ return values
46
+
47
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,24 @@
1
+ from typing import Any, Dict
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ # TODO: Upgrade to latest pydantic version
7
+ class RedisDBConfig(BaseModel):
8
+ redis_url: str = Field(..., description="Redis URL")
9
+ collection_name: str = Field("mem0", description="Collection name")
10
+ embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
11
+
12
+ @model_validator(mode="before")
13
+ @classmethod
14
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
15
+ allowed_fields = set(cls.model_fields.keys())
16
+ input_fields = set(values.keys())
17
+ extra_fields = input_fields - allowed_fields
18
+ if extra_fields:
19
+ raise ValueError(
20
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
21
+ )
22
+ return values
23
+
24
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,28 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class S3VectorsConfig(BaseModel):
7
+ vector_bucket_name: str = Field(description="Name of the S3 Vector bucket")
8
+ collection_name: str = Field("mem0", description="Name of the vector index")
9
+ embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
10
+ distance_metric: str = Field(
11
+ "cosine",
12
+ description="Distance metric for similarity search. Options: 'cosine', 'euclidean'",
13
+ )
14
+ region_name: Optional[str] = Field(None, description="AWS region for the S3 Vectors client")
15
+
16
+ @model_validator(mode="before")
17
+ @classmethod
18
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
19
+ allowed_fields = set(cls.model_fields.keys())
20
+ input_fields = set(values.keys())
21
+ extra_fields = input_fields - allowed_fields
22
+ if extra_fields:
23
+ raise ValueError(
24
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
25
+ )
26
+ return values
27
+
28
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,44 @@
1
+ from enum import Enum
2
+ from typing import Any, Dict, Optional
3
+
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+
7
+ class IndexMethod(str, Enum):
8
+ AUTO = "auto"
9
+ HNSW = "hnsw"
10
+ IVFFLAT = "ivfflat"
11
+
12
+
13
+ class IndexMeasure(str, Enum):
14
+ COSINE = "cosine_distance"
15
+ L2 = "l2_distance"
16
+ L1 = "l1_distance"
17
+ MAX_INNER_PRODUCT = "max_inner_product"
18
+
19
+
20
+ class SupabaseConfig(BaseModel):
21
+ connection_string: str = Field(..., description="PostgreSQL connection string")
22
+ collection_name: str = Field("mem0", description="Name for the vector collection")
23
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
24
+ index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
25
+ index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
26
+
27
+ @model_validator(mode="before")
28
+ def check_connection_string(cls, values):
29
+ conn_str = values.get("connection_string")
30
+ if not conn_str or not conn_str.startswith("postgresql://"):
31
+ raise ValueError("A valid PostgreSQL connection string must be provided")
32
+ return values
33
+
34
+ @model_validator(mode="before")
35
+ @classmethod
36
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
37
+ allowed_fields = set(cls.model_fields.keys())
38
+ input_fields = set(values.keys())
39
+ extra_fields = input_fields - allowed_fields
40
+ if extra_fields:
41
+ raise ValueError(
42
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
43
+ )
44
+ return values
@@ -0,0 +1,34 @@
1
+ import os
2
+ from typing import Any, ClassVar, Dict, Optional
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
5
+
6
+ try:
7
+ from upstash_vector import Index
8
+ except ImportError:
9
+ raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
10
+
11
+
12
+ class UpstashVectorConfig(BaseModel):
13
+ Index: ClassVar[type] = Index
14
+
15
+ url: Optional[str] = Field(None, description="URL for Upstash Vector index")
16
+ token: Optional[str] = Field(None, description="Token for Upstash Vector index")
17
+ client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
18
+ collection_name: str = Field("mem0", description="Namespace to use for the index")
19
+ enable_embeddings: bool = Field(
20
+ False, description="Whether to use built-in upstash embeddings or not. Default is True."
21
+ )
22
+
23
+ @model_validator(mode="before")
24
+ @classmethod
25
+ def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
26
+ client = values.get("client")
27
+ url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
28
+ token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
29
+
30
+ if not client and not (url and token):
31
+ raise ValueError("Either a client or URL and token must be provided.")
32
+ return values
33
+
34
+ model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -0,0 +1,15 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ValkeyConfig(BaseModel):
5
+ """Configuration for Valkey vector store."""
6
+
7
+ valkey_url: str
8
+ collection_name: str
9
+ embedding_model_dims: int
10
+ timezone: str = "UTC"
11
+ index_type: str = "hnsw" # Default to HNSW, can be 'hnsw' or 'flat'
12
+ # HNSW specific parameters with recommended defaults
13
+ hnsw_m: int = 16 # Number of connections per layer (default from Valkey docs)
14
+ hnsw_ef_construction: int = 200 # Search width during construction
15
+ hnsw_ef_runtime: int = 10 # Search width during queries
@@ -0,0 +1,28 @@
1
+ from typing import Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field
4
+
5
+
6
+ class GoogleMatchingEngineConfig(BaseModel):
7
+ project_id: str = Field(description="Google Cloud project ID")
8
+ project_number: str = Field(description="Google Cloud project number")
9
+ region: str = Field(description="Google Cloud region")
10
+ endpoint_id: str = Field(description="Vertex AI Vector Search endpoint ID")
11
+ index_id: str = Field(description="Vertex AI Vector Search index ID")
12
+ deployment_index_id: str = Field(description="Deployment-specific index ID")
13
+ collection_name: Optional[str] = Field(None, description="Collection name, defaults to index_id")
14
+ credentials_path: Optional[str] = Field(None, description="Path to service account credentials JSON file")
15
+ service_account_json: Optional[Dict] = Field(None, description="Service account credentials as dictionary (alternative to credentials_path)")
16
+ vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
17
+
18
+ model_config = ConfigDict(extra="forbid")
19
+
20
+ def __init__(self, **kwargs):
21
+ super().__init__(**kwargs)
22
+ if not self.collection_name:
23
+ self.collection_name = self.index_id
24
+
25
+ def model_post_init(self, _context) -> None:
26
+ """Set collection_name to index_id if not provided"""
27
+ if self.collection_name is None:
28
+ self.collection_name = self.index_id
@@ -0,0 +1,41 @@
1
+ from typing import Any, ClassVar, Dict, Optional
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
4
+
5
+
6
+ class WeaviateConfig(BaseModel):
7
+ from weaviate import WeaviateClient
8
+
9
+ WeaviateClient: ClassVar[type] = WeaviateClient
10
+
11
+ collection_name: str = Field("mem0", description="Name of the collection")
12
+ embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
13
+ cluster_url: Optional[str] = Field(None, description="URL for Weaviate server")
14
+ auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication")
15
+ additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests")
16
+
17
+ @model_validator(mode="before")
18
+ @classmethod
19
+ def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
20
+ cluster_url = values.get("cluster_url")
21
+
22
+ if not cluster_url:
23
+ raise ValueError("'cluster_url' must be provided.")
24
+
25
+ return values
26
+
27
+ @model_validator(mode="before")
28
+ @classmethod
29
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
30
+ allowed_fields = set(cls.model_fields.keys())
31
+ input_fields = set(values.keys())
32
+ extra_fields = input_fields - allowed_fields
33
+
34
+ if extra_fields:
35
+ raise ValueError(
36
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
37
+ )
38
+
39
+ return values
40
+
41
+ model_config = ConfigDict(arbitrary_types_allowed=True)
File without changes
@@ -0,0 +1,100 @@
1
+ import json
2
+ import os
3
+ from typing import Literal, Optional
4
+
5
+ try:
6
+ import boto3
7
+ except ImportError:
8
+ raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
9
+
10
+ import numpy as np
11
+
12
+ from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
13
+ from agentrun_mem0.embeddings.base import EmbeddingBase
14
+
15
+
16
+ class AWSBedrockEmbedding(EmbeddingBase):
17
+ """AWS Bedrock embedding implementation.
18
+
19
+ This class uses AWS Bedrock's embedding models.
20
+ """
21
+
22
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
23
+ super().__init__(config)
24
+
25
+ self.config.model = self.config.model or "amazon.titan-embed-text-v1"
26
+
27
+ # Get AWS config from environment variables or use defaults
28
+ aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
29
+ aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
30
+ aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
31
+
32
+ # Check if AWS config is provided in the config
33
+ if hasattr(self.config, "aws_access_key_id"):
34
+ aws_access_key = self.config.aws_access_key_id
35
+ if hasattr(self.config, "aws_secret_access_key"):
36
+ aws_secret_key = self.config.aws_secret_access_key
37
+
38
+ # AWS region is always set in config - see BaseEmbedderConfig
39
+ aws_region = self.config.aws_region or "us-west-2"
40
+
41
+ self.client = boto3.client(
42
+ "bedrock-runtime",
43
+ region_name=aws_region,
44
+ aws_access_key_id=aws_access_key if aws_access_key else None,
45
+ aws_secret_access_key=aws_secret_key if aws_secret_key else None,
46
+ aws_session_token=aws_session_token if aws_session_token else None,
47
+ )
48
+
49
+ def _normalize_vector(self, embeddings):
50
+ """Normalize the embedding to a unit vector."""
51
+ emb = np.array(embeddings)
52
+ norm_emb = emb / np.linalg.norm(emb)
53
+ return norm_emb.tolist()
54
+
55
+ def _get_embedding(self, text):
56
+ """Call out to Bedrock embedding endpoint."""
57
+
58
+ # Format input body based on the provider
59
+ provider = self.config.model.split(".")[0]
60
+ input_body = {}
61
+
62
+ if provider == "cohere":
63
+ input_body["input_type"] = "search_document"
64
+ input_body["texts"] = [text]
65
+ else:
66
+ # Amazon and other providers
67
+ input_body["inputText"] = text
68
+
69
+ body = json.dumps(input_body)
70
+
71
+ try:
72
+ response = self.client.invoke_model(
73
+ body=body,
74
+ modelId=self.config.model,
75
+ accept="application/json",
76
+ contentType="application/json",
77
+ )
78
+
79
+ response_body = json.loads(response.get("body").read())
80
+
81
+ if provider == "cohere":
82
+ embeddings = response_body.get("embeddings")[0]
83
+ else:
84
+ embeddings = response_body.get("embedding")
85
+
86
+ return embeddings
87
+ except Exception as e:
88
+ raise ValueError(f"Error getting embedding from AWS Bedrock: {e}")
89
+
90
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
91
+ """
92
+ Get the embedding for the given text using AWS Bedrock.
93
+
94
+ Args:
95
+ text (str): The text to embed.
96
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
97
+ Returns:
98
+ list: The embedding vector.
99
+ """
100
+ return self._get_embedding(text)
@@ -0,0 +1,55 @@
1
+ import os
2
+ from typing import Literal, Optional
3
+
4
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
5
+ from openai import AzureOpenAI
6
+
7
+ from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
8
+ from agentrun_mem0.embeddings.base import EmbeddingBase
9
+
10
+ SCOPE = "https://cognitiveservices.azure.com/.default"
11
+
12
+
13
+ class AzureOpenAIEmbedding(EmbeddingBase):
14
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
15
+ super().__init__(config)
16
+
17
+ api_key = self.config.azure_kwargs.api_key or os.getenv("EMBEDDING_AZURE_OPENAI_API_KEY")
18
+ azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
19
+ azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
20
+ api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
21
+ default_headers = self.config.azure_kwargs.default_headers
22
+
23
+ # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
24
+ if api_key is None or api_key == "" or api_key == "your-api-key":
25
+ self.credential = DefaultAzureCredential()
26
+ azure_ad_token_provider = get_bearer_token_provider(
27
+ self.credential,
28
+ SCOPE,
29
+ )
30
+ api_key = None
31
+ else:
32
+ azure_ad_token_provider = None
33
+
34
+ self.client = AzureOpenAI(
35
+ azure_deployment=azure_deployment,
36
+ azure_endpoint=azure_endpoint,
37
+ azure_ad_token_provider=azure_ad_token_provider,
38
+ api_version=api_version,
39
+ api_key=api_key,
40
+ http_client=self.config.http_client,
41
+ default_headers=default_headers,
42
+ )
43
+
44
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
45
+ """
46
+ Get the embedding for the given text using OpenAI.
47
+
48
+ Args:
49
+ text (str): The text to embed.
50
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
51
+ Returns:
52
+ list: The embedding vector.
53
+ """
54
+ text = text.replace("\n", " ")
55
+ return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
@@ -0,0 +1,31 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal, Optional
3
+
4
+ from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
5
+
6
+
7
+ class EmbeddingBase(ABC):
8
+ """Initialized a base embedding class
9
+
10
+ :param config: Embedding configuration option class, defaults to None
11
+ :type config: Optional[BaseEmbedderConfig], optional
12
+ """
13
+
14
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
15
+ if config is None:
16
+ self.config = BaseEmbedderConfig()
17
+ else:
18
+ self.config = config
19
+
20
+ @abstractmethod
21
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
22
+ """
23
+ Get the embedding for the given text.
24
+
25
+ Args:
26
+ text (str): The text to embed.
27
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
28
+ Returns:
29
+ list: The embedding vector.
30
+ """
31
+ pass
@@ -0,0 +1,30 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class EmbedderConfig(BaseModel):
7
+ provider: str = Field(
8
+ description="Provider of the embedding model (e.g., 'ollama', 'openai')",
9
+ default="openai",
10
+ )
11
+ config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
12
+
13
+ @field_validator("config")
14
+ def validate_config(cls, v, values):
15
+ provider = values.data.get("provider")
16
+ if provider in [
17
+ "openai",
18
+ "ollama",
19
+ "huggingface",
20
+ "azure_openai",
21
+ "gemini",
22
+ "vertexai",
23
+ "together",
24
+ "lmstudio",
25
+ "langchain",
26
+ "aws_bedrock",
27
+ ]:
28
+ return v
29
+ else:
30
+ raise ValueError(f"Unsupported embedding provider: {provider}")