MemoryOS 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (114) hide show
  1. {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/METADATA +67 -26
  2. memoryos-0.2.2.dist-info/RECORD +169 -0
  3. memoryos-0.2.2.dist-info/entry_points.txt +3 -0
  4. memos/__init__.py +1 -1
  5. memos/api/config.py +562 -0
  6. memos/api/context/context.py +147 -0
  7. memos/api/context/dependencies.py +90 -0
  8. memos/api/exceptions.py +28 -0
  9. memos/api/mcp_serve.py +502 -0
  10. memos/api/product_api.py +35 -0
  11. memos/api/product_models.py +163 -0
  12. memos/api/routers/__init__.py +1 -0
  13. memos/api/routers/product_router.py +386 -0
  14. memos/chunkers/sentence_chunker.py +8 -2
  15. memos/cli.py +113 -0
  16. memos/configs/embedder.py +27 -0
  17. memos/configs/graph_db.py +132 -3
  18. memos/configs/internet_retriever.py +6 -0
  19. memos/configs/llm.py +47 -0
  20. memos/configs/mem_cube.py +1 -1
  21. memos/configs/mem_os.py +5 -0
  22. memos/configs/mem_reader.py +9 -0
  23. memos/configs/mem_scheduler.py +107 -7
  24. memos/configs/mem_user.py +58 -0
  25. memos/configs/memory.py +5 -4
  26. memos/dependency.py +52 -0
  27. memos/embedders/ark.py +92 -0
  28. memos/embedders/factory.py +4 -0
  29. memos/embedders/sentence_transformer.py +8 -2
  30. memos/embedders/universal_api.py +32 -0
  31. memos/graph_dbs/base.py +11 -3
  32. memos/graph_dbs/factory.py +4 -0
  33. memos/graph_dbs/nebular.py +1364 -0
  34. memos/graph_dbs/neo4j.py +333 -124
  35. memos/graph_dbs/neo4j_community.py +300 -0
  36. memos/llms/base.py +9 -0
  37. memos/llms/deepseek.py +54 -0
  38. memos/llms/factory.py +10 -1
  39. memos/llms/hf.py +170 -13
  40. memos/llms/hf_singleton.py +114 -0
  41. memos/llms/ollama.py +4 -0
  42. memos/llms/openai.py +67 -1
  43. memos/llms/qwen.py +63 -0
  44. memos/llms/vllm.py +153 -0
  45. memos/log.py +1 -1
  46. memos/mem_cube/general.py +77 -16
  47. memos/mem_cube/utils.py +109 -0
  48. memos/mem_os/core.py +251 -51
  49. memos/mem_os/main.py +94 -12
  50. memos/mem_os/product.py +1220 -43
  51. memos/mem_os/utils/default_config.py +352 -0
  52. memos/mem_os/utils/format_utils.py +1401 -0
  53. memos/mem_reader/simple_struct.py +18 -10
  54. memos/mem_scheduler/base_scheduler.py +441 -40
  55. memos/mem_scheduler/general_scheduler.py +249 -248
  56. memos/mem_scheduler/modules/base.py +14 -5
  57. memos/mem_scheduler/modules/dispatcher.py +67 -4
  58. memos/mem_scheduler/modules/misc.py +104 -0
  59. memos/mem_scheduler/modules/monitor.py +240 -50
  60. memos/mem_scheduler/modules/rabbitmq_service.py +319 -0
  61. memos/mem_scheduler/modules/redis_service.py +32 -22
  62. memos/mem_scheduler/modules/retriever.py +167 -23
  63. memos/mem_scheduler/modules/scheduler_logger.py +255 -0
  64. memos/mem_scheduler/mos_for_test_scheduler.py +140 -0
  65. memos/mem_scheduler/schemas/__init__.py +0 -0
  66. memos/mem_scheduler/schemas/general_schemas.py +43 -0
  67. memos/mem_scheduler/{modules/schemas.py → schemas/message_schemas.py} +63 -61
  68. memos/mem_scheduler/schemas/monitor_schemas.py +329 -0
  69. memos/mem_scheduler/utils/__init__.py +0 -0
  70. memos/mem_scheduler/utils/filter_utils.py +176 -0
  71. memos/mem_scheduler/utils/misc_utils.py +61 -0
  72. memos/mem_user/factory.py +94 -0
  73. memos/mem_user/mysql_persistent_user_manager.py +271 -0
  74. memos/mem_user/mysql_user_manager.py +500 -0
  75. memos/mem_user/persistent_factory.py +96 -0
  76. memos/mem_user/persistent_user_manager.py +260 -0
  77. memos/mem_user/user_manager.py +4 -4
  78. memos/memories/activation/item.py +29 -0
  79. memos/memories/activation/kv.py +10 -3
  80. memos/memories/activation/vllmkv.py +219 -0
  81. memos/memories/factory.py +2 -0
  82. memos/memories/textual/base.py +1 -1
  83. memos/memories/textual/general.py +43 -97
  84. memos/memories/textual/item.py +5 -33
  85. memos/memories/textual/tree.py +22 -12
  86. memos/memories/textual/tree_text_memory/organize/conflict.py +9 -5
  87. memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
  88. memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
  89. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +50 -48
  90. memos/memories/textual/tree_text_memory/organize/reorganizer.py +81 -56
  91. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -3
  92. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  93. memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
  94. memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
  95. memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +2 -0
  96. memos/memories/textual/tree_text_memory/retrieve/searcher.py +52 -28
  97. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +42 -15
  98. memos/memories/textual/tree_text_memory/retrieve/utils.py +11 -7
  99. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +62 -58
  100. memos/memos_tools/dinding_report_bot.py +422 -0
  101. memos/memos_tools/notification_service.py +44 -0
  102. memos/memos_tools/notification_utils.py +96 -0
  103. memos/parsers/markitdown.py +8 -2
  104. memos/settings.py +3 -1
  105. memos/templates/mem_reader_prompts.py +66 -23
  106. memos/templates/mem_scheduler_prompts.py +126 -43
  107. memos/templates/mos_prompts.py +87 -0
  108. memos/templates/tree_reorganize_prompts.py +85 -30
  109. memos/vec_dbs/base.py +12 -0
  110. memos/vec_dbs/qdrant.py +46 -20
  111. memoryos-0.2.0.dist-info/RECORD +0 -128
  112. memos/mem_scheduler/utils.py +0 -26
  113. {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/LICENSE +0 -0
  114. {memoryos-0.2.0.dist-info → memoryos-0.2.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,300 @@
1
+ from typing import Any
2
+
3
+ from memos.configs.graph_db import Neo4jGraphDBConfig
4
+ from memos.graph_dbs.neo4j import Neo4jGraphDB, _prepare_node_metadata
5
+ from memos.log import get_logger
6
+ from memos.vec_dbs.factory import VecDBFactory
7
+ from memos.vec_dbs.item import VecDBItem
8
+
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class Neo4jCommunityGraphDB(Neo4jGraphDB):
14
+ """
15
+ Neo4j Community Edition graph memory store.
16
+
17
+ Note:
18
+ This class avoids Enterprise-only features:
19
+ - No multi-database support
20
+ - No vector index
21
+ - No CREATE DATABASE
22
+ """
23
+
24
+ def __init__(self, config: Neo4jGraphDBConfig):
25
+ assert config.auto_create is False
26
+ assert config.use_multi_db is False
27
+ # Init vector database
28
+ self.vec_db = VecDBFactory.from_config(config.vec_config)
29
+ # Call parent init
30
+ super().__init__(config)
31
+
32
+ def create_index(
33
+ self,
34
+ label: str = "Memory",
35
+ vector_property: str = "embedding",
36
+ dimensions: int = 1536,
37
+ index_name: str = "memory_vector_index",
38
+ ) -> None:
39
+ """
40
+ Create the vector index for embedding and datetime indexes for created_at and updated_at fields.
41
+ """
42
+ # Create indexes
43
+ self._create_basic_property_indexes()
44
+
45
+ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
46
+ if not self.config.use_multi_db and self.config.user_name:
47
+ metadata["user_name"] = self.config.user_name
48
+
49
+ # Safely process metadata
50
+ metadata = _prepare_node_metadata(metadata)
51
+
52
+ # Extract required fields
53
+ embedding = metadata.pop("embedding", None)
54
+ if embedding is None:
55
+ raise ValueError(f"Missing 'embedding' in metadata for node {id}")
56
+
57
+ # Merge node and set metadata
58
+ created_at = metadata.pop("created_at")
59
+ updated_at = metadata.pop("updated_at")
60
+ vector_sync_status = "success"
61
+
62
+ try:
63
+ # Write to Vector DB
64
+ item = VecDBItem(
65
+ id=id,
66
+ vector=embedding,
67
+ payload={
68
+ "memory": memory,
69
+ "vector_sync": vector_sync_status,
70
+ **metadata, # unpack all metadata keys to top-level
71
+ },
72
+ )
73
+ self.vec_db.add([item])
74
+ except Exception as e:
75
+ logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
76
+ vector_sync_status = "failed"
77
+
78
+ metadata["vector_sync"] = vector_sync_status
79
+ query = """
80
+ MERGE (n:Memory {id: $id})
81
+ SET n.memory = $memory,
82
+ n.created_at = datetime($created_at),
83
+ n.updated_at = datetime($updated_at),
84
+ n += $metadata
85
+ """
86
+ with self.driver.session(database=self.db_name) as session:
87
+ session.run(
88
+ query,
89
+ id=id,
90
+ memory=memory,
91
+ created_at=created_at,
92
+ updated_at=updated_at,
93
+ metadata=metadata,
94
+ )
95
+
96
+ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
97
+ where_user = ""
98
+ params = {"id": id}
99
+
100
+ if not self.config.use_multi_db and self.config.user_name:
101
+ where_user = "AND p.user_name = $user_name AND c.user_name = $user_name"
102
+ params["user_name"] = self.config.user_name
103
+
104
+ query = f"""
105
+ MATCH (p:Memory)-[:PARENT]->(c:Memory)
106
+ WHERE p.id = $id {where_user}
107
+ RETURN c.id AS id, c.memory AS memory
108
+ """
109
+
110
+ with self.driver.session(database=self.db_name) as session:
111
+ result = session.run(query, params)
112
+ child_nodes = [{"id": r["id"], "memory": r["memory"]} for r in result]
113
+
114
+ # Get embeddings from vector DB
115
+ ids = [n["id"] for n in child_nodes]
116
+ vec_items = {v.id: v.vector for v in self.vec_db.get_by_ids(ids)}
117
+
118
+ # Merge results
119
+ for node in child_nodes:
120
+ node["embedding"] = vec_items.get(node["id"])
121
+
122
+ return child_nodes
123
+
124
+ # Search / recall operations
125
+ def search_by_embedding(
126
+ self,
127
+ vector: list[float],
128
+ top_k: int = 5,
129
+ scope: str | None = None,
130
+ status: str | None = None,
131
+ threshold: float | None = None,
132
+ ) -> list[dict]:
133
+ """
134
+ Retrieve node IDs based on vector similarity using external vector DB.
135
+
136
+ Args:
137
+ vector (list[float]): The embedding vector representing query semantics.
138
+ top_k (int): Number of top similar nodes to retrieve.
139
+ scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
140
+ status (str, optional): Node status filter (e.g., 'activated', 'archived').
141
+ threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
142
+
143
+ Returns:
144
+ list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
145
+
146
+ Notes:
147
+ - This method uses an external vector database (not Neo4j) to perform the search.
148
+ - If 'scope' is provided, it restricts results to nodes with matching memory_type.
149
+ - If 'status' is provided, it further filters nodes by status.
150
+ - If 'threshold' is provided, only results with score >= threshold will be returned.
151
+ - The returned IDs can be used to fetch full node data from Neo4j if needed.
152
+ """
153
+ # Build VecDB filter
154
+ vec_filter = {}
155
+ if scope:
156
+ vec_filter["memory_type"] = scope
157
+ if status:
158
+ vec_filter["status"] = status
159
+ vec_filter["vector_sync"] = "success"
160
+ vec_filter["user_name"] = self.config.user_name
161
+
162
+ # Perform vector search
163
+ results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
164
+
165
+ # Filter by threshold
166
+ if threshold is not None:
167
+ results = [r for r in results if r.score is None or r.score >= threshold]
168
+
169
+ # Return consistent format
170
+ return [{"id": r.id, "score": r.score} for r in results]
171
+
172
+ def get_all_memory_items(self, scope: str) -> list[dict]:
173
+ """
174
+ Retrieve all memory items of a specific memory_type.
175
+
176
+ Args:
177
+ scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
178
+
179
+ Returns:
180
+ list[dict]: Full list of memory items under this scope.
181
+ """
182
+ if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}:
183
+ raise ValueError(f"Unsupported memory type scope: {scope}")
184
+
185
+ where_clause = "WHERE n.memory_type = $scope"
186
+ params = {"scope": scope}
187
+
188
+ if not self.config.use_multi_db and self.config.user_name:
189
+ where_clause += " AND n.user_name = $user_name"
190
+ params["user_name"] = self.config.user_name
191
+
192
+ query = f"""
193
+ MATCH (n:Memory)
194
+ {where_clause}
195
+ RETURN n
196
+ """
197
+
198
+ with self.driver.session(database=self.db_name) as session:
199
+ results = session.run(query, params)
200
+ return [self._parse_node(dict(record["n"])) for record in results]
201
+
202
+ def clear(self) -> None:
203
+ """
204
+ Clear the entire graph if the target database exists.
205
+ """
206
+ # Step 1: clear Neo4j part via parent logic
207
+ super().clear()
208
+
209
+ # Step2: Clear the vector db
210
+ try:
211
+ items = self.vec_db.get_by_filter({"user_name": self.config.user_name})
212
+ if items:
213
+ self.vec_db.delete([item.id for item in items])
214
+ logger.info(f"Cleared {len(items)} vectors for user '{self.config.user_name}'.")
215
+ else:
216
+ logger.info(f"No vectors to clear for user '{self.config.user_name}'.")
217
+ except Exception as e:
218
+ logger.warning(f"Failed to clear vector DB for user '{self.config.user_name}': {e}")
219
+
220
+ def drop_database(self) -> None:
221
+ """
222
+ Permanently delete the entire database this instance is using.
223
+ WARNING: This operation is destructive and cannot be undone.
224
+ """
225
+ raise ValueError(
226
+ f"Refusing to drop protected database: {self.db_name} in "
227
+ f"Shared Database Multi-Tenant mode"
228
+ )
229
+
230
+ # Avoid enterprise feature
231
+ def _ensure_database_exists(self):
232
+ pass
233
+
234
+ def _create_basic_property_indexes(self) -> None:
235
+ """
236
+ Create standard B-tree indexes on memory_type, created_at,
237
+ and updated_at fields.
238
+ Create standard B-tree indexes on user_name when use Shared Database
239
+ Multi-Tenant Mode
240
+ """
241
+ # Step 1: Neo4j indexes
242
+ try:
243
+ with self.driver.session(database=self.db_name) as session:
244
+ session.run("""
245
+ CREATE INDEX memory_type_index IF NOT EXISTS
246
+ FOR (n:Memory) ON (n.memory_type)
247
+ """)
248
+ logger.debug("Index 'memory_type_index' ensured.")
249
+
250
+ session.run("""
251
+ CREATE INDEX memory_created_at_index IF NOT EXISTS
252
+ FOR (n:Memory) ON (n.created_at)
253
+ """)
254
+ logger.debug("Index 'memory_created_at_index' ensured.")
255
+
256
+ session.run("""
257
+ CREATE INDEX memory_updated_at_index IF NOT EXISTS
258
+ FOR (n:Memory) ON (n.updated_at)
259
+ """)
260
+ logger.debug("Index 'memory_updated_at_index' ensured.")
261
+
262
+ if not self.config.use_multi_db and self.config.user_name:
263
+ session.run(
264
+ """
265
+ CREATE INDEX memory_user_name_index IF NOT EXISTS
266
+ FOR (n:Memory) ON (n.user_name)
267
+ """
268
+ )
269
+ logger.debug("Index 'memory_user_name_index' ensured.")
270
+ except Exception as e:
271
+ logger.warning(f"Failed to create basic property indexes: {e}")
272
+
273
+ # Step 2: VectorDB indexes
274
+ try:
275
+ if hasattr(self.vec_db, "ensure_payload_indexes"):
276
+ self.vec_db.ensure_payload_indexes(["user_name", "memory_type", "status"])
277
+ else:
278
+ logger.debug("VecDB does not support payload index creation; skipping.")
279
+ except Exception as e:
280
+ logger.warning(f"Failed to create VecDB payload indexes: {e}")
281
+
282
+ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
283
+ """Parse Neo4j node and optionally fetch embedding from vector DB."""
284
+ node = node_data.copy()
285
+
286
+ # Convert Neo4j datetime to string
287
+ for time_field in ("created_at", "updated_at"):
288
+ if time_field in node and hasattr(node[time_field], "isoformat"):
289
+ node[time_field] = node[time_field].isoformat()
290
+ node.pop("user_name", None)
291
+
292
+ new_node = {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
293
+ try:
294
+ vec_item = self.vec_db.get_by_id(new_node["id"])
295
+ if vec_item and vec_item.vector:
296
+ new_node["metadata"]["embedding"] = vec_item.vector
297
+ except Exception as e:
298
+ logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
299
+ new_node["metadata"]["embedding"] = None
300
+ return new_node
memos/llms/base.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Generator
2
3
 
3
4
  from memos.configs.llm import BaseLLMConfig
4
5
  from memos.types import MessageList
@@ -14,3 +15,11 @@ class BaseLLM(ABC):
14
15
  @abstractmethod
15
16
  def generate(self, messages: MessageList, **kwargs) -> str:
16
17
  """Generate a response from the LLM."""
18
+
19
+ @abstractmethod
20
+ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
21
+ """
22
+ (Optional) Generate a streaming response from the LLM.
23
+ Subclasses should override this if they support streaming.
24
+ By default, this raises NotImplementedError.
25
+ """
memos/llms/deepseek.py ADDED
@@ -0,0 +1,54 @@
1
+ from collections.abc import Generator
2
+
3
+ from memos.configs.llm import DeepSeekLLMConfig
4
+ from memos.llms.openai import OpenAILLM
5
+ from memos.llms.utils import remove_thinking_tags
6
+ from memos.log import get_logger
7
+ from memos.types import MessageList
8
+
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class DeepSeekLLM(OpenAILLM):
14
+ """DeepSeek LLM via OpenAI-compatible API."""
15
+
16
+ def __init__(self, config: DeepSeekLLMConfig):
17
+ super().__init__(config)
18
+
19
+ def generate(self, messages: MessageList) -> str:
20
+ """Generate a response from DeepSeek."""
21
+ response = self.client.chat.completions.create(
22
+ model=self.config.model_name_or_path,
23
+ messages=messages,
24
+ temperature=self.config.temperature,
25
+ max_tokens=self.config.max_tokens,
26
+ top_p=self.config.top_p,
27
+ extra_body=self.config.extra_body,
28
+ )
29
+ logger.info(f"Response from DeepSeek: {response.model_dump_json()}")
30
+ response_content = response.choices[0].message.content
31
+ if self.config.remove_think_prefix:
32
+ return remove_thinking_tags(response_content)
33
+ else:
34
+ return response_content
35
+
36
+ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
37
+ """Stream response from DeepSeek."""
38
+ response = self.client.chat.completions.create(
39
+ model=self.config.model_name_or_path,
40
+ messages=messages,
41
+ stream=True,
42
+ temperature=self.config.temperature,
43
+ max_tokens=self.config.max_tokens,
44
+ top_p=self.config.top_p,
45
+ extra_body=self.config.extra_body,
46
+ )
47
+ # Streaming chunks of text
48
+ for chunk in response:
49
+ delta = chunk.choices[0].delta
50
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content:
51
+ yield delta.reasoning_content
52
+
53
+ if hasattr(delta, "content") and delta.content:
54
+ yield delta.content
memos/llms/factory.py CHANGED
@@ -2,9 +2,13 @@ from typing import Any, ClassVar
2
2
 
3
3
  from memos.configs.llm import LLMConfigFactory
4
4
  from memos.llms.base import BaseLLM
5
+ from memos.llms.deepseek import DeepSeekLLM
5
6
  from memos.llms.hf import HFLLM
7
+ from memos.llms.hf_singleton import HFSingletonLLM
6
8
  from memos.llms.ollama import OllamaLLM
7
- from memos.llms.openai import OpenAILLM
9
+ from memos.llms.openai import AzureLLM, OpenAILLM
10
+ from memos.llms.qwen import QwenLLM
11
+ from memos.llms.vllm import VLLMLLM
8
12
 
9
13
 
10
14
  class LLMFactory(BaseLLM):
@@ -12,8 +16,13 @@ class LLMFactory(BaseLLM):
12
16
 
13
17
  backend_to_class: ClassVar[dict[str, Any]] = {
14
18
  "openai": OpenAILLM,
19
+ "azure": AzureLLM,
15
20
  "ollama": OllamaLLM,
16
21
  "huggingface": HFLLM,
22
+ "huggingface_singleton": HFSingletonLLM, # Add singleton version
23
+ "vllm": VLLMLLM,
24
+ "qwen": QwenLLM,
25
+ "deepseek": DeepSeekLLM,
17
26
  }
18
27
 
19
28
  @classmethod
memos/llms/hf.py CHANGED
@@ -1,4 +1,5 @@
1
- import torch
1
+ from collections.abc import Generator
2
+ from typing import Any
2
3
 
3
4
  from transformers import (
4
5
  AutoModelForCausalLM,
@@ -71,6 +72,26 @@ class HFLLM(BaseLLM):
71
72
  else:
72
73
  return self._generate_with_cache(prompt, past_key_values)
73
74
 
75
+ def generate_stream(
76
+ self, messages: MessageList, past_key_values: DynamicCache | None = None
77
+ ) -> Generator[str, None, None]:
78
+ """
79
+ Generate a streaming response from the model.
80
+ Args:
81
+ messages (MessageList): Chat messages for prompt construction.
82
+ past_key_values (DynamicCache | None): Optional KV cache for fast generation.
83
+ Yields:
84
+ str: Streaming model response chunks.
85
+ """
86
+ prompt = self.tokenizer.apply_chat_template(
87
+ messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
88
+ )
89
+ logger.info(f"HFLLM streaming prompt: {prompt}")
90
+ if past_key_values is None:
91
+ yield from self._generate_full_stream(prompt)
92
+ else:
93
+ yield from self._generate_with_cache_stream(prompt, past_key_values)
94
+
74
95
  def _generate_full(self, prompt: str) -> str:
75
96
  """
76
97
  Generate output from scratch using the full prompt.
@@ -104,6 +125,73 @@ class HFLLM(BaseLLM):
104
125
  else response
105
126
  )
106
127
 
128
+ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
129
+ """
130
+ Generate output from scratch using the full prompt with streaming.
131
+ Args:
132
+ prompt (str): The input prompt string.
133
+ Yields:
134
+ str: Streaming response chunks.
135
+ """
136
+ import torch
137
+
138
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
139
+
140
+ # Get generation parameters
141
+ max_new_tokens = getattr(self.config, "max_tokens", 128)
142
+ remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
143
+
144
+ # Manual streaming generation
145
+ generated_ids = inputs.input_ids.clone()
146
+ accumulated_text = ""
147
+
148
+ for _ in range(max_new_tokens):
149
+ # Forward pass
150
+ with torch.no_grad():
151
+ outputs = self.model(
152
+ input_ids=generated_ids,
153
+ use_cache=True,
154
+ return_dict=True,
155
+ )
156
+
157
+ # Get next token logits
158
+ next_token_logits = outputs.logits[:, -1, :]
159
+
160
+ # Apply logits processors if sampling
161
+ if getattr(self.config, "do_sample", True):
162
+ batch_size, _ = next_token_logits.size()
163
+ dummy_ids = torch.zeros(
164
+ (batch_size, 1), dtype=torch.long, device=next_token_logits.device
165
+ )
166
+ filtered_logits = self.logits_processors(dummy_ids, next_token_logits)
167
+ probs = torch.softmax(filtered_logits, dim=-1)
168
+ next_token = torch.multinomial(probs, num_samples=1)
169
+ else:
170
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
171
+
172
+ # Check for EOS token
173
+ if self._should_stop(next_token):
174
+ break
175
+
176
+ # Append new token
177
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
178
+
179
+ # Decode and yield the new token
180
+ new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
181
+ if new_token_text: # Only yield non-empty tokens
182
+ accumulated_text += new_token_text
183
+
184
+ # Apply thinking tag removal if enabled
185
+ if remove_think_prefix:
186
+ processed_text = remove_thinking_tags(accumulated_text)
187
+ # Only yield the difference (new content)
188
+ if len(processed_text) > len(accumulated_text) - len(new_token_text):
189
+ yield processed_text[len(accumulated_text) - len(new_token_text) :]
190
+ else:
191
+ yield new_token_text
192
+ else:
193
+ yield new_token_text
194
+
107
195
  def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
108
196
  """
109
197
  Generate output incrementally using an existing KV cache.
@@ -113,6 +201,8 @@ class HFLLM(BaseLLM):
113
201
  Returns:
114
202
  str: Model response.
115
203
  """
204
+ import torch
205
+
116
206
  query_ids = self.tokenizer(
117
207
  query, return_tensors="pt", add_special_tokens=False
118
208
  ).input_ids.to(self.model.device)
@@ -137,10 +227,70 @@ class HFLLM(BaseLLM):
137
227
  else response
138
228
  )
139
229
 
140
- @torch.no_grad()
141
- def _prefill(
142
- self, input_ids: torch.Tensor, kv: DynamicCache
143
- ) -> tuple[torch.Tensor, DynamicCache]:
230
+ def _generate_with_cache_stream(
231
+ self, query: str, kv: DynamicCache
232
+ ) -> Generator[str, None, None]:
233
+ """
234
+ Generate output incrementally using an existing KV cache with streaming.
235
+ Args:
236
+ query (str): The new user query string.
237
+ kv (DynamicCache): The prefilled KV cache.
238
+ Yields:
239
+ str: Streaming response chunks.
240
+ """
241
+ query_ids = self.tokenizer(
242
+ query, return_tensors="pt", add_special_tokens=False
243
+ ).input_ids.to(self.model.device)
244
+
245
+ max_new_tokens = getattr(self.config, "max_tokens", 128)
246
+ remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
247
+
248
+ # Initial forward pass
249
+ logits, kv = self._prefill(query_ids, kv)
250
+ next_token = self._select_next_token(logits)
251
+
252
+ # Yield first token
253
+ first_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
254
+ accumulated_text = ""
255
+ if first_token_text:
256
+ accumulated_text += first_token_text
257
+ if remove_think_prefix:
258
+ processed_text = remove_thinking_tags(accumulated_text)
259
+ if len(processed_text) > len(accumulated_text) - len(first_token_text):
260
+ yield processed_text[len(accumulated_text) - len(first_token_text) :]
261
+ else:
262
+ yield first_token_text
263
+ else:
264
+ yield first_token_text
265
+
266
+ generated = [next_token]
267
+
268
+ # Continue generation
269
+ for _ in range(max_new_tokens - 1):
270
+ if self._should_stop(next_token):
271
+ break
272
+ logits, kv = self._prefill(next_token, kv)
273
+ next_token = self._select_next_token(logits)
274
+
275
+ # Decode and yield the new token
276
+ new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
277
+ if new_token_text:
278
+ accumulated_text += new_token_text
279
+
280
+ # Apply thinking tag removal if enabled
281
+ if remove_think_prefix:
282
+ processed_text = remove_thinking_tags(accumulated_text)
283
+ # Only yield the difference (new content)
284
+ if len(processed_text) > len(accumulated_text) - len(new_token_text):
285
+ yield processed_text[len(accumulated_text) - len(new_token_text) :]
286
+ else:
287
+ yield new_token_text
288
+ else:
289
+ yield new_token_text
290
+
291
+ generated.append(next_token)
292
+
293
+ def _prefill(self, input_ids: Any, kv: DynamicCache) -> tuple[Any, DynamicCache]:
144
294
  """
145
295
  Forward the model once, returning last-step logits and updated KV cache.
146
296
  Args:
@@ -149,15 +299,18 @@ class HFLLM(BaseLLM):
149
299
  Returns:
150
300
  tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
151
301
  """
152
- out = self.model(
153
- input_ids=input_ids,
154
- use_cache=True,
155
- past_key_values=kv,
156
- return_dict=True,
157
- )
302
+ import torch
303
+
304
+ with torch.no_grad():
305
+ out = self.model(
306
+ input_ids=input_ids,
307
+ use_cache=True,
308
+ past_key_values=kv,
309
+ return_dict=True,
310
+ )
158
311
  return out.logits[:, -1, :], out.past_key_values
159
312
 
160
- def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor:
313
+ def _select_next_token(self, logits: Any) -> Any:
161
314
  """
162
315
  Select the next token from logits using sampling or argmax, depending on config.
163
316
  Args:
@@ -165,6 +318,8 @@ class HFLLM(BaseLLM):
165
318
  Returns:
166
319
  torch.Tensor: Selected token ID(s).
167
320
  """
321
+ import torch
322
+
168
323
  if getattr(self.config, "do_sample", True):
169
324
  batch_size, _ = logits.size()
170
325
  dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device)
@@ -173,7 +328,7 @@ class HFLLM(BaseLLM):
173
328
  return torch.multinomial(probs, num_samples=1)
174
329
  return torch.argmax(logits, dim=-1, keepdim=True)
175
330
 
176
- def _should_stop(self, token: torch.Tensor) -> bool:
331
+ def _should_stop(self, token: Any) -> bool:
177
332
  """
178
333
  Check if the given token is the EOS (end-of-sequence) token.
179
334
  Args:
@@ -197,6 +352,8 @@ class HFLLM(BaseLLM):
197
352
  Returns:
198
353
  DynamicCache: The constructed KV cache object.
199
354
  """
355
+ import torch
356
+
200
357
  # Accept multiple input types and convert to standard chat messages
201
358
  if isinstance(messages, str):
202
359
  messages = [