agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,189 @@
1
+ import logging
2
+ import subprocess
3
+ import sys
4
+ import threading
5
+ from typing import List, Optional, Union
6
+
7
+ import httpx
8
+
9
+ import agentrun_mem0
10
+
11
+ try:
12
+ import litellm
13
+ except ImportError:
14
+ try:
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
16
+ import litellm
17
+ except subprocess.CalledProcessError:
18
+ print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
19
+ sys.exit(1)
20
+
21
+ from agentrun_mem0 import Memory, MemoryClient
22
+ from agentrun_mem0.configs.prompts import MEMORY_ANSWER_PROMPT
23
+ from agentrun_mem0.memory.telemetry import capture_client_event, capture_event
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class Mem0:
29
+ def __init__(
30
+ self,
31
+ config: Optional[dict] = None,
32
+ api_key: Optional[str] = None,
33
+ host: Optional[str] = None,
34
+ ):
35
+ if api_key:
36
+ self.mem0_client = MemoryClient(api_key, host)
37
+ else:
38
+ self.mem0_client = Memory.from_config(config) if config else Memory()
39
+
40
+ self.chat = Chat(self.mem0_client)
41
+
42
+
43
+ class Chat:
44
+ def __init__(self, mem0_client):
45
+ self.completions = Completions(mem0_client)
46
+
47
+
48
+ class Completions:
49
+ def __init__(self, mem0_client):
50
+ self.mem0_client = mem0_client
51
+
52
+ def create(
53
+ self,
54
+ model: str,
55
+ messages: List = [],
56
+ # Mem0 arguments
57
+ user_id: Optional[str] = None,
58
+ agent_id: Optional[str] = None,
59
+ run_id: Optional[str] = None,
60
+ metadata: Optional[dict] = None,
61
+ filters: Optional[dict] = None,
62
+ limit: Optional[int] = 10,
63
+ # LLM arguments
64
+ timeout: Optional[Union[float, str, httpx.Timeout]] = None,
65
+ temperature: Optional[float] = None,
66
+ top_p: Optional[float] = None,
67
+ n: Optional[int] = None,
68
+ stream: Optional[bool] = None,
69
+ stream_options: Optional[dict] = None,
70
+ stop=None,
71
+ max_tokens: Optional[int] = None,
72
+ presence_penalty: Optional[float] = None,
73
+ frequency_penalty: Optional[float] = None,
74
+ logit_bias: Optional[dict] = None,
75
+ user: Optional[str] = None,
76
+ # openai v1.0+ new params
77
+ response_format: Optional[dict] = None,
78
+ seed: Optional[int] = None,
79
+ tools: Optional[List] = None,
80
+ tool_choice: Optional[Union[str, dict]] = None,
81
+ logprobs: Optional[bool] = None,
82
+ top_logprobs: Optional[int] = None,
83
+ parallel_tool_calls: Optional[bool] = None,
84
+ deployment_id=None,
85
+ extra_headers: Optional[dict] = None,
86
+ # soon to be deprecated params by OpenAI
87
+ functions: Optional[List] = None,
88
+ function_call: Optional[str] = None,
89
+ # set api_base, api_version, api_key
90
+ base_url: Optional[str] = None,
91
+ api_version: Optional[str] = None,
92
+ api_key: Optional[str] = None,
93
+ model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
94
+ ):
95
+ if not any([user_id, agent_id, run_id]):
96
+ raise ValueError("One of user_id, agent_id, run_id must be provided")
97
+
98
+ if not litellm.supports_function_calling(model):
99
+ raise ValueError(
100
+ f"Model '{model}' does not support function calling. Please use a model that supports function calling."
101
+ )
102
+
103
+ prepared_messages = self._prepare_messages(messages)
104
+ if prepared_messages[-1]["role"] == "user":
105
+ self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
106
+ relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
107
+ logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
108
+ prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
109
+
110
+ response = litellm.completion(
111
+ model=model,
112
+ messages=prepared_messages,
113
+ temperature=temperature,
114
+ top_p=top_p,
115
+ n=n,
116
+ timeout=timeout,
117
+ stream=stream,
118
+ stream_options=stream_options,
119
+ stop=stop,
120
+ max_tokens=max_tokens,
121
+ presence_penalty=presence_penalty,
122
+ frequency_penalty=frequency_penalty,
123
+ logit_bias=logit_bias,
124
+ user=user,
125
+ response_format=response_format,
126
+ seed=seed,
127
+ tools=tools,
128
+ tool_choice=tool_choice,
129
+ logprobs=logprobs,
130
+ top_logprobs=top_logprobs,
131
+ parallel_tool_calls=parallel_tool_calls,
132
+ deployment_id=deployment_id,
133
+ extra_headers=extra_headers,
134
+ functions=functions,
135
+ function_call=function_call,
136
+ base_url=base_url,
137
+ api_version=api_version,
138
+ api_key=api_key,
139
+ model_list=model_list,
140
+ )
141
+ if isinstance(self.mem0_client, Memory):
142
+ capture_event("mem0.chat.create", self.mem0_client)
143
+ else:
144
+ capture_client_event("mem0.chat.create", self.mem0_client)
145
+ return response
146
+
147
+ def _prepare_messages(self, messages: List[dict]) -> List[dict]:
148
+ if not messages or messages[0]["role"] != "system":
149
+ return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
150
+ return messages
151
+
152
+ def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
153
+ def add_task():
154
+ logger.debug("Adding to memory asynchronously")
155
+ self.mem0_client.add(
156
+ messages=messages,
157
+ user_id=user_id,
158
+ agent_id=agent_id,
159
+ run_id=run_id,
160
+ metadata=metadata,
161
+ filters=filters,
162
+ )
163
+
164
+ threading.Thread(target=add_task, daemon=True).start()
165
+
166
+ def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
167
+ # Currently, only pass the last 6 messages to the search API to prevent long query
168
+ message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
169
+ # TODO: Make it better by summarizing the past conversation
170
+ return self.mem0_client.search(
171
+ query="\n".join(message_input),
172
+ user_id=user_id,
173
+ agent_id=agent_id,
174
+ run_id=run_id,
175
+ filters=filters,
176
+ limit=limit,
177
+ )
178
+
179
+ def _format_query_with_memories(self, messages, relevant_memories):
180
+ # Check if self.mem0_client is an instance of Memory or MemoryClient
181
+
182
+ entities = []
183
+ if isinstance(self.mem0_client, agentrun_mem0.memory.main.Memory):
184
+ memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
185
+ if relevant_memories.get("relations"):
186
+ entities = [entity for entity in relevant_memories["relations"]]
187
+ elif isinstance(self.mem0_client, agentrun_mem0.client.main.MemoryClient):
188
+ memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
189
+ return f"- Relevant Memories/Facts: {memories_text}\n\n- Entities: {entities}\n\n- User Question: {messages[-1]['content']}"
@@ -0,0 +1,9 @@
1
+ """
2
+ Reranker implementations for mem0 search functionality.
3
+ """
4
+
5
+ from .base import BaseReranker
6
+ from .cohere_reranker import CohereReranker
7
+ from .sentence_transformer_reranker import SentenceTransformerReranker
8
+
9
+ __all__ = ["BaseReranker", "CohereReranker", "SentenceTransformerReranker"]
@@ -0,0 +1,20 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any
3
+
4
+ class BaseReranker(ABC):
5
+ """Abstract base class for all rerankers."""
6
+
7
+ @abstractmethod
8
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
9
+ """
10
+ Rerank documents based on relevance to the query.
11
+
12
+ Args:
13
+ query: The search query
14
+ documents: List of documents to rerank, each with 'memory' field
15
+ top_k: Number of top documents to return (None = return all)
16
+
17
+ Returns:
18
+ List of reranked documents with added 'rerank_score' field
19
+ """
20
+ pass
@@ -0,0 +1,85 @@
1
+ import os
2
+ from typing import List, Dict, Any
3
+
4
+ from agentrun_mem0.reranker.base import BaseReranker
5
+
6
+ try:
7
+ import cohere
8
+ COHERE_AVAILABLE = True
9
+ except ImportError:
10
+ COHERE_AVAILABLE = False
11
+
12
+
13
+ class CohereReranker(BaseReranker):
14
+ """Cohere-based reranker implementation."""
15
+
16
+ def __init__(self, config):
17
+ """
18
+ Initialize Cohere reranker.
19
+
20
+ Args:
21
+ config: CohereRerankerConfig object with configuration parameters
22
+ """
23
+ if not COHERE_AVAILABLE:
24
+ raise ImportError("cohere package is required for CohereReranker. Install with: pip install cohere")
25
+
26
+ self.config = config
27
+ self.api_key = config.api_key or os.getenv("COHERE_API_KEY")
28
+ if not self.api_key:
29
+ raise ValueError("Cohere API key is required. Set COHERE_API_KEY environment variable or pass api_key in config.")
30
+
31
+ self.model = config.model
32
+ self.client = cohere.Client(self.api_key)
33
+
34
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
35
+ """
36
+ Rerank documents using Cohere's rerank API.
37
+
38
+ Args:
39
+ query: The search query
40
+ documents: List of documents to rerank
41
+ top_k: Number of top documents to return
42
+
43
+ Returns:
44
+ List of reranked documents with rerank_score
45
+ """
46
+ if not documents:
47
+ return documents
48
+
49
+ # Extract text content for reranking
50
+ doc_texts = []
51
+ for doc in documents:
52
+ if 'memory' in doc:
53
+ doc_texts.append(doc['memory'])
54
+ elif 'text' in doc:
55
+ doc_texts.append(doc['text'])
56
+ elif 'content' in doc:
57
+ doc_texts.append(doc['content'])
58
+ else:
59
+ doc_texts.append(str(doc))
60
+
61
+ try:
62
+ # Call Cohere rerank API
63
+ response = self.client.rerank(
64
+ model=self.model,
65
+ query=query,
66
+ documents=doc_texts,
67
+ top_n=top_k or self.config.top_k or len(documents),
68
+ return_documents=self.config.return_documents,
69
+ max_chunks_per_doc=self.config.max_chunks_per_doc,
70
+ )
71
+
72
+ # Create reranked results
73
+ reranked_docs = []
74
+ for result in response.results:
75
+ original_doc = documents[result.index].copy()
76
+ original_doc['rerank_score'] = result.relevance_score
77
+ reranked_docs.append(original_doc)
78
+
79
+ return reranked_docs
80
+
81
+ except Exception:
82
+ # Fallback to original order if reranking fails
83
+ for doc in documents:
84
+ doc['rerank_score'] = 0.0
85
+ return documents[:top_k] if top_k else documents
@@ -0,0 +1,147 @@
1
+ from typing import List, Dict, Any, Union
2
+ import numpy as np
3
+
4
+ from agentrun_mem0.reranker.base import BaseReranker
5
+ from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
6
+ from agentrun_mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
7
+
8
+ try:
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ import torch
11
+ TRANSFORMERS_AVAILABLE = True
12
+ except ImportError:
13
+ TRANSFORMERS_AVAILABLE = False
14
+
15
+
16
+ class HuggingFaceReranker(BaseReranker):
17
+ """HuggingFace Transformers based reranker implementation."""
18
+
19
+ def __init__(self, config: Union[BaseRerankerConfig, HuggingFaceRerankerConfig, Dict]):
20
+ """
21
+ Initialize HuggingFace reranker.
22
+
23
+ Args:
24
+ config: Configuration object with reranker parameters
25
+ """
26
+ if not TRANSFORMERS_AVAILABLE:
27
+ raise ImportError("transformers package is required for HuggingFaceReranker. Install with: pip install transformers torch")
28
+
29
+ # Convert to HuggingFaceRerankerConfig if needed
30
+ if isinstance(config, dict):
31
+ config = HuggingFaceRerankerConfig(**config)
32
+ elif isinstance(config, BaseRerankerConfig) and not isinstance(config, HuggingFaceRerankerConfig):
33
+ # Convert BaseRerankerConfig to HuggingFaceRerankerConfig with defaults
34
+ config = HuggingFaceRerankerConfig(
35
+ provider=getattr(config, 'provider', 'huggingface'),
36
+ model=getattr(config, 'model', 'BAAI/bge-reranker-base'),
37
+ api_key=getattr(config, 'api_key', None),
38
+ top_k=getattr(config, 'top_k', None),
39
+ device=None, # Will auto-detect
40
+ batch_size=32, # Default
41
+ max_length=512, # Default
42
+ normalize=True, # Default
43
+ )
44
+
45
+ self.config = config
46
+
47
+ # Set device
48
+ if self.config.device is None:
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ else:
51
+ self.device = self.config.device
52
+
53
+ # Load model and tokenizer
54
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model)
55
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.config.model)
56
+ self.model.to(self.device)
57
+ self.model.eval()
58
+
59
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
60
+ """
61
+ Rerank documents using HuggingFace cross-encoder model.
62
+
63
+ Args:
64
+ query: The search query
65
+ documents: List of documents to rerank
66
+ top_k: Number of top documents to return
67
+
68
+ Returns:
69
+ List of reranked documents with rerank_score
70
+ """
71
+ if not documents:
72
+ return documents
73
+
74
+ # Extract text content for reranking
75
+ doc_texts = []
76
+ for doc in documents:
77
+ if 'memory' in doc:
78
+ doc_texts.append(doc['memory'])
79
+ elif 'text' in doc:
80
+ doc_texts.append(doc['text'])
81
+ elif 'content' in doc:
82
+ doc_texts.append(doc['content'])
83
+ else:
84
+ doc_texts.append(str(doc))
85
+
86
+ try:
87
+ scores = []
88
+
89
+ # Process documents in batches
90
+ for i in range(0, len(doc_texts), self.config.batch_size):
91
+ batch_docs = doc_texts[i:i + self.config.batch_size]
92
+ batch_pairs = [[query, doc] for doc in batch_docs]
93
+
94
+ # Tokenize batch
95
+ inputs = self.tokenizer(
96
+ batch_pairs,
97
+ padding=True,
98
+ truncation=True,
99
+ max_length=self.config.max_length,
100
+ return_tensors="pt"
101
+ ).to(self.device)
102
+
103
+ # Get scores
104
+ with torch.no_grad():
105
+ outputs = self.model(**inputs)
106
+ batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
107
+
108
+ # Handle single item case
109
+ if batch_scores.ndim == 0:
110
+ batch_scores = [float(batch_scores)]
111
+ else:
112
+ batch_scores = batch_scores.tolist()
113
+
114
+ scores.extend(batch_scores)
115
+
116
+ # Normalize scores if requested
117
+ if self.config.normalize:
118
+ scores = np.array(scores)
119
+ scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
120
+ scores = scores.tolist()
121
+
122
+ # Combine documents with scores
123
+ doc_score_pairs = list(zip(documents, scores))
124
+
125
+ # Sort by score (descending)
126
+ doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
127
+
128
+ # Apply top_k limit
129
+ final_top_k = top_k or self.config.top_k
130
+ if final_top_k:
131
+ doc_score_pairs = doc_score_pairs[:final_top_k]
132
+
133
+ # Create reranked results
134
+ reranked_docs = []
135
+ for doc, score in doc_score_pairs:
136
+ reranked_doc = doc.copy()
137
+ reranked_doc['rerank_score'] = float(score)
138
+ reranked_docs.append(reranked_doc)
139
+
140
+ return reranked_docs
141
+
142
+ except Exception:
143
+ # Fallback to original order if reranking fails
144
+ for doc in documents:
145
+ doc['rerank_score'] = 0.0
146
+ final_top_k = top_k or self.config.top_k
147
+ return documents[:final_top_k] if final_top_k else documents
@@ -0,0 +1,142 @@
1
+ import re
2
+ from typing import List, Dict, Any, Union
3
+
4
+ from agentrun_mem0.reranker.base import BaseReranker
5
+ from agentrun_mem0.utils.factory import LlmFactory
6
+ from agentrun_mem0.configs.rerankers.base import BaseRerankerConfig
7
+ from agentrun_mem0.configs.rerankers.llm import LLMRerankerConfig
8
+
9
+
10
+ class LLMReranker(BaseReranker):
11
+ """LLM-based reranker implementation."""
12
+
13
+ def __init__(self, config: Union[BaseRerankerConfig, LLMRerankerConfig, Dict]):
14
+ """
15
+ Initialize LLM reranker.
16
+
17
+ Args:
18
+ config: Configuration object with reranker parameters
19
+ """
20
+ # Convert to LLMRerankerConfig if needed
21
+ if isinstance(config, dict):
22
+ config = LLMRerankerConfig(**config)
23
+ elif isinstance(config, BaseRerankerConfig) and not isinstance(config, LLMRerankerConfig):
24
+ # Convert BaseRerankerConfig to LLMRerankerConfig with defaults
25
+ config = LLMRerankerConfig(
26
+ provider=getattr(config, 'provider', 'openai'),
27
+ model=getattr(config, 'model', 'gpt-4o-mini'),
28
+ api_key=getattr(config, 'api_key', None),
29
+ top_k=getattr(config, 'top_k', None),
30
+ temperature=0.0, # Default for reranking
31
+ max_tokens=100, # Default for reranking
32
+ )
33
+
34
+ self.config = config
35
+
36
+ # Create LLM configuration for the factory
37
+ llm_config = {
38
+ "model": self.config.model,
39
+ "temperature": self.config.temperature,
40
+ "max_tokens": self.config.max_tokens,
41
+ }
42
+
43
+ # Add API key if provided
44
+ if self.config.api_key:
45
+ llm_config["api_key"] = self.config.api_key
46
+
47
+ # Initialize LLM using the factory
48
+ self.llm = LlmFactory.create(self.config.provider, llm_config)
49
+
50
+ # Default scoring prompt
51
+ self.scoring_prompt = getattr(self.config, 'scoring_prompt', None) or self._get_default_prompt()
52
+
53
+ def _get_default_prompt(self) -> str:
54
+ """Get the default scoring prompt template."""
55
+ return """You are a relevance scoring assistant. Given a query and a document, you need to score how relevant the document is to the query.
56
+
57
+ Score the relevance on a scale from 0.0 to 1.0, where:
58
+ - 1.0 = Perfectly relevant and directly answers the query
59
+ - 0.8-0.9 = Highly relevant with good information
60
+ - 0.6-0.7 = Moderately relevant with some useful information
61
+ - 0.4-0.5 = Slightly relevant with limited useful information
62
+ - 0.0-0.3 = Not relevant or no useful information
63
+
64
+ Query: "{query}"
65
+ Document: "{document}"
66
+
67
+ Provide only a single numerical score between 0.0 and 1.0. Do not include any explanation or additional text."""
68
+
69
+ def _extract_score(self, response_text: str) -> float:
70
+ """Extract numerical score from LLM response."""
71
+ # Look for decimal numbers between 0.0 and 1.0
72
+ pattern = r'\b([01](?:\.\d+)?)\b'
73
+ matches = re.findall(pattern, response_text)
74
+
75
+ if matches:
76
+ score = float(matches[0])
77
+ return min(max(score, 0.0), 1.0) # Clamp between 0.0 and 1.0
78
+
79
+ # Fallback: return 0.5 if no valid score found
80
+ return 0.5
81
+
82
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
83
+ """
84
+ Rerank documents using LLM scoring.
85
+
86
+ Args:
87
+ query: The search query
88
+ documents: List of documents to rerank
89
+ top_k: Number of top documents to return
90
+
91
+ Returns:
92
+ List of reranked documents with rerank_score
93
+ """
94
+ if not documents:
95
+ return documents
96
+
97
+ scored_docs = []
98
+
99
+ for doc in documents:
100
+ # Extract text content
101
+ if 'memory' in doc:
102
+ doc_text = doc['memory']
103
+ elif 'text' in doc:
104
+ doc_text = doc['text']
105
+ elif 'content' in doc:
106
+ doc_text = doc['content']
107
+ else:
108
+ doc_text = str(doc)
109
+
110
+ try:
111
+ # Generate scoring prompt
112
+ prompt = self.scoring_prompt.format(query=query, document=doc_text)
113
+
114
+ # Get LLM response
115
+ response = self.llm.generate_response(
116
+ messages=[{"role": "user", "content": prompt}]
117
+ )
118
+
119
+ # Extract score from response
120
+ score = self._extract_score(response)
121
+
122
+ # Create scored document
123
+ scored_doc = doc.copy()
124
+ scored_doc['rerank_score'] = score
125
+ scored_docs.append(scored_doc)
126
+
127
+ except Exception:
128
+ # Fallback: assign neutral score if scoring fails
129
+ scored_doc = doc.copy()
130
+ scored_doc['rerank_score'] = 0.5
131
+ scored_docs.append(scored_doc)
132
+
133
+ # Sort by relevance score in descending order
134
+ scored_docs.sort(key=lambda x: x['rerank_score'], reverse=True)
135
+
136
+ # Apply top_k limit
137
+ if top_k:
138
+ scored_docs = scored_docs[:top_k]
139
+ elif self.config.top_k:
140
+ scored_docs = scored_docs[:self.config.top_k]
141
+
142
+ return scored_docs