agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
import threading
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
import agentrun_mem0
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import litellm
|
|
13
|
+
except ImportError:
|
|
14
|
+
try:
|
|
15
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
|
|
16
|
+
import litellm
|
|
17
|
+
except subprocess.CalledProcessError:
|
|
18
|
+
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
|
|
19
|
+
sys.exit(1)
|
|
20
|
+
|
|
21
|
+
from agentrun_mem0 import Memory, MemoryClient
|
|
22
|
+
from agentrun_mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
|
23
|
+
from agentrun_mem0.memory.telemetry import capture_client_event, capture_event
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Mem0:
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
config: Optional[dict] = None,
|
|
32
|
+
api_key: Optional[str] = None,
|
|
33
|
+
host: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
if api_key:
|
|
36
|
+
self.mem0_client = MemoryClient(api_key, host)
|
|
37
|
+
else:
|
|
38
|
+
self.mem0_client = Memory.from_config(config) if config else Memory()
|
|
39
|
+
|
|
40
|
+
self.chat = Chat(self.mem0_client)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Chat:
|
|
44
|
+
def __init__(self, mem0_client):
|
|
45
|
+
self.completions = Completions(mem0_client)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Completions:
|
|
49
|
+
def __init__(self, mem0_client):
|
|
50
|
+
self.mem0_client = mem0_client
|
|
51
|
+
|
|
52
|
+
def create(
|
|
53
|
+
self,
|
|
54
|
+
model: str,
|
|
55
|
+
messages: List = [],
|
|
56
|
+
# Mem0 arguments
|
|
57
|
+
user_id: Optional[str] = None,
|
|
58
|
+
agent_id: Optional[str] = None,
|
|
59
|
+
run_id: Optional[str] = None,
|
|
60
|
+
metadata: Optional[dict] = None,
|
|
61
|
+
filters: Optional[dict] = None,
|
|
62
|
+
limit: Optional[int] = 10,
|
|
63
|
+
# LLM arguments
|
|
64
|
+
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
|
65
|
+
temperature: Optional[float] = None,
|
|
66
|
+
top_p: Optional[float] = None,
|
|
67
|
+
n: Optional[int] = None,
|
|
68
|
+
stream: Optional[bool] = None,
|
|
69
|
+
stream_options: Optional[dict] = None,
|
|
70
|
+
stop=None,
|
|
71
|
+
max_tokens: Optional[int] = None,
|
|
72
|
+
presence_penalty: Optional[float] = None,
|
|
73
|
+
frequency_penalty: Optional[float] = None,
|
|
74
|
+
logit_bias: Optional[dict] = None,
|
|
75
|
+
user: Optional[str] = None,
|
|
76
|
+
# openai v1.0+ new params
|
|
77
|
+
response_format: Optional[dict] = None,
|
|
78
|
+
seed: Optional[int] = None,
|
|
79
|
+
tools: Optional[List] = None,
|
|
80
|
+
tool_choice: Optional[Union[str, dict]] = None,
|
|
81
|
+
logprobs: Optional[bool] = None,
|
|
82
|
+
top_logprobs: Optional[int] = None,
|
|
83
|
+
parallel_tool_calls: Optional[bool] = None,
|
|
84
|
+
deployment_id=None,
|
|
85
|
+
extra_headers: Optional[dict] = None,
|
|
86
|
+
# soon to be deprecated params by OpenAI
|
|
87
|
+
functions: Optional[List] = None,
|
|
88
|
+
function_call: Optional[str] = None,
|
|
89
|
+
# set api_base, api_version, api_key
|
|
90
|
+
base_url: Optional[str] = None,
|
|
91
|
+
api_version: Optional[str] = None,
|
|
92
|
+
api_key: Optional[str] = None,
|
|
93
|
+
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
|
94
|
+
):
|
|
95
|
+
if not any([user_id, agent_id, run_id]):
|
|
96
|
+
raise ValueError("One of user_id, agent_id, run_id must be provided")
|
|
97
|
+
|
|
98
|
+
if not litellm.supports_function_calling(model):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
prepared_messages = self._prepare_messages(messages)
|
|
104
|
+
if prepared_messages[-1]["role"] == "user":
|
|
105
|
+
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
|
|
106
|
+
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
|
|
107
|
+
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
|
|
108
|
+
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
|
|
109
|
+
|
|
110
|
+
response = litellm.completion(
|
|
111
|
+
model=model,
|
|
112
|
+
messages=prepared_messages,
|
|
113
|
+
temperature=temperature,
|
|
114
|
+
top_p=top_p,
|
|
115
|
+
n=n,
|
|
116
|
+
timeout=timeout,
|
|
117
|
+
stream=stream,
|
|
118
|
+
stream_options=stream_options,
|
|
119
|
+
stop=stop,
|
|
120
|
+
max_tokens=max_tokens,
|
|
121
|
+
presence_penalty=presence_penalty,
|
|
122
|
+
frequency_penalty=frequency_penalty,
|
|
123
|
+
logit_bias=logit_bias,
|
|
124
|
+
user=user,
|
|
125
|
+
response_format=response_format,
|
|
126
|
+
seed=seed,
|
|
127
|
+
tools=tools,
|
|
128
|
+
tool_choice=tool_choice,
|
|
129
|
+
logprobs=logprobs,
|
|
130
|
+
top_logprobs=top_logprobs,
|
|
131
|
+
parallel_tool_calls=parallel_tool_calls,
|
|
132
|
+
deployment_id=deployment_id,
|
|
133
|
+
extra_headers=extra_headers,
|
|
134
|
+
functions=functions,
|
|
135
|
+
function_call=function_call,
|
|
136
|
+
base_url=base_url,
|
|
137
|
+
api_version=api_version,
|
|
138
|
+
api_key=api_key,
|
|
139
|
+
model_list=model_list,
|
|
140
|
+
)
|
|
141
|
+
if isinstance(self.mem0_client, Memory):
|
|
142
|
+
capture_event("mem0.chat.create", self.mem0_client)
|
|
143
|
+
else:
|
|
144
|
+
capture_client_event("mem0.chat.create", self.mem0_client)
|
|
145
|
+
return response
|
|
146
|
+
|
|
147
|
+
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
|
|
148
|
+
if not messages or messages[0]["role"] != "system":
|
|
149
|
+
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
|
|
150
|
+
return messages
|
|
151
|
+
|
|
152
|
+
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
|
|
153
|
+
def add_task():
|
|
154
|
+
logger.debug("Adding to memory asynchronously")
|
|
155
|
+
self.mem0_client.add(
|
|
156
|
+
messages=messages,
|
|
157
|
+
user_id=user_id,
|
|
158
|
+
agent_id=agent_id,
|
|
159
|
+
run_id=run_id,
|
|
160
|
+
metadata=metadata,
|
|
161
|
+
filters=filters,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
threading.Thread(target=add_task, daemon=True).start()
|
|
165
|
+
|
|
166
|
+
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
|
|
167
|
+
# Currently, only pass the last 6 messages to the search API to prevent long query
|
|
168
|
+
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
|
|
169
|
+
# TODO: Make it better by summarizing the past conversation
|
|
170
|
+
return self.mem0_client.search(
|
|
171
|
+
query="\n".join(message_input),
|
|
172
|
+
user_id=user_id,
|
|
173
|
+
agent_id=agent_id,
|
|
174
|
+
run_id=run_id,
|
|
175
|
+
filters=filters,
|
|
176
|
+
limit=limit,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def _format_query_with_memories(self, messages, relevant_memories):
|
|
180
|
+
# Check if self.mem0_client is an instance of Memory or MemoryClient
|
|
181
|
+
|
|
182
|
+
entities = []
|
|
183
|
+
if isinstance(self.mem0_client, agentrun_mem0.memory.main.Memory):
|
|
184
|
+
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
|
|
185
|
+
if relevant_memories.get("relations"):
|
|
186
|
+
entities = [entity for entity in relevant_memories["relations"]]
|
|
187
|
+
elif isinstance(self.mem0_client, agentrun_mem0.client.main.MemoryClient):
|
|
188
|
+
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
|
|
189
|
+
return f"- Relevant Memories/Facts: {memories_text}\n\n- Entities: {entities}\n\n- User Question: {messages[-1]['content']}"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reranker implementations for mem0 search functionality.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .base import BaseReranker
|
|
6
|
+
from .cohere_reranker import CohereReranker
|
|
7
|
+
from .sentence_transformer_reranker import SentenceTransformerReranker
|
|
8
|
+
|
|
9
|
+
__all__ = ["BaseReranker", "CohereReranker", "SentenceTransformerReranker"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
|
|
4
|
+
class BaseReranker(ABC):
|
|
5
|
+
"""Abstract base class for all rerankers."""
|
|
6
|
+
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
9
|
+
"""
|
|
10
|
+
Rerank documents based on relevance to the query.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
query: The search query
|
|
14
|
+
documents: List of documents to rerank, each with 'memory' field
|
|
15
|
+
top_k: Number of top documents to return (None = return all)
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List of reranked documents with added 'rerank_score' field
|
|
19
|
+
"""
|
|
20
|
+
pass
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.reranker.base import BaseReranker
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import cohere
|
|
8
|
+
COHERE_AVAILABLE = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
COHERE_AVAILABLE = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CohereReranker(BaseReranker):
|
|
14
|
+
"""Cohere-based reranker implementation."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config):
|
|
17
|
+
"""
|
|
18
|
+
Initialize Cohere reranker.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config: CohereRerankerConfig object with configuration parameters
|
|
22
|
+
"""
|
|
23
|
+
if not COHERE_AVAILABLE:
|
|
24
|
+
raise ImportError("cohere package is required for CohereReranker. Install with: pip install cohere")
|
|
25
|
+
|
|
26
|
+
self.config = config
|
|
27
|
+
self.api_key = config.api_key or os.getenv("COHERE_API_KEY")
|
|
28
|
+
if not self.api_key:
|
|
29
|
+
raise ValueError("Cohere API key is required. Set COHERE_API_KEY environment variable or pass api_key in config.")
|
|
30
|
+
|
|
31
|
+
self.model = config.model
|
|
32
|
+
self.client = cohere.Client(self.api_key)
|
|
33
|
+
|
|
34
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
35
|
+
"""
|
|
36
|
+
Rerank documents using Cohere's rerank API.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
query: The search query
|
|
40
|
+
documents: List of documents to rerank
|
|
41
|
+
top_k: Number of top documents to return
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
List of reranked documents with rerank_score
|
|
45
|
+
"""
|
|
46
|
+
if not documents:
|
|
47
|
+
return documents
|
|
48
|
+
|
|
49
|
+
# Extract text content for reranking
|
|
50
|
+
doc_texts = []
|
|
51
|
+
for doc in documents:
|
|
52
|
+
if 'memory' in doc:
|
|
53
|
+
doc_texts.append(doc['memory'])
|
|
54
|
+
elif 'text' in doc:
|
|
55
|
+
doc_texts.append(doc['text'])
|
|
56
|
+
elif 'content' in doc:
|
|
57
|
+
doc_texts.append(doc['content'])
|
|
58
|
+
else:
|
|
59
|
+
doc_texts.append(str(doc))
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
# Call Cohere rerank API
|
|
63
|
+
response = self.client.rerank(
|
|
64
|
+
model=self.model,
|
|
65
|
+
query=query,
|
|
66
|
+
documents=doc_texts,
|
|
67
|
+
top_n=top_k or self.config.top_k or len(documents),
|
|
68
|
+
return_documents=self.config.return_documents,
|
|
69
|
+
max_chunks_per_doc=self.config.max_chunks_per_doc,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Create reranked results
|
|
73
|
+
reranked_docs = []
|
|
74
|
+
for result in response.results:
|
|
75
|
+
original_doc = documents[result.index].copy()
|
|
76
|
+
original_doc['rerank_score'] = result.relevance_score
|
|
77
|
+
reranked_docs.append(original_doc)
|
|
78
|
+
|
|
79
|
+
return reranked_docs
|
|
80
|
+
|
|
81
|
+
except Exception:
|
|
82
|
+
# Fallback to original order if reranking fails
|
|
83
|
+
for doc in documents:
|
|
84
|
+
doc['rerank_score'] = 0.0
|
|
85
|
+
return documents[:top_k] if top_k else documents
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.reranker.base import BaseReranker
|
|
5
|
+
from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
|
|
6
|
+
from agentrun_mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
10
|
+
import torch
|
|
11
|
+
TRANSFORMERS_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
TRANSFORMERS_AVAILABLE = False
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HuggingFaceReranker(BaseReranker):
|
|
17
|
+
"""HuggingFace Transformers based reranker implementation."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: Union[BaseRerankerConfig, HuggingFaceRerankerConfig, Dict]):
|
|
20
|
+
"""
|
|
21
|
+
Initialize HuggingFace reranker.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
config: Configuration object with reranker parameters
|
|
25
|
+
"""
|
|
26
|
+
if not TRANSFORMERS_AVAILABLE:
|
|
27
|
+
raise ImportError("transformers package is required for HuggingFaceReranker. Install with: pip install transformers torch")
|
|
28
|
+
|
|
29
|
+
# Convert to HuggingFaceRerankerConfig if needed
|
|
30
|
+
if isinstance(config, dict):
|
|
31
|
+
config = HuggingFaceRerankerConfig(**config)
|
|
32
|
+
elif isinstance(config, BaseRerankerConfig) and not isinstance(config, HuggingFaceRerankerConfig):
|
|
33
|
+
# Convert BaseRerankerConfig to HuggingFaceRerankerConfig with defaults
|
|
34
|
+
config = HuggingFaceRerankerConfig(
|
|
35
|
+
provider=getattr(config, 'provider', 'huggingface'),
|
|
36
|
+
model=getattr(config, 'model', 'BAAI/bge-reranker-base'),
|
|
37
|
+
api_key=getattr(config, 'api_key', None),
|
|
38
|
+
top_k=getattr(config, 'top_k', None),
|
|
39
|
+
device=None, # Will auto-detect
|
|
40
|
+
batch_size=32, # Default
|
|
41
|
+
max_length=512, # Default
|
|
42
|
+
normalize=True, # Default
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
self.config = config
|
|
46
|
+
|
|
47
|
+
# Set device
|
|
48
|
+
if self.config.device is None:
|
|
49
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
50
|
+
else:
|
|
51
|
+
self.device = self.config.device
|
|
52
|
+
|
|
53
|
+
# Load model and tokenizer
|
|
54
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model)
|
|
55
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(self.config.model)
|
|
56
|
+
self.model.to(self.device)
|
|
57
|
+
self.model.eval()
|
|
58
|
+
|
|
59
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
60
|
+
"""
|
|
61
|
+
Rerank documents using HuggingFace cross-encoder model.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
query: The search query
|
|
65
|
+
documents: List of documents to rerank
|
|
66
|
+
top_k: Number of top documents to return
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of reranked documents with rerank_score
|
|
70
|
+
"""
|
|
71
|
+
if not documents:
|
|
72
|
+
return documents
|
|
73
|
+
|
|
74
|
+
# Extract text content for reranking
|
|
75
|
+
doc_texts = []
|
|
76
|
+
for doc in documents:
|
|
77
|
+
if 'memory' in doc:
|
|
78
|
+
doc_texts.append(doc['memory'])
|
|
79
|
+
elif 'text' in doc:
|
|
80
|
+
doc_texts.append(doc['text'])
|
|
81
|
+
elif 'content' in doc:
|
|
82
|
+
doc_texts.append(doc['content'])
|
|
83
|
+
else:
|
|
84
|
+
doc_texts.append(str(doc))
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
scores = []
|
|
88
|
+
|
|
89
|
+
# Process documents in batches
|
|
90
|
+
for i in range(0, len(doc_texts), self.config.batch_size):
|
|
91
|
+
batch_docs = doc_texts[i:i + self.config.batch_size]
|
|
92
|
+
batch_pairs = [[query, doc] for doc in batch_docs]
|
|
93
|
+
|
|
94
|
+
# Tokenize batch
|
|
95
|
+
inputs = self.tokenizer(
|
|
96
|
+
batch_pairs,
|
|
97
|
+
padding=True,
|
|
98
|
+
truncation=True,
|
|
99
|
+
max_length=self.config.max_length,
|
|
100
|
+
return_tensors="pt"
|
|
101
|
+
).to(self.device)
|
|
102
|
+
|
|
103
|
+
# Get scores
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
outputs = self.model(**inputs)
|
|
106
|
+
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
107
|
+
|
|
108
|
+
# Handle single item case
|
|
109
|
+
if batch_scores.ndim == 0:
|
|
110
|
+
batch_scores = [float(batch_scores)]
|
|
111
|
+
else:
|
|
112
|
+
batch_scores = batch_scores.tolist()
|
|
113
|
+
|
|
114
|
+
scores.extend(batch_scores)
|
|
115
|
+
|
|
116
|
+
# Normalize scores if requested
|
|
117
|
+
if self.config.normalize:
|
|
118
|
+
scores = np.array(scores)
|
|
119
|
+
scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
|
|
120
|
+
scores = scores.tolist()
|
|
121
|
+
|
|
122
|
+
# Combine documents with scores
|
|
123
|
+
doc_score_pairs = list(zip(documents, scores))
|
|
124
|
+
|
|
125
|
+
# Sort by score (descending)
|
|
126
|
+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
|
127
|
+
|
|
128
|
+
# Apply top_k limit
|
|
129
|
+
final_top_k = top_k or self.config.top_k
|
|
130
|
+
if final_top_k:
|
|
131
|
+
doc_score_pairs = doc_score_pairs[:final_top_k]
|
|
132
|
+
|
|
133
|
+
# Create reranked results
|
|
134
|
+
reranked_docs = []
|
|
135
|
+
for doc, score in doc_score_pairs:
|
|
136
|
+
reranked_doc = doc.copy()
|
|
137
|
+
reranked_doc['rerank_score'] = float(score)
|
|
138
|
+
reranked_docs.append(reranked_doc)
|
|
139
|
+
|
|
140
|
+
return reranked_docs
|
|
141
|
+
|
|
142
|
+
except Exception:
|
|
143
|
+
# Fallback to original order if reranking fails
|
|
144
|
+
for doc in documents:
|
|
145
|
+
doc['rerank_score'] = 0.0
|
|
146
|
+
final_top_k = top_k or self.config.top_k
|
|
147
|
+
return documents[:final_top_k] if final_top_k else documents
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Dict, Any, Union
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.reranker.base import BaseReranker
|
|
5
|
+
from agentrun_mem0.utils.factory import LlmFactory
|
|
6
|
+
from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
|
|
7
|
+
from agentrun_mem0.configs.rerankers.llm import LLMRerankerConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LLMReranker(BaseReranker):
|
|
11
|
+
"""LLM-based reranker implementation."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, config: Union[BaseRerankerConfig, LLMRerankerConfig, Dict]):
|
|
14
|
+
"""
|
|
15
|
+
Initialize LLM reranker.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
config: Configuration object with reranker parameters
|
|
19
|
+
"""
|
|
20
|
+
# Convert to LLMRerankerConfig if needed
|
|
21
|
+
if isinstance(config, dict):
|
|
22
|
+
config = LLMRerankerConfig(**config)
|
|
23
|
+
elif isinstance(config, BaseRerankerConfig) and not isinstance(config, LLMRerankerConfig):
|
|
24
|
+
# Convert BaseRerankerConfig to LLMRerankerConfig with defaults
|
|
25
|
+
config = LLMRerankerConfig(
|
|
26
|
+
provider=getattr(config, 'provider', 'openai'),
|
|
27
|
+
model=getattr(config, 'model', 'gpt-4o-mini'),
|
|
28
|
+
api_key=getattr(config, 'api_key', None),
|
|
29
|
+
top_k=getattr(config, 'top_k', None),
|
|
30
|
+
temperature=0.0, # Default for reranking
|
|
31
|
+
max_tokens=100, # Default for reranking
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.config = config
|
|
35
|
+
|
|
36
|
+
# Create LLM configuration for the factory
|
|
37
|
+
llm_config = {
|
|
38
|
+
"model": self.config.model,
|
|
39
|
+
"temperature": self.config.temperature,
|
|
40
|
+
"max_tokens": self.config.max_tokens,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Add API key if provided
|
|
44
|
+
if self.config.api_key:
|
|
45
|
+
llm_config["api_key"] = self.config.api_key
|
|
46
|
+
|
|
47
|
+
# Initialize LLM using the factory
|
|
48
|
+
self.llm = LlmFactory.create(self.config.provider, llm_config)
|
|
49
|
+
|
|
50
|
+
# Default scoring prompt
|
|
51
|
+
self.scoring_prompt = getattr(self.config, 'scoring_prompt', None) or self._get_default_prompt()
|
|
52
|
+
|
|
53
|
+
def _get_default_prompt(self) -> str:
|
|
54
|
+
"""Get the default scoring prompt template."""
|
|
55
|
+
return """You are a relevance scoring assistant. Given a query and a document, you need to score how relevant the document is to the query.
|
|
56
|
+
|
|
57
|
+
Score the relevance on a scale from 0.0 to 1.0, where:
|
|
58
|
+
- 1.0 = Perfectly relevant and directly answers the query
|
|
59
|
+
- 0.8-0.9 = Highly relevant with good information
|
|
60
|
+
- 0.6-0.7 = Moderately relevant with some useful information
|
|
61
|
+
- 0.4-0.5 = Slightly relevant with limited useful information
|
|
62
|
+
- 0.0-0.3 = Not relevant or no useful information
|
|
63
|
+
|
|
64
|
+
Query: "{query}"
|
|
65
|
+
Document: "{document}"
|
|
66
|
+
|
|
67
|
+
Provide only a single numerical score between 0.0 and 1.0. Do not include any explanation or additional text."""
|
|
68
|
+
|
|
69
|
+
def _extract_score(self, response_text: str) -> float:
|
|
70
|
+
"""Extract numerical score from LLM response."""
|
|
71
|
+
# Look for decimal numbers between 0.0 and 1.0
|
|
72
|
+
pattern = r'\b([01](?:\.\d+)?)\b'
|
|
73
|
+
matches = re.findall(pattern, response_text)
|
|
74
|
+
|
|
75
|
+
if matches:
|
|
76
|
+
score = float(matches[0])
|
|
77
|
+
return min(max(score, 0.0), 1.0) # Clamp between 0.0 and 1.0
|
|
78
|
+
|
|
79
|
+
# Fallback: return 0.5 if no valid score found
|
|
80
|
+
return 0.5
|
|
81
|
+
|
|
82
|
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
83
|
+
"""
|
|
84
|
+
Rerank documents using LLM scoring.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
query: The search query
|
|
88
|
+
documents: List of documents to rerank
|
|
89
|
+
top_k: Number of top documents to return
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of reranked documents with rerank_score
|
|
93
|
+
"""
|
|
94
|
+
if not documents:
|
|
95
|
+
return documents
|
|
96
|
+
|
|
97
|
+
scored_docs = []
|
|
98
|
+
|
|
99
|
+
for doc in documents:
|
|
100
|
+
# Extract text content
|
|
101
|
+
if 'memory' in doc:
|
|
102
|
+
doc_text = doc['memory']
|
|
103
|
+
elif 'text' in doc:
|
|
104
|
+
doc_text = doc['text']
|
|
105
|
+
elif 'content' in doc:
|
|
106
|
+
doc_text = doc['content']
|
|
107
|
+
else:
|
|
108
|
+
doc_text = str(doc)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
# Generate scoring prompt
|
|
112
|
+
prompt = self.scoring_prompt.format(query=query, document=doc_text)
|
|
113
|
+
|
|
114
|
+
# Get LLM response
|
|
115
|
+
response = self.llm.generate_response(
|
|
116
|
+
messages=[{"role": "user", "content": prompt}]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Extract score from response
|
|
120
|
+
score = self._extract_score(response)
|
|
121
|
+
|
|
122
|
+
# Create scored document
|
|
123
|
+
scored_doc = doc.copy()
|
|
124
|
+
scored_doc['rerank_score'] = score
|
|
125
|
+
scored_docs.append(scored_doc)
|
|
126
|
+
|
|
127
|
+
except Exception:
|
|
128
|
+
# Fallback: assign neutral score if scoring fails
|
|
129
|
+
scored_doc = doc.copy()
|
|
130
|
+
scored_doc['rerank_score'] = 0.5
|
|
131
|
+
scored_docs.append(scored_doc)
|
|
132
|
+
|
|
133
|
+
# Sort by relevance score in descending order
|
|
134
|
+
scored_docs.sort(key=lambda x: x['rerank_score'], reverse=True)
|
|
135
|
+
|
|
136
|
+
# Apply top_k limit
|
|
137
|
+
if top_k:
|
|
138
|
+
scored_docs = scored_docs[:top_k]
|
|
139
|
+
elif self.config.top_k:
|
|
140
|
+
scored_docs = scored_docs[:self.config.top_k]
|
|
141
|
+
|
|
142
|
+
return scored_docs
|