MemoryOS 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/METADATA +78 -49
- memoryos-0.2.1.dist-info/RECORD +152 -0
- memoryos-0.2.1.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +471 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +159 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +358 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +83 -2
- memos/configs/llm.py +48 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_reader.py +4 -0
- memos/configs/mem_scheduler.py +91 -5
- memos/configs/memory.py +10 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +2 -2
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/neo4j.py +377 -101
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +68 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +102 -0
- memos/mem_os/core.py +131 -41
- memos/mem_os/main.py +93 -11
- memos/mem_os/product.py +1098 -35
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1154 -0
- memos/mem_reader/simple_struct.py +13 -8
- memos/mem_scheduler/base_scheduler.py +467 -36
- memos/mem_scheduler/general_scheduler.py +125 -244
- memos/mem_scheduler/modules/base.py +9 -0
- memos/mem_scheduler/modules/dispatcher.py +68 -2
- memos/mem_scheduler/modules/misc.py +39 -0
- memos/mem_scheduler/modules/monitor.py +228 -49
- memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +250 -23
- memos/mem_scheduler/modules/schemas.py +189 -7
- memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
- memos/mem_scheduler/utils.py +51 -2
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/general.py +7 -5
- memos/memories/textual/item.py +3 -1
- memos/memories/textual/tree.py +14 -6
- memos/memories/textual/tree_text_memory/organize/conflict.py +198 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +72 -23
- memos/memories/textual/tree_text_memory/organize/redundancy.py +193 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +233 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +606 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
- memos/parsers/markitdown.py +8 -2
- memos/templates/mem_reader_prompts.py +105 -36
- memos/templates/mem_scheduler_prompts.py +96 -47
- memos/templates/tree_reorganize_prompts.py +223 -0
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- memoryos-0.1.13.dist-info/RECORD +0 -122
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from memos.configs.graph_db import Neo4jGraphDBConfig
|
|
4
|
+
from memos.graph_dbs.neo4j import Neo4jGraphDB, _prepare_node_metadata
|
|
5
|
+
from memos.log import get_logger
|
|
6
|
+
from memos.vec_dbs.factory import VecDBFactory
|
|
7
|
+
from memos.vec_dbs.item import VecDBItem
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Neo4jCommunityGraphDB(Neo4jGraphDB):
|
|
14
|
+
"""
|
|
15
|
+
Neo4j Community Edition graph memory store.
|
|
16
|
+
|
|
17
|
+
Note:
|
|
18
|
+
This class avoids Enterprise-only features:
|
|
19
|
+
- No multi-database support
|
|
20
|
+
- No vector index
|
|
21
|
+
- No CREATE DATABASE
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config: Neo4jGraphDBConfig):
|
|
25
|
+
assert config.auto_create is False
|
|
26
|
+
assert config.use_multi_db is False
|
|
27
|
+
# Init vector database
|
|
28
|
+
self.vec_db = VecDBFactory.from_config(config.vec_config)
|
|
29
|
+
# Call parent init
|
|
30
|
+
super().__init__(config)
|
|
31
|
+
|
|
32
|
+
def create_index(
|
|
33
|
+
self,
|
|
34
|
+
label: str = "Memory",
|
|
35
|
+
vector_property: str = "embedding",
|
|
36
|
+
dimensions: int = 1536,
|
|
37
|
+
index_name: str = "memory_vector_index",
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Create the vector index for embedding and datetime indexes for created_at and updated_at fields.
|
|
41
|
+
"""
|
|
42
|
+
# Create indexes
|
|
43
|
+
self._create_basic_property_indexes()
|
|
44
|
+
|
|
45
|
+
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
|
|
46
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
47
|
+
metadata["user_name"] = self.config.user_name
|
|
48
|
+
|
|
49
|
+
# Safely process metadata
|
|
50
|
+
metadata = _prepare_node_metadata(metadata)
|
|
51
|
+
|
|
52
|
+
# Extract required fields
|
|
53
|
+
embedding = metadata.pop("embedding", None)
|
|
54
|
+
if embedding is None:
|
|
55
|
+
raise ValueError(f"Missing 'embedding' in metadata for node {id}")
|
|
56
|
+
|
|
57
|
+
# Merge node and set metadata
|
|
58
|
+
created_at = metadata.pop("created_at")
|
|
59
|
+
updated_at = metadata.pop("updated_at")
|
|
60
|
+
vector_sync_status = "success"
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
# Write to Vector DB
|
|
64
|
+
item = VecDBItem(
|
|
65
|
+
id=id,
|
|
66
|
+
vector=embedding,
|
|
67
|
+
payload={
|
|
68
|
+
"memory": memory,
|
|
69
|
+
"vector_sync": vector_sync_status,
|
|
70
|
+
**metadata, # unpack all metadata keys to top-level
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
self.vec_db.add([item])
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
|
|
76
|
+
vector_sync_status = "failed"
|
|
77
|
+
|
|
78
|
+
metadata["vector_sync"] = vector_sync_status
|
|
79
|
+
query = """
|
|
80
|
+
MERGE (n:Memory {id: $id})
|
|
81
|
+
SET n.memory = $memory,
|
|
82
|
+
n.created_at = datetime($created_at),
|
|
83
|
+
n.updated_at = datetime($updated_at),
|
|
84
|
+
n += $metadata
|
|
85
|
+
"""
|
|
86
|
+
with self.driver.session(database=self.db_name) as session:
|
|
87
|
+
session.run(
|
|
88
|
+
query,
|
|
89
|
+
id=id,
|
|
90
|
+
memory=memory,
|
|
91
|
+
created_at=created_at,
|
|
92
|
+
updated_at=updated_at,
|
|
93
|
+
metadata=metadata,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
|
|
97
|
+
where_user = ""
|
|
98
|
+
params = {"id": id}
|
|
99
|
+
|
|
100
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
101
|
+
where_user = "AND p.user_name = $user_name AND c.user_name = $user_name"
|
|
102
|
+
params["user_name"] = self.config.user_name
|
|
103
|
+
|
|
104
|
+
query = f"""
|
|
105
|
+
MATCH (p:Memory)-[:PARENT]->(c:Memory)
|
|
106
|
+
WHERE p.id = $id {where_user}
|
|
107
|
+
RETURN c.id AS id, c.memory AS memory
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
with self.driver.session(database=self.db_name) as session:
|
|
111
|
+
result = session.run(query, params)
|
|
112
|
+
child_nodes = [{"id": r["id"], "memory": r["memory"]} for r in result]
|
|
113
|
+
|
|
114
|
+
# Get embeddings from vector DB
|
|
115
|
+
ids = [n["id"] for n in child_nodes]
|
|
116
|
+
vec_items = {v.id: v.vector for v in self.vec_db.get_by_ids(ids)}
|
|
117
|
+
|
|
118
|
+
# Merge results
|
|
119
|
+
for node in child_nodes:
|
|
120
|
+
node["embedding"] = vec_items.get(node["id"])
|
|
121
|
+
|
|
122
|
+
return child_nodes
|
|
123
|
+
|
|
124
|
+
# Search / recall operations
|
|
125
|
+
def search_by_embedding(
|
|
126
|
+
self,
|
|
127
|
+
vector: list[float],
|
|
128
|
+
top_k: int = 5,
|
|
129
|
+
scope: str | None = None,
|
|
130
|
+
status: str | None = None,
|
|
131
|
+
threshold: float | None = None,
|
|
132
|
+
) -> list[dict]:
|
|
133
|
+
"""
|
|
134
|
+
Retrieve node IDs based on vector similarity using external vector DB.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
vector (list[float]): The embedding vector representing query semantics.
|
|
138
|
+
top_k (int): Number of top similar nodes to retrieve.
|
|
139
|
+
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
140
|
+
status (str, optional): Node status filter (e.g., 'activated', 'archived').
|
|
141
|
+
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
145
|
+
|
|
146
|
+
Notes:
|
|
147
|
+
- This method uses an external vector database (not Neo4j) to perform the search.
|
|
148
|
+
- If 'scope' is provided, it restricts results to nodes with matching memory_type.
|
|
149
|
+
- If 'status' is provided, it further filters nodes by status.
|
|
150
|
+
- If 'threshold' is provided, only results with score >= threshold will be returned.
|
|
151
|
+
- The returned IDs can be used to fetch full node data from Neo4j if needed.
|
|
152
|
+
"""
|
|
153
|
+
# Build VecDB filter
|
|
154
|
+
vec_filter = {}
|
|
155
|
+
if scope:
|
|
156
|
+
vec_filter["memory_type"] = scope
|
|
157
|
+
if status:
|
|
158
|
+
vec_filter["status"] = status
|
|
159
|
+
vec_filter["vector_sync"] = "success"
|
|
160
|
+
vec_filter["user_name"] = self.config.user_name
|
|
161
|
+
|
|
162
|
+
# Perform vector search
|
|
163
|
+
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
|
|
164
|
+
|
|
165
|
+
# Filter by threshold
|
|
166
|
+
if threshold is not None:
|
|
167
|
+
results = [r for r in results if r.score is None or r.score >= threshold]
|
|
168
|
+
|
|
169
|
+
# Return consistent format
|
|
170
|
+
return [{"id": r.id, "score": r.score} for r in results]
|
|
171
|
+
|
|
172
|
+
def get_all_memory_items(self, scope: str) -> list[dict]:
|
|
173
|
+
"""
|
|
174
|
+
Retrieve all memory items of a specific memory_type.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
list[dict]: Full list of memory items under this scope.
|
|
181
|
+
"""
|
|
182
|
+
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}:
|
|
183
|
+
raise ValueError(f"Unsupported memory type scope: {scope}")
|
|
184
|
+
|
|
185
|
+
where_clause = "WHERE n.memory_type = $scope"
|
|
186
|
+
params = {"scope": scope}
|
|
187
|
+
|
|
188
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
189
|
+
where_clause += " AND n.user_name = $user_name"
|
|
190
|
+
params["user_name"] = self.config.user_name
|
|
191
|
+
|
|
192
|
+
query = f"""
|
|
193
|
+
MATCH (n:Memory)
|
|
194
|
+
{where_clause}
|
|
195
|
+
RETURN n
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
with self.driver.session(database=self.db_name) as session:
|
|
199
|
+
results = session.run(query, params)
|
|
200
|
+
return [self._parse_node(dict(record["n"])) for record in results]
|
|
201
|
+
|
|
202
|
+
def clear(self) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Clear the entire graph if the target database exists.
|
|
205
|
+
"""
|
|
206
|
+
# Step 1: clear Neo4j part via parent logic
|
|
207
|
+
super().clear()
|
|
208
|
+
|
|
209
|
+
# Step2: Clear the vector db
|
|
210
|
+
try:
|
|
211
|
+
items = self.vec_db.get_by_filter({"user_name": self.config.user_name})
|
|
212
|
+
if items:
|
|
213
|
+
self.vec_db.delete([item.id for item in items])
|
|
214
|
+
logger.info(f"Cleared {len(items)} vectors for user '{self.config.user_name}'.")
|
|
215
|
+
else:
|
|
216
|
+
logger.info(f"No vectors to clear for user '{self.config.user_name}'.")
|
|
217
|
+
except Exception as e:
|
|
218
|
+
logger.warning(f"Failed to clear vector DB for user '{self.config.user_name}': {e}")
|
|
219
|
+
|
|
220
|
+
def drop_database(self) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Permanently delete the entire database this instance is using.
|
|
223
|
+
WARNING: This operation is destructive and cannot be undone.
|
|
224
|
+
"""
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Refusing to drop protected database: {self.db_name} in "
|
|
227
|
+
f"Shared Database Multi-Tenant mode"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Avoid enterprise feature
|
|
231
|
+
def _ensure_database_exists(self):
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def _create_basic_property_indexes(self) -> None:
|
|
235
|
+
"""
|
|
236
|
+
Create standard B-tree indexes on memory_type, created_at,
|
|
237
|
+
and updated_at fields.
|
|
238
|
+
Create standard B-tree indexes on user_name when use Shared Database
|
|
239
|
+
Multi-Tenant Mode
|
|
240
|
+
"""
|
|
241
|
+
# Step 1: Neo4j indexes
|
|
242
|
+
try:
|
|
243
|
+
with self.driver.session(database=self.db_name) as session:
|
|
244
|
+
session.run("""
|
|
245
|
+
CREATE INDEX memory_type_index IF NOT EXISTS
|
|
246
|
+
FOR (n:Memory) ON (n.memory_type)
|
|
247
|
+
""")
|
|
248
|
+
logger.debug("Index 'memory_type_index' ensured.")
|
|
249
|
+
|
|
250
|
+
session.run("""
|
|
251
|
+
CREATE INDEX memory_created_at_index IF NOT EXISTS
|
|
252
|
+
FOR (n:Memory) ON (n.created_at)
|
|
253
|
+
""")
|
|
254
|
+
logger.debug("Index 'memory_created_at_index' ensured.")
|
|
255
|
+
|
|
256
|
+
session.run("""
|
|
257
|
+
CREATE INDEX memory_updated_at_index IF NOT EXISTS
|
|
258
|
+
FOR (n:Memory) ON (n.updated_at)
|
|
259
|
+
""")
|
|
260
|
+
logger.debug("Index 'memory_updated_at_index' ensured.")
|
|
261
|
+
|
|
262
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
263
|
+
session.run(
|
|
264
|
+
"""
|
|
265
|
+
CREATE INDEX memory_user_name_index IF NOT EXISTS
|
|
266
|
+
FOR (n:Memory) ON (n.user_name)
|
|
267
|
+
"""
|
|
268
|
+
)
|
|
269
|
+
logger.debug("Index 'memory_user_name_index' ensured.")
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logger.warning(f"Failed to create basic property indexes: {e}")
|
|
272
|
+
|
|
273
|
+
# Step 2: VectorDB indexes
|
|
274
|
+
try:
|
|
275
|
+
if hasattr(self.vec_db, "ensure_payload_indexes"):
|
|
276
|
+
self.vec_db.ensure_payload_indexes(["user_name", "memory_type", "status"])
|
|
277
|
+
else:
|
|
278
|
+
logger.debug("VecDB does not support payload index creation; skipping.")
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.warning(f"Failed to create VecDB payload indexes: {e}")
|
|
281
|
+
|
|
282
|
+
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
|
|
283
|
+
"""Parse Neo4j node and optionally fetch embedding from vector DB."""
|
|
284
|
+
node = node_data.copy()
|
|
285
|
+
|
|
286
|
+
# Convert Neo4j datetime to string
|
|
287
|
+
for time_field in ("created_at", "updated_at"):
|
|
288
|
+
if time_field in node and hasattr(node[time_field], "isoformat"):
|
|
289
|
+
node[time_field] = node[time_field].isoformat()
|
|
290
|
+
node.pop("user_name", None)
|
|
291
|
+
|
|
292
|
+
new_node = {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
|
|
293
|
+
try:
|
|
294
|
+
vec_item = self.vec_db.get_by_id(new_node["id"])
|
|
295
|
+
if vec_item and vec_item.vector:
|
|
296
|
+
new_node["metadata"]["embedding"] = vec_item.vector
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
|
|
299
|
+
new_node["metadata"]["embedding"] = None
|
|
300
|
+
return new_node
|
memos/llms/base.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Generator
|
|
2
3
|
|
|
3
4
|
from memos.configs.llm import BaseLLMConfig
|
|
4
5
|
from memos.types import MessageList
|
|
@@ -14,3 +15,11 @@ class BaseLLM(ABC):
|
|
|
14
15
|
@abstractmethod
|
|
15
16
|
def generate(self, messages: MessageList, **kwargs) -> str:
|
|
16
17
|
"""Generate a response from the LLM."""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
21
|
+
"""
|
|
22
|
+
(Optional) Generate a streaming response from the LLM.
|
|
23
|
+
Subclasses should override this if they support streaming.
|
|
24
|
+
By default, this raises NotImplementedError.
|
|
25
|
+
"""
|
memos/llms/deepseek.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
|
|
3
|
+
from memos.configs.llm import DeepSeekLLMConfig
|
|
4
|
+
from memos.llms.openai import OpenAILLM
|
|
5
|
+
from memos.llms.utils import remove_thinking_tags
|
|
6
|
+
from memos.log import get_logger
|
|
7
|
+
from memos.types import MessageList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DeepSeekLLM(OpenAILLM):
|
|
14
|
+
"""DeepSeek LLM via OpenAI-compatible API."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: DeepSeekLLMConfig):
|
|
17
|
+
super().__init__(config)
|
|
18
|
+
|
|
19
|
+
def generate(self, messages: MessageList) -> str:
|
|
20
|
+
"""Generate a response from DeepSeek."""
|
|
21
|
+
response = self.client.chat.completions.create(
|
|
22
|
+
model=self.config.model_name_or_path,
|
|
23
|
+
messages=messages,
|
|
24
|
+
temperature=self.config.temperature,
|
|
25
|
+
max_tokens=self.config.max_tokens,
|
|
26
|
+
top_p=self.config.top_p,
|
|
27
|
+
extra_body=self.config.extra_body,
|
|
28
|
+
)
|
|
29
|
+
logger.info(f"Response from DeepSeek: {response.model_dump_json()}")
|
|
30
|
+
response_content = response.choices[0].message.content
|
|
31
|
+
if self.config.remove_think_prefix:
|
|
32
|
+
return remove_thinking_tags(response_content)
|
|
33
|
+
else:
|
|
34
|
+
return response_content
|
|
35
|
+
|
|
36
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
37
|
+
"""Stream response from DeepSeek."""
|
|
38
|
+
response = self.client.chat.completions.create(
|
|
39
|
+
model=self.config.model_name_or_path,
|
|
40
|
+
messages=messages,
|
|
41
|
+
stream=True,
|
|
42
|
+
temperature=self.config.temperature,
|
|
43
|
+
max_tokens=self.config.max_tokens,
|
|
44
|
+
top_p=self.config.top_p,
|
|
45
|
+
extra_body=self.config.extra_body,
|
|
46
|
+
)
|
|
47
|
+
# Streaming chunks of text
|
|
48
|
+
for chunk in response:
|
|
49
|
+
delta = chunk.choices[0].delta
|
|
50
|
+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
51
|
+
yield delta.reasoning_content
|
|
52
|
+
|
|
53
|
+
if hasattr(delta, "content") and delta.content:
|
|
54
|
+
yield delta.content
|
memos/llms/factory.py
CHANGED
|
@@ -2,9 +2,13 @@ from typing import Any, ClassVar
|
|
|
2
2
|
|
|
3
3
|
from memos.configs.llm import LLMConfigFactory
|
|
4
4
|
from memos.llms.base import BaseLLM
|
|
5
|
+
from memos.llms.deepseek import DeepSeekLLM
|
|
5
6
|
from memos.llms.hf import HFLLM
|
|
7
|
+
from memos.llms.hf_singleton import HFSingletonLLM
|
|
6
8
|
from memos.llms.ollama import OllamaLLM
|
|
7
|
-
from memos.llms.openai import OpenAILLM
|
|
9
|
+
from memos.llms.openai import AzureLLM, OpenAILLM
|
|
10
|
+
from memos.llms.qwen import QwenLLM
|
|
11
|
+
from memos.llms.vllm import VLLMLLM
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
class LLMFactory(BaseLLM):
|
|
@@ -12,8 +16,13 @@ class LLMFactory(BaseLLM):
|
|
|
12
16
|
|
|
13
17
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
14
18
|
"openai": OpenAILLM,
|
|
19
|
+
"azure": AzureLLM,
|
|
15
20
|
"ollama": OllamaLLM,
|
|
16
21
|
"huggingface": HFLLM,
|
|
22
|
+
"huggingface_singleton": HFSingletonLLM, # Add singleton version
|
|
23
|
+
"vllm": VLLMLLM,
|
|
24
|
+
"qwen": QwenLLM,
|
|
25
|
+
"deepseek": DeepSeekLLM,
|
|
17
26
|
}
|
|
18
27
|
|
|
19
28
|
@classmethod
|
memos/llms/hf.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
import
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
from transformers import (
|
|
4
5
|
AutoModelForCausalLM,
|
|
@@ -71,6 +72,26 @@ class HFLLM(BaseLLM):
|
|
|
71
72
|
else:
|
|
72
73
|
return self._generate_with_cache(prompt, past_key_values)
|
|
73
74
|
|
|
75
|
+
def generate_stream(
|
|
76
|
+
self, messages: MessageList, past_key_values: DynamicCache | None = None
|
|
77
|
+
) -> Generator[str, None, None]:
|
|
78
|
+
"""
|
|
79
|
+
Generate a streaming response from the model.
|
|
80
|
+
Args:
|
|
81
|
+
messages (MessageList): Chat messages for prompt construction.
|
|
82
|
+
past_key_values (DynamicCache | None): Optional KV cache for fast generation.
|
|
83
|
+
Yields:
|
|
84
|
+
str: Streaming model response chunks.
|
|
85
|
+
"""
|
|
86
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
87
|
+
messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
|
|
88
|
+
)
|
|
89
|
+
logger.info(f"HFLLM streaming prompt: {prompt}")
|
|
90
|
+
if past_key_values is None:
|
|
91
|
+
yield from self._generate_full_stream(prompt)
|
|
92
|
+
else:
|
|
93
|
+
yield from self._generate_with_cache_stream(prompt, past_key_values)
|
|
94
|
+
|
|
74
95
|
def _generate_full(self, prompt: str) -> str:
|
|
75
96
|
"""
|
|
76
97
|
Generate output from scratch using the full prompt.
|
|
@@ -104,6 +125,73 @@ class HFLLM(BaseLLM):
|
|
|
104
125
|
else response
|
|
105
126
|
)
|
|
106
127
|
|
|
128
|
+
def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
|
|
129
|
+
"""
|
|
130
|
+
Generate output from scratch using the full prompt with streaming.
|
|
131
|
+
Args:
|
|
132
|
+
prompt (str): The input prompt string.
|
|
133
|
+
Yields:
|
|
134
|
+
str: Streaming response chunks.
|
|
135
|
+
"""
|
|
136
|
+
import torch
|
|
137
|
+
|
|
138
|
+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
|
|
139
|
+
|
|
140
|
+
# Get generation parameters
|
|
141
|
+
max_new_tokens = getattr(self.config, "max_tokens", 128)
|
|
142
|
+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
|
|
143
|
+
|
|
144
|
+
# Manual streaming generation
|
|
145
|
+
generated_ids = inputs.input_ids.clone()
|
|
146
|
+
accumulated_text = ""
|
|
147
|
+
|
|
148
|
+
for _ in range(max_new_tokens):
|
|
149
|
+
# Forward pass
|
|
150
|
+
with torch.no_grad():
|
|
151
|
+
outputs = self.model(
|
|
152
|
+
input_ids=generated_ids,
|
|
153
|
+
use_cache=True,
|
|
154
|
+
return_dict=True,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Get next token logits
|
|
158
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
159
|
+
|
|
160
|
+
# Apply logits processors if sampling
|
|
161
|
+
if getattr(self.config, "do_sample", True):
|
|
162
|
+
batch_size, _ = next_token_logits.size()
|
|
163
|
+
dummy_ids = torch.zeros(
|
|
164
|
+
(batch_size, 1), dtype=torch.long, device=next_token_logits.device
|
|
165
|
+
)
|
|
166
|
+
filtered_logits = self.logits_processors(dummy_ids, next_token_logits)
|
|
167
|
+
probs = torch.softmax(filtered_logits, dim=-1)
|
|
168
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
169
|
+
else:
|
|
170
|
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
171
|
+
|
|
172
|
+
# Check for EOS token
|
|
173
|
+
if self._should_stop(next_token):
|
|
174
|
+
break
|
|
175
|
+
|
|
176
|
+
# Append new token
|
|
177
|
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
|
178
|
+
|
|
179
|
+
# Decode and yield the new token
|
|
180
|
+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
181
|
+
if new_token_text: # Only yield non-empty tokens
|
|
182
|
+
accumulated_text += new_token_text
|
|
183
|
+
|
|
184
|
+
# Apply thinking tag removal if enabled
|
|
185
|
+
if remove_think_prefix:
|
|
186
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
187
|
+
# Only yield the difference (new content)
|
|
188
|
+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
|
|
189
|
+
yield processed_text[len(accumulated_text) - len(new_token_text) :]
|
|
190
|
+
else:
|
|
191
|
+
yield new_token_text
|
|
192
|
+
else:
|
|
193
|
+
yield new_token_text
|
|
194
|
+
|
|
107
195
|
def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
|
|
108
196
|
"""
|
|
109
197
|
Generate output incrementally using an existing KV cache.
|
|
@@ -113,6 +201,8 @@ class HFLLM(BaseLLM):
|
|
|
113
201
|
Returns:
|
|
114
202
|
str: Model response.
|
|
115
203
|
"""
|
|
204
|
+
import torch
|
|
205
|
+
|
|
116
206
|
query_ids = self.tokenizer(
|
|
117
207
|
query, return_tensors="pt", add_special_tokens=False
|
|
118
208
|
).input_ids.to(self.model.device)
|
|
@@ -137,10 +227,70 @@ class HFLLM(BaseLLM):
|
|
|
137
227
|
else response
|
|
138
228
|
)
|
|
139
229
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
230
|
+
def _generate_with_cache_stream(
|
|
231
|
+
self, query: str, kv: DynamicCache
|
|
232
|
+
) -> Generator[str, None, None]:
|
|
233
|
+
"""
|
|
234
|
+
Generate output incrementally using an existing KV cache with streaming.
|
|
235
|
+
Args:
|
|
236
|
+
query (str): The new user query string.
|
|
237
|
+
kv (DynamicCache): The prefilled KV cache.
|
|
238
|
+
Yields:
|
|
239
|
+
str: Streaming response chunks.
|
|
240
|
+
"""
|
|
241
|
+
query_ids = self.tokenizer(
|
|
242
|
+
query, return_tensors="pt", add_special_tokens=False
|
|
243
|
+
).input_ids.to(self.model.device)
|
|
244
|
+
|
|
245
|
+
max_new_tokens = getattr(self.config, "max_tokens", 128)
|
|
246
|
+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
|
|
247
|
+
|
|
248
|
+
# Initial forward pass
|
|
249
|
+
logits, kv = self._prefill(query_ids, kv)
|
|
250
|
+
next_token = self._select_next_token(logits)
|
|
251
|
+
|
|
252
|
+
# Yield first token
|
|
253
|
+
first_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
254
|
+
accumulated_text = ""
|
|
255
|
+
if first_token_text:
|
|
256
|
+
accumulated_text += first_token_text
|
|
257
|
+
if remove_think_prefix:
|
|
258
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
259
|
+
if len(processed_text) > len(accumulated_text) - len(first_token_text):
|
|
260
|
+
yield processed_text[len(accumulated_text) - len(first_token_text) :]
|
|
261
|
+
else:
|
|
262
|
+
yield first_token_text
|
|
263
|
+
else:
|
|
264
|
+
yield first_token_text
|
|
265
|
+
|
|
266
|
+
generated = [next_token]
|
|
267
|
+
|
|
268
|
+
# Continue generation
|
|
269
|
+
for _ in range(max_new_tokens - 1):
|
|
270
|
+
if self._should_stop(next_token):
|
|
271
|
+
break
|
|
272
|
+
logits, kv = self._prefill(next_token, kv)
|
|
273
|
+
next_token = self._select_next_token(logits)
|
|
274
|
+
|
|
275
|
+
# Decode and yield the new token
|
|
276
|
+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
277
|
+
if new_token_text:
|
|
278
|
+
accumulated_text += new_token_text
|
|
279
|
+
|
|
280
|
+
# Apply thinking tag removal if enabled
|
|
281
|
+
if remove_think_prefix:
|
|
282
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
283
|
+
# Only yield the difference (new content)
|
|
284
|
+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
|
|
285
|
+
yield processed_text[len(accumulated_text) - len(new_token_text) :]
|
|
286
|
+
else:
|
|
287
|
+
yield new_token_text
|
|
288
|
+
else:
|
|
289
|
+
yield new_token_text
|
|
290
|
+
|
|
291
|
+
generated.append(next_token)
|
|
292
|
+
|
|
293
|
+
def _prefill(self, input_ids: Any, kv: DynamicCache) -> tuple[Any, DynamicCache]:
|
|
144
294
|
"""
|
|
145
295
|
Forward the model once, returning last-step logits and updated KV cache.
|
|
146
296
|
Args:
|
|
@@ -149,15 +299,18 @@ class HFLLM(BaseLLM):
|
|
|
149
299
|
Returns:
|
|
150
300
|
tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
|
|
151
301
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
302
|
+
import torch
|
|
303
|
+
|
|
304
|
+
with torch.no_grad():
|
|
305
|
+
out = self.model(
|
|
306
|
+
input_ids=input_ids,
|
|
307
|
+
use_cache=True,
|
|
308
|
+
past_key_values=kv,
|
|
309
|
+
return_dict=True,
|
|
310
|
+
)
|
|
158
311
|
return out.logits[:, -1, :], out.past_key_values
|
|
159
312
|
|
|
160
|
-
def _select_next_token(self, logits:
|
|
313
|
+
def _select_next_token(self, logits: Any) -> Any:
|
|
161
314
|
"""
|
|
162
315
|
Select the next token from logits using sampling or argmax, depending on config.
|
|
163
316
|
Args:
|
|
@@ -165,6 +318,8 @@ class HFLLM(BaseLLM):
|
|
|
165
318
|
Returns:
|
|
166
319
|
torch.Tensor: Selected token ID(s).
|
|
167
320
|
"""
|
|
321
|
+
import torch
|
|
322
|
+
|
|
168
323
|
if getattr(self.config, "do_sample", True):
|
|
169
324
|
batch_size, _ = logits.size()
|
|
170
325
|
dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device)
|
|
@@ -173,7 +328,7 @@ class HFLLM(BaseLLM):
|
|
|
173
328
|
return torch.multinomial(probs, num_samples=1)
|
|
174
329
|
return torch.argmax(logits, dim=-1, keepdim=True)
|
|
175
330
|
|
|
176
|
-
def _should_stop(self, token:
|
|
331
|
+
def _should_stop(self, token: Any) -> bool:
|
|
177
332
|
"""
|
|
178
333
|
Check if the given token is the EOS (end-of-sequence) token.
|
|
179
334
|
Args:
|
|
@@ -197,6 +352,8 @@ class HFLLM(BaseLLM):
|
|
|
197
352
|
Returns:
|
|
198
353
|
DynamicCache: The constructed KV cache object.
|
|
199
354
|
"""
|
|
355
|
+
import torch
|
|
356
|
+
|
|
200
357
|
# Accept multiple input types and convert to standard chat messages
|
|
201
358
|
if isinstance(messages, str):
|
|
202
359
|
messages = [
|