agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.reranker.base import BaseReranker
|
|
5
|
+
from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
|
|
6
|
+
from agentrun_mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from sentence_transformers import SentenceTransformer
|
|
10
|
+
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SentenceTransformerReranker(BaseReranker):
|
|
16
|
+
"""Sentence Transformer based reranker implementation."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: Union[BaseRerankerConfig, SentenceTransformerRerankerConfig, Dict]):
|
|
19
|
+
"""
|
|
20
|
+
Initialize Sentence Transformer reranker.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
config: Configuration object with reranker parameters
|
|
24
|
+
"""
|
|
25
|
+
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
|
26
|
+
raise ImportError("sentence-transformers package is required for SentenceTransformerReranker. Install with: pip install sentence-transformers")
|
|
27
|
+
|
|
28
|
+
# Convert to SentenceTransformerRerankerConfig if needed
|
|
29
|
+
if isinstance(config, dict):
|
|
30
|
+
config = SentenceTransformerRerankerConfig(**config)
|
|
31
|
+
elif isinstance(config, BaseRerankerConfig) and not isinstance(config, SentenceTransformerRerankerConfig):
|
|
32
|
+
# Convert BaseRerankerConfig to SentenceTransformerRerankerConfig with defaults
|
|
33
|
+
config = SentenceTransformerRerankerConfig(
|
|
34
|
+
provider=getattr(config, 'provider', 'sentence_transformer'),
|
|
35
|
+
model=getattr(config, 'model', 'cross-encoder/ms-marco-MiniLM-L-6-v2'),
|
|
36
|
+
api_key=getattr(config, 'api_key', None),
|
|
37
|
+
top_k=getattr(config, 'top_k', None),
|
|
38
|
+
device=None, # Will auto-detect
|
|
39
|
+
batch_size=32, # Default
|
|
40
|
+
show_progress_bar=False, # Default
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.config = config
|
|
44
|
+
self.model = SentenceTransformer(self.config.model, device=self.config.device)
|
|
45
|
+
|
|
46
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
47
|
+
"""
|
|
48
|
+
Rerank documents using sentence transformer cross-encoder.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
query: The search query
|
|
52
|
+
documents: List of documents to rerank
|
|
53
|
+
top_k: Number of top documents to return
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
List of reranked documents with rerank_score
|
|
57
|
+
"""
|
|
58
|
+
if not documents:
|
|
59
|
+
return documents
|
|
60
|
+
|
|
61
|
+
# Extract text content for reranking
|
|
62
|
+
doc_texts = []
|
|
63
|
+
for doc in documents:
|
|
64
|
+
if 'memory' in doc:
|
|
65
|
+
doc_texts.append(doc['memory'])
|
|
66
|
+
elif 'text' in doc:
|
|
67
|
+
doc_texts.append(doc['text'])
|
|
68
|
+
elif 'content' in doc:
|
|
69
|
+
doc_texts.append(doc['content'])
|
|
70
|
+
else:
|
|
71
|
+
doc_texts.append(str(doc))
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
# Create query-document pairs
|
|
75
|
+
pairs = [[query, doc_text] for doc_text in doc_texts]
|
|
76
|
+
|
|
77
|
+
# Get similarity scores
|
|
78
|
+
scores = self.model.predict(pairs)
|
|
79
|
+
if isinstance(scores, np.ndarray):
|
|
80
|
+
scores = scores.tolist()
|
|
81
|
+
|
|
82
|
+
# Combine documents with scores
|
|
83
|
+
doc_score_pairs = list(zip(documents, scores))
|
|
84
|
+
|
|
85
|
+
# Sort by score (descending)
|
|
86
|
+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
|
87
|
+
|
|
88
|
+
# Apply top_k limit
|
|
89
|
+
final_top_k = top_k or self.config.top_k
|
|
90
|
+
if final_top_k:
|
|
91
|
+
doc_score_pairs = doc_score_pairs[:final_top_k]
|
|
92
|
+
|
|
93
|
+
# Create reranked results
|
|
94
|
+
reranked_docs = []
|
|
95
|
+
for doc, score in doc_score_pairs:
|
|
96
|
+
reranked_doc = doc.copy()
|
|
97
|
+
reranked_doc['rerank_score'] = float(score)
|
|
98
|
+
reranked_docs.append(reranked_doc)
|
|
99
|
+
|
|
100
|
+
return reranked_docs
|
|
101
|
+
|
|
102
|
+
except Exception:
|
|
103
|
+
# Fallback to original order if reranking fails
|
|
104
|
+
for doc in documents:
|
|
105
|
+
doc['rerank_score'] = 0.0
|
|
106
|
+
final_top_k = top_k or self.config.top_k
|
|
107
|
+
return documents[:final_top_k] if final_top_k else documents
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.reranker.base import BaseReranker
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from zeroentropy import ZeroEntropy
|
|
8
|
+
ZERO_ENTROPY_AVAILABLE = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
ZERO_ENTROPY_AVAILABLE = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ZeroEntropyReranker(BaseReranker):
|
|
14
|
+
"""Zero Entropy-based reranker implementation."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config):
|
|
17
|
+
"""
|
|
18
|
+
Initialize Zero Entropy reranker.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config: ZeroEntropyRerankerConfig object with configuration parameters
|
|
22
|
+
"""
|
|
23
|
+
if not ZERO_ENTROPY_AVAILABLE:
|
|
24
|
+
raise ImportError("zeroentropy package is required for ZeroEntropyReranker. Install with: pip install zeroentropy")
|
|
25
|
+
|
|
26
|
+
self.config = config
|
|
27
|
+
self.api_key = config.api_key or os.getenv("ZERO_ENTROPY_API_KEY")
|
|
28
|
+
if not self.api_key:
|
|
29
|
+
raise ValueError("Zero Entropy API key is required. Set ZERO_ENTROPY_API_KEY environment variable or pass api_key in config.")
|
|
30
|
+
|
|
31
|
+
self.model = config.model or "zerank-1"
|
|
32
|
+
|
|
33
|
+
# Initialize Zero Entropy client
|
|
34
|
+
if self.api_key:
|
|
35
|
+
self.client = ZeroEntropy(api_key=self.api_key)
|
|
36
|
+
else:
|
|
37
|
+
self.client = ZeroEntropy() # Will use ZERO_ENTROPY_API_KEY from environment
|
|
38
|
+
|
|
39
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
40
|
+
"""
|
|
41
|
+
Rerank documents using Zero Entropy's rerank API.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
query: The search query
|
|
45
|
+
documents: List of documents to rerank
|
|
46
|
+
top_k: Number of top documents to return
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
List of reranked documents with rerank_score
|
|
50
|
+
"""
|
|
51
|
+
if not documents:
|
|
52
|
+
return documents
|
|
53
|
+
|
|
54
|
+
# Extract text content for reranking
|
|
55
|
+
doc_texts = []
|
|
56
|
+
for doc in documents:
|
|
57
|
+
if 'memory' in doc:
|
|
58
|
+
doc_texts.append(doc['memory'])
|
|
59
|
+
elif 'text' in doc:
|
|
60
|
+
doc_texts.append(doc['text'])
|
|
61
|
+
elif 'content' in doc:
|
|
62
|
+
doc_texts.append(doc['content'])
|
|
63
|
+
else:
|
|
64
|
+
doc_texts.append(str(doc))
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
# Call Zero Entropy rerank API
|
|
68
|
+
response = self.client.models.rerank(
|
|
69
|
+
model=self.model,
|
|
70
|
+
query=query,
|
|
71
|
+
documents=doc_texts,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Create reranked results
|
|
75
|
+
reranked_docs = []
|
|
76
|
+
for result in response.results:
|
|
77
|
+
original_doc = documents[result.index].copy()
|
|
78
|
+
original_doc['rerank_score'] = result.relevance_score
|
|
79
|
+
reranked_docs.append(original_doc)
|
|
80
|
+
|
|
81
|
+
# Sort by relevance score in descending order
|
|
82
|
+
reranked_docs.sort(key=lambda x: x['rerank_score'], reverse=True)
|
|
83
|
+
|
|
84
|
+
# Apply top_k limit
|
|
85
|
+
if top_k:
|
|
86
|
+
reranked_docs = reranked_docs[:top_k]
|
|
87
|
+
elif self.config.top_k:
|
|
88
|
+
reranked_docs = reranked_docs[:self.config.top_k]
|
|
89
|
+
|
|
90
|
+
return reranked_docs
|
|
91
|
+
|
|
92
|
+
except Exception:
|
|
93
|
+
# Fallback to original order if reranking fails
|
|
94
|
+
for doc in documents:
|
|
95
|
+
doc['rerank_score'] = 0.0
|
|
96
|
+
return documents[:top_k] if top_k else documents
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
from typing import Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
5
|
+
from agentrun_mem0.configs.llms.anthropic import AnthropicConfig
|
|
6
|
+
from agentrun_mem0.configs.llms.azure import AzureOpenAIConfig
|
|
7
|
+
from agentrun_mem0.configs.llms.base import BaseLlmConfig
|
|
8
|
+
from agentrun_mem0.configs.llms.deepseek import DeepSeekConfig
|
|
9
|
+
from agentrun_mem0.configs.llms.lmstudio import LMStudioConfig
|
|
10
|
+
from agentrun_mem0.configs.llms.ollama import OllamaConfig
|
|
11
|
+
from agentrun_mem0.configs.llms.openai import OpenAIConfig
|
|
12
|
+
from agentrun_mem0.configs.llms.vllm import VllmConfig
|
|
13
|
+
from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
|
|
14
|
+
from agentrun_mem0.configs.rerankers.cohere import CohereRerankerConfig
|
|
15
|
+
from agentrun_mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig
|
|
16
|
+
from agentrun_mem0.configs.rerankers.zero_entropy import ZeroEntropyRerankerConfig
|
|
17
|
+
from agentrun_mem0.configs.rerankers.llm import LLMRerankerConfig
|
|
18
|
+
from agentrun_mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
|
|
19
|
+
from agentrun_mem0.embeddings.mock import MockEmbeddings
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_class(class_type):
|
|
23
|
+
module_path, class_name = class_type.rsplit(".", 1)
|
|
24
|
+
module = importlib.import_module(module_path)
|
|
25
|
+
return getattr(module, class_name)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LlmFactory:
|
|
29
|
+
"""
|
|
30
|
+
Factory for creating LLM instances with appropriate configurations.
|
|
31
|
+
Supports both old-style BaseLlmConfig and new provider-specific configs.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Provider mappings with their config classes
|
|
35
|
+
provider_to_class = {
|
|
36
|
+
"ollama": ("agentrun_mem0.llms.ollama.OllamaLLM", OllamaConfig),
|
|
37
|
+
"openai": ("agentrun_mem0.llms.openai.OpenAILLM", OpenAIConfig),
|
|
38
|
+
"groq": ("agentrun_mem0.llms.groq.GroqLLM", BaseLlmConfig),
|
|
39
|
+
"together": ("agentrun_mem0.llms.together.TogetherLLM", BaseLlmConfig),
|
|
40
|
+
"aws_bedrock": ("agentrun_mem0.llms.aws_bedrock.AWSBedrockLLM", BaseLlmConfig),
|
|
41
|
+
"litellm": ("agentrun_mem0.llms.litellm.LiteLLM", BaseLlmConfig),
|
|
42
|
+
"azure_openai": ("agentrun_mem0.llms.azure_openai.AzureOpenAILLM", AzureOpenAIConfig),
|
|
43
|
+
"openai_structured": ("agentrun_mem0.llms.openai_structured.OpenAIStructuredLLM", OpenAIConfig),
|
|
44
|
+
"anthropic": ("agentrun_mem0.llms.anthropic.AnthropicLLM", AnthropicConfig),
|
|
45
|
+
"azure_openai_structured": ("agentrun_mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM", AzureOpenAIConfig),
|
|
46
|
+
"gemini": ("agentrun_mem0.llms.gemini.GeminiLLM", BaseLlmConfig),
|
|
47
|
+
"deepseek": ("agentrun_mem0.llms.deepseek.DeepSeekLLM", DeepSeekConfig),
|
|
48
|
+
"xai": ("agentrun_mem0.llms.xai.XAILLM", BaseLlmConfig),
|
|
49
|
+
"sarvam": ("agentrun_mem0.llms.sarvam.SarvamLLM", BaseLlmConfig),
|
|
50
|
+
"lmstudio": ("agentrun_mem0.llms.lmstudio.LMStudioLLM", LMStudioConfig),
|
|
51
|
+
"vllm": ("agentrun_mem0.llms.vllm.VllmLLM", VllmConfig),
|
|
52
|
+
"langchain": ("agentrun_mem0.llms.langchain.LangchainLLM", BaseLlmConfig),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def create(cls, provider_name: str, config: Optional[Union[BaseLlmConfig, Dict]] = None, **kwargs):
|
|
57
|
+
"""
|
|
58
|
+
Create an LLM instance with the appropriate configuration.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
provider_name (str): The provider name (e.g., 'openai', 'anthropic')
|
|
62
|
+
config: Configuration object or dict. If None, will create default config
|
|
63
|
+
**kwargs: Additional configuration parameters
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Configured LLM instance
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If provider is not supported
|
|
70
|
+
"""
|
|
71
|
+
if provider_name not in cls.provider_to_class:
|
|
72
|
+
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
|
73
|
+
|
|
74
|
+
class_type, config_class = cls.provider_to_class[provider_name]
|
|
75
|
+
llm_class = load_class(class_type)
|
|
76
|
+
|
|
77
|
+
# Handle configuration
|
|
78
|
+
if config is None:
|
|
79
|
+
# Create default config with kwargs
|
|
80
|
+
config = config_class(**kwargs)
|
|
81
|
+
elif isinstance(config, dict):
|
|
82
|
+
# Merge dict config with kwargs
|
|
83
|
+
config.update(kwargs)
|
|
84
|
+
config = config_class(**config)
|
|
85
|
+
elif isinstance(config, BaseLlmConfig):
|
|
86
|
+
# Convert base config to provider-specific config if needed
|
|
87
|
+
if config_class != BaseLlmConfig:
|
|
88
|
+
# Convert to provider-specific config
|
|
89
|
+
config_dict = {
|
|
90
|
+
"model": config.model,
|
|
91
|
+
"temperature": config.temperature,
|
|
92
|
+
"api_key": config.api_key,
|
|
93
|
+
"max_tokens": config.max_tokens,
|
|
94
|
+
"top_p": config.top_p,
|
|
95
|
+
"top_k": config.top_k,
|
|
96
|
+
"enable_vision": config.enable_vision,
|
|
97
|
+
"vision_details": config.vision_details,
|
|
98
|
+
"http_client_proxies": config.http_client,
|
|
99
|
+
}
|
|
100
|
+
config_dict.update(kwargs)
|
|
101
|
+
config = config_class(**config_dict)
|
|
102
|
+
else:
|
|
103
|
+
# Use base config as-is
|
|
104
|
+
pass
|
|
105
|
+
else:
|
|
106
|
+
# Assume it's already the correct config type
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
return llm_class(config)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def register_provider(cls, name: str, class_path: str, config_class=None):
|
|
113
|
+
"""
|
|
114
|
+
Register a new provider.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
name (str): Provider name
|
|
118
|
+
class_path (str): Full path to LLM class
|
|
119
|
+
config_class: Configuration class for the provider (defaults to BaseLlmConfig)
|
|
120
|
+
"""
|
|
121
|
+
if config_class is None:
|
|
122
|
+
config_class = BaseLlmConfig
|
|
123
|
+
cls.provider_to_class[name] = (class_path, config_class)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def get_supported_providers(cls) -> list:
|
|
127
|
+
"""
|
|
128
|
+
Get list of supported providers.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
list: List of supported provider names
|
|
132
|
+
"""
|
|
133
|
+
return list(cls.provider_to_class.keys())
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class EmbedderFactory:
|
|
137
|
+
provider_to_class = {
|
|
138
|
+
"openai": "agentrun_mem0.embeddings.openai.OpenAIEmbedding",
|
|
139
|
+
"ollama": "agentrun_mem0.embeddings.ollama.OllamaEmbedding",
|
|
140
|
+
"huggingface": "agentrun_mem0.embeddings.huggingface.HuggingFaceEmbedding",
|
|
141
|
+
"azure_openai": "agentrun_mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
|
|
142
|
+
"gemini": "agentrun_mem0.embeddings.gemini.GoogleGenAIEmbedding",
|
|
143
|
+
"vertexai": "agentrun_mem0.embeddings.vertexai.VertexAIEmbedding",
|
|
144
|
+
"together": "agentrun_mem0.embeddings.together.TogetherEmbedding",
|
|
145
|
+
"lmstudio": "agentrun_mem0.embeddings.lmstudio.LMStudioEmbedding",
|
|
146
|
+
"langchain": "agentrun_mem0.embeddings.langchain.LangchainEmbedding",
|
|
147
|
+
"aws_bedrock": "agentrun_mem0.embeddings.aws_bedrock.AWSBedrockEmbedding",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def create(cls, provider_name, config, vector_config: Optional[dict]):
|
|
152
|
+
if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings:
|
|
153
|
+
return MockEmbeddings()
|
|
154
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
155
|
+
if class_type:
|
|
156
|
+
embedder_instance = load_class(class_type)
|
|
157
|
+
base_config = BaseEmbedderConfig(**config)
|
|
158
|
+
return embedder_instance(base_config)
|
|
159
|
+
else:
|
|
160
|
+
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class VectorStoreFactory:
|
|
164
|
+
provider_to_class = {
|
|
165
|
+
"qdrant": "agentrun_mem0.vector_stores.qdrant.Qdrant",
|
|
166
|
+
"chroma": "agentrun_mem0.vector_stores.chroma.ChromaDB",
|
|
167
|
+
"pgvector": "agentrun_mem0.vector_stores.pgvector.PGVector",
|
|
168
|
+
"milvus": "agentrun_mem0.vector_stores.milvus.MilvusDB",
|
|
169
|
+
"upstash_vector": "agentrun_mem0.vector_stores.upstash_vector.UpstashVector",
|
|
170
|
+
"azure_ai_search": "agentrun_mem0.vector_stores.azure_ai_search.AzureAISearch",
|
|
171
|
+
"azure_mysql": "agentrun_mem0.vector_stores.azure_mysql.AzureMySQL",
|
|
172
|
+
"pinecone": "agentrun_mem0.vector_stores.pinecone.PineconeDB",
|
|
173
|
+
"mongodb": "agentrun_mem0.vector_stores.mongodb.MongoDB",
|
|
174
|
+
"redis": "agentrun_mem0.vector_stores.redis.RedisDB",
|
|
175
|
+
"valkey": "agentrun_mem0.vector_stores.valkey.ValkeyDB",
|
|
176
|
+
"databricks": "agentrun_mem0.vector_stores.databricks.Databricks",
|
|
177
|
+
"elasticsearch": "agentrun_mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
|
178
|
+
"vertex_ai_vector_search": "agentrun_mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
|
|
179
|
+
"opensearch": "agentrun_mem0.vector_stores.opensearch.OpenSearchDB",
|
|
180
|
+
"supabase": "agentrun_mem0.vector_stores.supabase.Supabase",
|
|
181
|
+
"weaviate": "agentrun_mem0.vector_stores.weaviate.Weaviate",
|
|
182
|
+
"faiss": "agentrun_mem0.vector_stores.faiss.FAISS",
|
|
183
|
+
"langchain": "agentrun_mem0.vector_stores.langchain.Langchain",
|
|
184
|
+
"s3_vectors": "agentrun_mem0.vector_stores.s3_vectors.S3Vectors",
|
|
185
|
+
"baidu": "agentrun_mem0.vector_stores.baidu.BaiduDB",
|
|
186
|
+
"neptune": "agentrun_mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector",
|
|
187
|
+
"aliyun_tablestore": "agentrun_mem0.vector_stores.aliyun_tablestore.AliyunTableStore",
|
|
188
|
+
"alibabacloud_mysql": "agentrun_mem0.vector_stores.alibabacloud_mysql.MySQLVector",
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def create(cls, provider_name, config):
|
|
193
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
194
|
+
if class_type:
|
|
195
|
+
if not isinstance(config, dict):
|
|
196
|
+
config = config.model_dump()
|
|
197
|
+
vector_store_instance = load_class(class_type)
|
|
198
|
+
return vector_store_instance(**config)
|
|
199
|
+
else:
|
|
200
|
+
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def reset(cls, instance):
|
|
204
|
+
instance.reset()
|
|
205
|
+
return instance
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class GraphStoreFactory:
|
|
209
|
+
"""
|
|
210
|
+
Factory for creating MemoryGraph instances for different graph store providers.
|
|
211
|
+
Usage: GraphStoreFactory.create(provider_name, config)
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
provider_to_class = {
|
|
215
|
+
"memgraph": "agentrun_mem0.memory.memgraph_memory.MemoryGraph",
|
|
216
|
+
"neptune": "agentrun_mem0.graphs.neptune.neptunegraph.MemoryGraph",
|
|
217
|
+
"neptunedb": "agentrun_mem0.graphs.neptune.neptunedb.MemoryGraph",
|
|
218
|
+
"kuzu": "agentrun_mem0.memory.kuzu_memory.MemoryGraph",
|
|
219
|
+
"default": "agentrun_mem0.memory.graph_memory.MemoryGraph",
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def create(cls, provider_name, config):
|
|
224
|
+
class_type = cls.provider_to_class.get(provider_name, cls.provider_to_class["default"])
|
|
225
|
+
try:
|
|
226
|
+
GraphClass = load_class(class_type)
|
|
227
|
+
except (ImportError, AttributeError) as e:
|
|
228
|
+
raise ImportError(f"Could not import MemoryGraph for provider '{provider_name}': {e}")
|
|
229
|
+
return GraphClass(config)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class RerankerFactory:
|
|
233
|
+
"""
|
|
234
|
+
Factory for creating reranker instances with appropriate configurations.
|
|
235
|
+
Supports provider-specific configs following the same pattern as other factories.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
# Provider mappings with their config classes
|
|
239
|
+
provider_to_class = {
|
|
240
|
+
"cohere": ("agentrun_mem0.reranker.cohere_reranker.CohereReranker", CohereRerankerConfig),
|
|
241
|
+
"sentence_transformer": ("agentrun_mem0.reranker.sentence_transformer_reranker.SentenceTransformerReranker", SentenceTransformerRerankerConfig),
|
|
242
|
+
"zero_entropy": ("agentrun_mem0.reranker.zero_entropy_reranker.ZeroEntropyReranker", ZeroEntropyRerankerConfig),
|
|
243
|
+
"llm_reranker": ("agentrun_mem0.reranker.llm_reranker.LLMReranker", LLMRerankerConfig),
|
|
244
|
+
"huggingface": ("agentrun_mem0.reranker.huggingface_reranker.HuggingFaceReranker", HuggingFaceRerankerConfig),
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def create(cls, provider_name: str, config: Optional[Union[BaseRerankerConfig, Dict]] = None, **kwargs):
|
|
249
|
+
"""
|
|
250
|
+
Create a reranker instance based on the provider and configuration.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
provider_name: The reranker provider (e.g., 'cohere', 'sentence_transformer')
|
|
254
|
+
config: Configuration object or dictionary
|
|
255
|
+
**kwargs: Additional configuration parameters
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Reranker instance configured for the specified provider
|
|
259
|
+
|
|
260
|
+
Raises:
|
|
261
|
+
ImportError: If the provider class cannot be imported
|
|
262
|
+
ValueError: If the provider is not supported
|
|
263
|
+
"""
|
|
264
|
+
if provider_name not in cls.provider_to_class:
|
|
265
|
+
raise ValueError(f"Unsupported reranker provider: {provider_name}")
|
|
266
|
+
|
|
267
|
+
class_path, config_class = cls.provider_to_class[provider_name]
|
|
268
|
+
|
|
269
|
+
# Handle configuration
|
|
270
|
+
if config is None:
|
|
271
|
+
config = config_class(**kwargs)
|
|
272
|
+
elif isinstance(config, dict):
|
|
273
|
+
config = config_class(**config, **kwargs)
|
|
274
|
+
elif not isinstance(config, BaseRerankerConfig):
|
|
275
|
+
raise ValueError(f"Config must be a {config_class.__name__} instance or dict")
|
|
276
|
+
|
|
277
|
+
# Import and create the reranker class
|
|
278
|
+
try:
|
|
279
|
+
reranker_class = load_class(class_path)
|
|
280
|
+
except (ImportError, AttributeError) as e:
|
|
281
|
+
raise ImportError(f"Could not import reranker for provider '{provider_name}': {e}")
|
|
282
|
+
|
|
283
|
+
return reranker_class(config)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from typing import Optional, Dict, Any
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from google.oauth2 import service_account
|
|
7
|
+
from google.auth import default
|
|
8
|
+
import google.auth.credentials
|
|
9
|
+
except ImportError:
|
|
10
|
+
raise ImportError("google-auth is required for GCP authentication. Install with: pip install google-auth")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GCPAuthenticator:
|
|
14
|
+
"""
|
|
15
|
+
Centralized GCP authentication handler that supports multiple credential methods.
|
|
16
|
+
|
|
17
|
+
Priority order:
|
|
18
|
+
1. service_account_json (dict) - In-memory service account credentials
|
|
19
|
+
2. credentials_path (str) - Path to service account JSON file
|
|
20
|
+
3. Environment variables (GOOGLE_APPLICATION_CREDENTIALS)
|
|
21
|
+
4. Default credentials (for environments like GCE, Cloud Run, etc.)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def get_credentials(
|
|
26
|
+
service_account_json: Optional[Dict[str, Any]] = None,
|
|
27
|
+
credentials_path: Optional[str] = None,
|
|
28
|
+
scopes: Optional[list] = None
|
|
29
|
+
) -> tuple[google.auth.credentials.Credentials, Optional[str]]:
|
|
30
|
+
"""
|
|
31
|
+
Get Google credentials using the priority order defined above.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
service_account_json: Service account credentials as a dictionary
|
|
35
|
+
credentials_path: Path to service account JSON file
|
|
36
|
+
scopes: List of OAuth scopes (optional)
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
tuple: (credentials, project_id)
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If no valid credentials are found
|
|
43
|
+
"""
|
|
44
|
+
credentials = None
|
|
45
|
+
project_id = None
|
|
46
|
+
|
|
47
|
+
# Method 1: Service account JSON (in-memory)
|
|
48
|
+
if service_account_json:
|
|
49
|
+
credentials = service_account.Credentials.from_service_account_info(
|
|
50
|
+
service_account_json, scopes=scopes
|
|
51
|
+
)
|
|
52
|
+
project_id = service_account_json.get("project_id")
|
|
53
|
+
|
|
54
|
+
# Method 2: Service account file path
|
|
55
|
+
elif credentials_path and os.path.isfile(credentials_path):
|
|
56
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
57
|
+
credentials_path, scopes=scopes
|
|
58
|
+
)
|
|
59
|
+
# Extract project_id from the file
|
|
60
|
+
with open(credentials_path, 'r') as f:
|
|
61
|
+
cred_data = json.load(f)
|
|
62
|
+
project_id = cred_data.get("project_id")
|
|
63
|
+
|
|
64
|
+
# Method 3: Environment variable path
|
|
65
|
+
elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
|
66
|
+
env_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
67
|
+
if os.path.isfile(env_path):
|
|
68
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
69
|
+
env_path, scopes=scopes
|
|
70
|
+
)
|
|
71
|
+
# Extract project_id from the file
|
|
72
|
+
with open(env_path, 'r') as f:
|
|
73
|
+
cred_data = json.load(f)
|
|
74
|
+
project_id = cred_data.get("project_id")
|
|
75
|
+
|
|
76
|
+
# Method 4: Default credentials (GCE, Cloud Run, etc.)
|
|
77
|
+
if not credentials:
|
|
78
|
+
try:
|
|
79
|
+
credentials, project_id = default(scopes=scopes)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"No valid GCP credentials found. Please provide one of:\n"
|
|
83
|
+
f"1. service_account_json parameter (dict)\n"
|
|
84
|
+
f"2. credentials_path parameter (file path)\n"
|
|
85
|
+
f"3. GOOGLE_APPLICATION_CREDENTIALS environment variable\n"
|
|
86
|
+
f"4. Default credentials (if running on GCP)\n"
|
|
87
|
+
f"Error: {e}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return credentials, project_id
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def setup_vertex_ai(
|
|
94
|
+
service_account_json: Optional[Dict[str, Any]] = None,
|
|
95
|
+
credentials_path: Optional[str] = None,
|
|
96
|
+
project_id: Optional[str] = None,
|
|
97
|
+
location: str = "us-central1"
|
|
98
|
+
) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Initialize Vertex AI with proper authentication.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
service_account_json: Service account credentials as dict
|
|
104
|
+
credentials_path: Path to service account JSON file
|
|
105
|
+
project_id: GCP project ID (optional, will be auto-detected)
|
|
106
|
+
location: GCP location/region
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The project ID being used
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If authentication fails
|
|
113
|
+
"""
|
|
114
|
+
try:
|
|
115
|
+
import vertexai
|
|
116
|
+
except ImportError:
|
|
117
|
+
raise ImportError("google-cloud-aiplatform is required for Vertex AI. Install with: pip install google-cloud-aiplatform")
|
|
118
|
+
|
|
119
|
+
credentials, detected_project_id = GCPAuthenticator.get_credentials(
|
|
120
|
+
service_account_json=service_account_json,
|
|
121
|
+
credentials_path=credentials_path,
|
|
122
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Use provided project_id or fall back to detected one
|
|
126
|
+
final_project_id = project_id or detected_project_id or os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
127
|
+
|
|
128
|
+
if not final_project_id:
|
|
129
|
+
raise ValueError("Project ID could not be determined. Please provide project_id parameter or set GOOGLE_CLOUD_PROJECT environment variable.")
|
|
130
|
+
|
|
131
|
+
vertexai.init(project=final_project_id, location=location, credentials=credentials)
|
|
132
|
+
return final_project_id
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def get_genai_client(
|
|
136
|
+
service_account_json: Optional[Dict[str, Any]] = None,
|
|
137
|
+
credentials_path: Optional[str] = None,
|
|
138
|
+
api_key: Optional[str] = None
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Get a Google GenAI client with authentication.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
service_account_json: Service account credentials as dict
|
|
145
|
+
credentials_path: Path to service account JSON file
|
|
146
|
+
api_key: API key (takes precedence over service account)
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Google GenAI client instance
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
from google.genai import Client as GenAIClient
|
|
153
|
+
except ImportError:
|
|
154
|
+
raise ImportError("google-genai is required. Install with: pip install google-genai")
|
|
155
|
+
|
|
156
|
+
# If API key is provided, use it directly
|
|
157
|
+
if api_key:
|
|
158
|
+
return GenAIClient(api_key=api_key)
|
|
159
|
+
|
|
160
|
+
# Otherwise, try service account authentication
|
|
161
|
+
credentials, _ = GCPAuthenticator.get_credentials(
|
|
162
|
+
service_account_json=service_account_json,
|
|
163
|
+
credentials_path=credentials_path,
|
|
164
|
+
scopes=["https://www.googleapis.com/auth/generative-language"]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return GenAIClient(credentials=credentials)
|
|
File without changes
|