agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OpenSearchConfig(BaseModel):
|
|
7
|
+
collection_name: str = Field("mem0", description="Name of the index")
|
|
8
|
+
host: str = Field("localhost", description="OpenSearch host")
|
|
9
|
+
port: int = Field(9200, description="OpenSearch port")
|
|
10
|
+
user: Optional[str] = Field(None, description="Username for authentication")
|
|
11
|
+
password: Optional[str] = Field(None, description="Password for authentication")
|
|
12
|
+
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
|
13
|
+
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
|
14
|
+
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
|
15
|
+
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
|
16
|
+
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
|
17
|
+
connection_class: Optional[Union[str, Type]] = Field(
|
|
18
|
+
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
|
19
|
+
)
|
|
20
|
+
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
|
|
21
|
+
|
|
22
|
+
@model_validator(mode="before")
|
|
23
|
+
@classmethod
|
|
24
|
+
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
25
|
+
# Check if host is provided
|
|
26
|
+
if not values.get("host"):
|
|
27
|
+
raise ValueError("Host must be provided for OpenSearch")
|
|
28
|
+
|
|
29
|
+
return values
|
|
30
|
+
|
|
31
|
+
@model_validator(mode="before")
|
|
32
|
+
@classmethod
|
|
33
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
34
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
35
|
+
input_fields = set(values.keys())
|
|
36
|
+
extra_fields = input_fields - allowed_fields
|
|
37
|
+
if extra_fields:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
|
40
|
+
)
|
|
41
|
+
return values
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PGVectorConfig(BaseModel):
|
|
7
|
+
dbname: str = Field("postgres", description="Default name for the database")
|
|
8
|
+
collection_name: str = Field("mem0", description="Default name for the collection")
|
|
9
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
10
|
+
user: Optional[str] = Field(None, description="Database user")
|
|
11
|
+
password: Optional[str] = Field(None, description="Database password")
|
|
12
|
+
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
|
13
|
+
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
|
14
|
+
diskann: Optional[bool] = Field(False, description="Use diskann for approximate nearest neighbors search")
|
|
15
|
+
hnsw: Optional[bool] = Field(True, description="Use hnsw for faster search")
|
|
16
|
+
minconn: Optional[int] = Field(1, description="Minimum number of connections in the pool")
|
|
17
|
+
maxconn: Optional[int] = Field(5, description="Maximum number of connections in the pool")
|
|
18
|
+
# New SSL and connection options
|
|
19
|
+
sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
|
|
20
|
+
connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
|
|
21
|
+
connection_pool: Optional[Any] = Field(None, description="psycopg connection pool object (overrides connection string and individual parameters)")
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="before")
|
|
24
|
+
def check_auth_and_connection(cls, values):
|
|
25
|
+
# If connection_pool is provided, skip validation of individual connection parameters
|
|
26
|
+
if values.get("connection_pool") is not None:
|
|
27
|
+
return values
|
|
28
|
+
|
|
29
|
+
# If connection_string is provided, skip validation of individual connection parameters
|
|
30
|
+
if values.get("connection_string") is not None:
|
|
31
|
+
return values
|
|
32
|
+
|
|
33
|
+
# Otherwise, validate individual connection parameters
|
|
34
|
+
user, password = values.get("user"), values.get("password")
|
|
35
|
+
host, port = values.get("host"), values.get("port")
|
|
36
|
+
if not user and not password:
|
|
37
|
+
raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.")
|
|
38
|
+
if not host and not port:
|
|
39
|
+
raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.")
|
|
40
|
+
return values
|
|
41
|
+
|
|
42
|
+
@model_validator(mode="before")
|
|
43
|
+
@classmethod
|
|
44
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
45
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
46
|
+
input_fields = set(values.keys())
|
|
47
|
+
extra_fields = input_fields - allowed_fields
|
|
48
|
+
if extra_fields:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
51
|
+
)
|
|
52
|
+
return values
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PineconeConfig(BaseModel):
|
|
8
|
+
"""Configuration for Pinecone vector database."""
|
|
9
|
+
|
|
10
|
+
collection_name: str = Field("mem0", description="Name of the index/collection")
|
|
11
|
+
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
|
12
|
+
client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
|
|
13
|
+
api_key: Optional[str] = Field(None, description="API key for Pinecone")
|
|
14
|
+
environment: Optional[str] = Field(None, description="Pinecone environment")
|
|
15
|
+
serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
|
|
16
|
+
pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
|
|
17
|
+
hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
|
|
18
|
+
metric: str = Field("cosine", description="Distance metric for vector similarity")
|
|
19
|
+
batch_size: int = Field(100, description="Batch size for operations")
|
|
20
|
+
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
|
|
21
|
+
namespace: Optional[str] = Field(None, description="Namespace for the collection")
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
26
|
+
api_key, client = values.get("api_key"), values.get("client")
|
|
27
|
+
if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
|
|
30
|
+
)
|
|
31
|
+
return values
|
|
32
|
+
|
|
33
|
+
@model_validator(mode="before")
|
|
34
|
+
@classmethod
|
|
35
|
+
def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
36
|
+
pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
|
|
37
|
+
if pod_config and serverless_config:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
|
|
40
|
+
)
|
|
41
|
+
return values
|
|
42
|
+
|
|
43
|
+
@model_validator(mode="before")
|
|
44
|
+
@classmethod
|
|
45
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
46
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
47
|
+
input_fields = set(values.keys())
|
|
48
|
+
extra_fields = input_fields - allowed_fields
|
|
49
|
+
if extra_fields:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
52
|
+
)
|
|
53
|
+
return values
|
|
54
|
+
|
|
55
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class QdrantConfig(BaseModel):
|
|
7
|
+
from qdrant_client import QdrantClient
|
|
8
|
+
|
|
9
|
+
QdrantClient: ClassVar[type] = QdrantClient
|
|
10
|
+
|
|
11
|
+
collection_name: str = Field("mem0", description="Name of the collection")
|
|
12
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
13
|
+
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
|
14
|
+
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
|
15
|
+
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
|
16
|
+
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
|
17
|
+
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
|
18
|
+
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
|
19
|
+
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
|
20
|
+
|
|
21
|
+
@model_validator(mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
24
|
+
host, port, path, url, api_key = (
|
|
25
|
+
values.get("host"),
|
|
26
|
+
values.get("port"),
|
|
27
|
+
values.get("path"),
|
|
28
|
+
values.get("url"),
|
|
29
|
+
values.get("api_key"),
|
|
30
|
+
)
|
|
31
|
+
if not path and not (host and port) and not (url and api_key):
|
|
32
|
+
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
|
|
33
|
+
return values
|
|
34
|
+
|
|
35
|
+
@model_validator(mode="before")
|
|
36
|
+
@classmethod
|
|
37
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
38
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
39
|
+
input_fields = set(values.keys())
|
|
40
|
+
extra_fields = input_fields - allowed_fields
|
|
41
|
+
if extra_fields:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
44
|
+
)
|
|
45
|
+
return values
|
|
46
|
+
|
|
47
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO: Upgrade to latest pydantic version
|
|
7
|
+
class RedisDBConfig(BaseModel):
|
|
8
|
+
redis_url: str = Field(..., description="Redis URL")
|
|
9
|
+
collection_name: str = Field("mem0", description="Collection name")
|
|
10
|
+
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
|
|
11
|
+
|
|
12
|
+
@model_validator(mode="before")
|
|
13
|
+
@classmethod
|
|
14
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
15
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
16
|
+
input_fields = set(values.keys())
|
|
17
|
+
extra_fields = input_fields - allowed_fields
|
|
18
|
+
if extra_fields:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
21
|
+
)
|
|
22
|
+
return values
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class S3VectorsConfig(BaseModel):
|
|
7
|
+
vector_bucket_name: str = Field(description="Name of the S3 Vector bucket")
|
|
8
|
+
collection_name: str = Field("mem0", description="Name of the vector index")
|
|
9
|
+
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
|
10
|
+
distance_metric: str = Field(
|
|
11
|
+
"cosine",
|
|
12
|
+
description="Distance metric for similarity search. Options: 'cosine', 'euclidean'",
|
|
13
|
+
)
|
|
14
|
+
region_name: Optional[str] = Field(None, description="AWS region for the S3 Vectors client")
|
|
15
|
+
|
|
16
|
+
@model_validator(mode="before")
|
|
17
|
+
@classmethod
|
|
18
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
19
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
20
|
+
input_fields = set(values.keys())
|
|
21
|
+
extra_fields = input_fields - allowed_fields
|
|
22
|
+
if extra_fields:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
25
|
+
)
|
|
26
|
+
return values
|
|
27
|
+
|
|
28
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IndexMethod(str, Enum):
|
|
8
|
+
AUTO = "auto"
|
|
9
|
+
HNSW = "hnsw"
|
|
10
|
+
IVFFLAT = "ivfflat"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IndexMeasure(str, Enum):
|
|
14
|
+
COSINE = "cosine_distance"
|
|
15
|
+
L2 = "l2_distance"
|
|
16
|
+
L1 = "l1_distance"
|
|
17
|
+
MAX_INNER_PRODUCT = "max_inner_product"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SupabaseConfig(BaseModel):
|
|
21
|
+
connection_string: str = Field(..., description="PostgreSQL connection string")
|
|
22
|
+
collection_name: str = Field("mem0", description="Name for the vector collection")
|
|
23
|
+
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
|
24
|
+
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
|
|
25
|
+
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="before")
|
|
28
|
+
def check_connection_string(cls, values):
|
|
29
|
+
conn_str = values.get("connection_string")
|
|
30
|
+
if not conn_str or not conn_str.startswith("postgresql://"):
|
|
31
|
+
raise ValueError("A valid PostgreSQL connection string must be provided")
|
|
32
|
+
return values
|
|
33
|
+
|
|
34
|
+
@model_validator(mode="before")
|
|
35
|
+
@classmethod
|
|
36
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
37
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
38
|
+
input_fields = set(values.keys())
|
|
39
|
+
extra_fields = input_fields - allowed_fields
|
|
40
|
+
if extra_fields:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
43
|
+
)
|
|
44
|
+
return values
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from upstash_vector import Index
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UpstashVectorConfig(BaseModel):
|
|
13
|
+
Index: ClassVar[type] = Index
|
|
14
|
+
|
|
15
|
+
url: Optional[str] = Field(None, description="URL for Upstash Vector index")
|
|
16
|
+
token: Optional[str] = Field(None, description="Token for Upstash Vector index")
|
|
17
|
+
client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
|
|
18
|
+
collection_name: str = Field("mem0", description="Namespace to use for the index")
|
|
19
|
+
enable_embeddings: bool = Field(
|
|
20
|
+
False, description="Whether to use built-in upstash embeddings or not. Default is True."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
26
|
+
client = values.get("client")
|
|
27
|
+
url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
|
|
28
|
+
token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
|
|
29
|
+
|
|
30
|
+
if not client and not (url and token):
|
|
31
|
+
raise ValueError("Either a client or URL and token must be provided.")
|
|
32
|
+
return values
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ValkeyConfig(BaseModel):
|
|
5
|
+
"""Configuration for Valkey vector store."""
|
|
6
|
+
|
|
7
|
+
valkey_url: str
|
|
8
|
+
collection_name: str
|
|
9
|
+
embedding_model_dims: int
|
|
10
|
+
timezone: str = "UTC"
|
|
11
|
+
index_type: str = "hnsw" # Default to HNSW, can be 'hnsw' or 'flat'
|
|
12
|
+
# HNSW specific parameters with recommended defaults
|
|
13
|
+
hnsw_m: int = 16 # Number of connections per layer (default from Valkey docs)
|
|
14
|
+
hnsw_ef_construction: int = 200 # Search width during construction
|
|
15
|
+
hnsw_ef_runtime: int = 10 # Search width during queries
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GoogleMatchingEngineConfig(BaseModel):
|
|
7
|
+
project_id: str = Field(description="Google Cloud project ID")
|
|
8
|
+
project_number: str = Field(description="Google Cloud project number")
|
|
9
|
+
region: str = Field(description="Google Cloud region")
|
|
10
|
+
endpoint_id: str = Field(description="Vertex AI Vector Search endpoint ID")
|
|
11
|
+
index_id: str = Field(description="Vertex AI Vector Search index ID")
|
|
12
|
+
deployment_index_id: str = Field(description="Deployment-specific index ID")
|
|
13
|
+
collection_name: Optional[str] = Field(None, description="Collection name, defaults to index_id")
|
|
14
|
+
credentials_path: Optional[str] = Field(None, description="Path to service account credentials JSON file")
|
|
15
|
+
service_account_json: Optional[Dict] = Field(None, description="Service account credentials as dictionary (alternative to credentials_path)")
|
|
16
|
+
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
|
|
17
|
+
|
|
18
|
+
model_config = ConfigDict(extra="forbid")
|
|
19
|
+
|
|
20
|
+
def __init__(self, **kwargs):
|
|
21
|
+
super().__init__(**kwargs)
|
|
22
|
+
if not self.collection_name:
|
|
23
|
+
self.collection_name = self.index_id
|
|
24
|
+
|
|
25
|
+
def model_post_init(self, _context) -> None:
|
|
26
|
+
"""Set collection_name to index_id if not provided"""
|
|
27
|
+
if self.collection_name is None:
|
|
28
|
+
self.collection_name = self.index_id
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WeaviateConfig(BaseModel):
|
|
7
|
+
from weaviate import WeaviateClient
|
|
8
|
+
|
|
9
|
+
WeaviateClient: ClassVar[type] = WeaviateClient
|
|
10
|
+
|
|
11
|
+
collection_name: str = Field("mem0", description="Name of the collection")
|
|
12
|
+
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
|
13
|
+
cluster_url: Optional[str] = Field(None, description="URL for Weaviate server")
|
|
14
|
+
auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication")
|
|
15
|
+
additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests")
|
|
16
|
+
|
|
17
|
+
@model_validator(mode="before")
|
|
18
|
+
@classmethod
|
|
19
|
+
def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
20
|
+
cluster_url = values.get("cluster_url")
|
|
21
|
+
|
|
22
|
+
if not cluster_url:
|
|
23
|
+
raise ValueError("'cluster_url' must be provided.")
|
|
24
|
+
|
|
25
|
+
return values
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="before")
|
|
28
|
+
@classmethod
|
|
29
|
+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
30
|
+
allowed_fields = set(cls.model_fields.keys())
|
|
31
|
+
input_fields = set(values.keys())
|
|
32
|
+
extra_fields = input_fields - allowed_fields
|
|
33
|
+
|
|
34
|
+
if extra_fields:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return values
|
|
40
|
+
|
|
41
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
File without changes
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import boto3
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
13
|
+
from agentrun_mem0.embeddings.base import EmbeddingBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AWSBedrockEmbedding(EmbeddingBase):
|
|
17
|
+
"""AWS Bedrock embedding implementation.
|
|
18
|
+
|
|
19
|
+
This class uses AWS Bedrock's embedding models.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
23
|
+
super().__init__(config)
|
|
24
|
+
|
|
25
|
+
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
|
26
|
+
|
|
27
|
+
# Get AWS config from environment variables or use defaults
|
|
28
|
+
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
|
29
|
+
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
|
30
|
+
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
|
|
31
|
+
|
|
32
|
+
# Check if AWS config is provided in the config
|
|
33
|
+
if hasattr(self.config, "aws_access_key_id"):
|
|
34
|
+
aws_access_key = self.config.aws_access_key_id
|
|
35
|
+
if hasattr(self.config, "aws_secret_access_key"):
|
|
36
|
+
aws_secret_key = self.config.aws_secret_access_key
|
|
37
|
+
|
|
38
|
+
# AWS region is always set in config - see BaseEmbedderConfig
|
|
39
|
+
aws_region = self.config.aws_region or "us-west-2"
|
|
40
|
+
|
|
41
|
+
self.client = boto3.client(
|
|
42
|
+
"bedrock-runtime",
|
|
43
|
+
region_name=aws_region,
|
|
44
|
+
aws_access_key_id=aws_access_key if aws_access_key else None,
|
|
45
|
+
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
|
46
|
+
aws_session_token=aws_session_token if aws_session_token else None,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def _normalize_vector(self, embeddings):
|
|
50
|
+
"""Normalize the embedding to a unit vector."""
|
|
51
|
+
emb = np.array(embeddings)
|
|
52
|
+
norm_emb = emb / np.linalg.norm(emb)
|
|
53
|
+
return norm_emb.tolist()
|
|
54
|
+
|
|
55
|
+
def _get_embedding(self, text):
|
|
56
|
+
"""Call out to Bedrock embedding endpoint."""
|
|
57
|
+
|
|
58
|
+
# Format input body based on the provider
|
|
59
|
+
provider = self.config.model.split(".")[0]
|
|
60
|
+
input_body = {}
|
|
61
|
+
|
|
62
|
+
if provider == "cohere":
|
|
63
|
+
input_body["input_type"] = "search_document"
|
|
64
|
+
input_body["texts"] = [text]
|
|
65
|
+
else:
|
|
66
|
+
# Amazon and other providers
|
|
67
|
+
input_body["inputText"] = text
|
|
68
|
+
|
|
69
|
+
body = json.dumps(input_body)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
response = self.client.invoke_model(
|
|
73
|
+
body=body,
|
|
74
|
+
modelId=self.config.model,
|
|
75
|
+
accept="application/json",
|
|
76
|
+
contentType="application/json",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
response_body = json.loads(response.get("body").read())
|
|
80
|
+
|
|
81
|
+
if provider == "cohere":
|
|
82
|
+
embeddings = response_body.get("embeddings")[0]
|
|
83
|
+
else:
|
|
84
|
+
embeddings = response_body.get("embedding")
|
|
85
|
+
|
|
86
|
+
return embeddings
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise ValueError(f"Error getting embedding from AWS Bedrock: {e}")
|
|
89
|
+
|
|
90
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
91
|
+
"""
|
|
92
|
+
Get the embedding for the given text using AWS Bedrock.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
text (str): The text to embed.
|
|
96
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
97
|
+
Returns:
|
|
98
|
+
list: The embedding vector.
|
|
99
|
+
"""
|
|
100
|
+
return self._get_embedding(text)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
5
|
+
from openai import AzureOpenAI
|
|
6
|
+
|
|
7
|
+
from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
8
|
+
from agentrun_mem0.embeddings.base import EmbeddingBase
|
|
9
|
+
|
|
10
|
+
SCOPE = "https://cognitiveservices.azure.com/.default"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AzureOpenAIEmbedding(EmbeddingBase):
|
|
14
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
15
|
+
super().__init__(config)
|
|
16
|
+
|
|
17
|
+
api_key = self.config.azure_kwargs.api_key or os.getenv("EMBEDDING_AZURE_OPENAI_API_KEY")
|
|
18
|
+
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
|
|
19
|
+
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
|
|
20
|
+
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
|
|
21
|
+
default_headers = self.config.azure_kwargs.default_headers
|
|
22
|
+
|
|
23
|
+
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
|
24
|
+
if api_key is None or api_key == "" or api_key == "your-api-key":
|
|
25
|
+
self.credential = DefaultAzureCredential()
|
|
26
|
+
azure_ad_token_provider = get_bearer_token_provider(
|
|
27
|
+
self.credential,
|
|
28
|
+
SCOPE,
|
|
29
|
+
)
|
|
30
|
+
api_key = None
|
|
31
|
+
else:
|
|
32
|
+
azure_ad_token_provider = None
|
|
33
|
+
|
|
34
|
+
self.client = AzureOpenAI(
|
|
35
|
+
azure_deployment=azure_deployment,
|
|
36
|
+
azure_endpoint=azure_endpoint,
|
|
37
|
+
azure_ad_token_provider=azure_ad_token_provider,
|
|
38
|
+
api_version=api_version,
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
http_client=self.config.http_client,
|
|
41
|
+
default_headers=default_headers,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
45
|
+
"""
|
|
46
|
+
Get the embedding for the given text using OpenAI.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
text (str): The text to embed.
|
|
50
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
51
|
+
Returns:
|
|
52
|
+
list: The embedding vector.
|
|
53
|
+
"""
|
|
54
|
+
text = text.replace("\n", " ")
|
|
55
|
+
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EmbeddingBase(ABC):
|
|
8
|
+
"""Initialized a base embedding class
|
|
9
|
+
|
|
10
|
+
:param config: Embedding configuration option class, defaults to None
|
|
11
|
+
:type config: Optional[BaseEmbedderConfig], optional
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
15
|
+
if config is None:
|
|
16
|
+
self.config = BaseEmbedderConfig()
|
|
17
|
+
else:
|
|
18
|
+
self.config = config
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
|
|
22
|
+
"""
|
|
23
|
+
Get the embedding for the given text.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
text (str): The text to embed.
|
|
27
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
28
|
+
Returns:
|
|
29
|
+
list: The embedding vector.
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EmbedderConfig(BaseModel):
|
|
7
|
+
provider: str = Field(
|
|
8
|
+
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
|
|
9
|
+
default="openai",
|
|
10
|
+
)
|
|
11
|
+
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
|
|
12
|
+
|
|
13
|
+
@field_validator("config")
|
|
14
|
+
def validate_config(cls, v, values):
|
|
15
|
+
provider = values.data.get("provider")
|
|
16
|
+
if provider in [
|
|
17
|
+
"openai",
|
|
18
|
+
"ollama",
|
|
19
|
+
"huggingface",
|
|
20
|
+
"azure_openai",
|
|
21
|
+
"gemini",
|
|
22
|
+
"vertexai",
|
|
23
|
+
"together",
|
|
24
|
+
"lmstudio",
|
|
25
|
+
"langchain",
|
|
26
|
+
"aws_bedrock",
|
|
27
|
+
]:
|
|
28
|
+
return v
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unsupported embedding provider: {provider}")
|