MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deep Search Agent implementation for MemOS.
|
|
3
|
+
|
|
4
|
+
This module implements a sophisticated deep search agent that performs iterative
|
|
5
|
+
query refinement and memory retrieval to provide comprehensive answers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
from memos.configs.mem_agent import DeepSearchAgentConfig
|
|
14
|
+
from memos.llms.base import BaseLLM
|
|
15
|
+
from memos.log import get_logger
|
|
16
|
+
from memos.mem_agent.base import BaseMemAgent
|
|
17
|
+
from memos.memories.textual.item import TextualMemoryItem
|
|
18
|
+
from memos.memories.textual.tree import TreeTextMemory
|
|
19
|
+
from memos.templates.mem_agent_prompts import (
|
|
20
|
+
FINAL_GENERATION_PROMPT,
|
|
21
|
+
QUERY_REWRITE_PROMPT,
|
|
22
|
+
REFLECTION_PROMPT,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from memos.types import MessageList
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class JSONResponseParser:
|
|
33
|
+
"""Elegant JSON response parser for LLM outputs"""
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def parse(response: str) -> dict[str, Any]:
|
|
37
|
+
"""Parse JSON response from LLM output with fallback strategies"""
|
|
38
|
+
# Clean response text by removing code block markers
|
|
39
|
+
cleaned = re.sub(r"^```(?:json)?\s*\n?|```\s*$", "", response.strip(), flags=re.IGNORECASE)
|
|
40
|
+
|
|
41
|
+
# Try parsing with multiple strategies
|
|
42
|
+
for text in [cleaned, re.search(r"\{.*\}", cleaned, re.DOTALL)]:
|
|
43
|
+
if not text:
|
|
44
|
+
continue
|
|
45
|
+
try:
|
|
46
|
+
return json.loads(text if isinstance(text, str) else text.group())
|
|
47
|
+
except json.JSONDecodeError:
|
|
48
|
+
continue
|
|
49
|
+
|
|
50
|
+
raise ValueError(f"Cannot parse JSON response: {response[:100]}...")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class QueryRewriter(BaseMemAgent):
|
|
54
|
+
"""Specialized agent for rewriting queries based on conversation history"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"):
|
|
57
|
+
self.llm = llm
|
|
58
|
+
self.name = name
|
|
59
|
+
|
|
60
|
+
def run(self, query: str, history: list[str] | None = None) -> str:
|
|
61
|
+
"""Rewrite query to be standalone and more searchable"""
|
|
62
|
+
history = history or []
|
|
63
|
+
history_context = self._format_history(history)
|
|
64
|
+
|
|
65
|
+
prompt = QUERY_REWRITE_PROMPT.format(history=history_context, query=query)
|
|
66
|
+
messages = [{"role": "user", "content": prompt}]
|
|
67
|
+
try:
|
|
68
|
+
response = self.llm.generate(messages)
|
|
69
|
+
logger.info(f"[{self.name}] Rewritten query: {response.strip()}")
|
|
70
|
+
return response.strip()
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f"[{self.name}] Query rewrite failed: {e}")
|
|
73
|
+
return query
|
|
74
|
+
|
|
75
|
+
def _format_history(self, history: list[str]) -> str:
|
|
76
|
+
"""Format conversation history for prompt context"""
|
|
77
|
+
if not history:
|
|
78
|
+
return "No previous conversation"
|
|
79
|
+
return "\n".join(f"- {msg}" for msg in history[-5:])
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ReflectionAgent:
|
|
83
|
+
"""Specialized agent for analyzing information sufficiency"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, llm: BaseLLM, name: str = "Reflector"):
|
|
86
|
+
self.llm = llm
|
|
87
|
+
self.name = name
|
|
88
|
+
|
|
89
|
+
def run(self, query: str, context: list[str]) -> dict[str, Any]:
|
|
90
|
+
"""Analyze whether retrieved context is sufficient to answer the query"""
|
|
91
|
+
context_summary = self._format_context(context)
|
|
92
|
+
prompt = REFLECTION_PROMPT.format(query=query, context=context_summary)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
response = self.llm.generate([{"role": "user", "content": prompt}])
|
|
96
|
+
logger.info(f"[{self.name}] Reflection response: {response}")
|
|
97
|
+
|
|
98
|
+
result = JSONResponseParser.parse(response.strip())
|
|
99
|
+
logger.info(f"[{self.name}] Reflection result: {result}")
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"[{self.name}] Reflection analysis failed: {e}")
|
|
104
|
+
return self._fallback_response()
|
|
105
|
+
|
|
106
|
+
def _format_context(self, context: list[str]) -> str:
|
|
107
|
+
"""Format context strings for analysis with length limits"""
|
|
108
|
+
return "\n".join(
|
|
109
|
+
f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def _fallback_response(self) -> dict[str, Any]:
|
|
113
|
+
"""Return safe fallback when reflection fails"""
|
|
114
|
+
return {
|
|
115
|
+
"status": "sufficient",
|
|
116
|
+
"reasoning": "Unable to analyze, proceeding with available information",
|
|
117
|
+
"missing_entities": [],
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class DeepSearchMemAgent(BaseMemAgent):
|
|
122
|
+
"""
|
|
123
|
+
Main orchestrator agent implementing the deep search pipeline.
|
|
124
|
+
|
|
125
|
+
This agent coordinates multiple sub-agents to perform iterative query refinement,
|
|
126
|
+
memory retrieval, and information synthesis as shown in the architecture diagram.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
llm: BaseLLM,
|
|
132
|
+
memory_retriever: TreeTextMemory | None = None,
|
|
133
|
+
config: DeepSearchAgentConfig | None = None,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Initialize DeepSearchMemAgent.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
llm: Language model for query rewriting and response generation
|
|
140
|
+
memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem)
|
|
141
|
+
config: Configuration for deep search behavior
|
|
142
|
+
"""
|
|
143
|
+
self.config = config or DeepSearchAgentConfig(agent_name="DeepSearchMemAgent")
|
|
144
|
+
self.max_iterations = self.config.max_iterations
|
|
145
|
+
self.timeout = self.config.timeout
|
|
146
|
+
self.llm: BaseLLM = llm
|
|
147
|
+
self.query_rewriter: QueryRewriter = QueryRewriter(llm, "QueryRewriter")
|
|
148
|
+
self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector")
|
|
149
|
+
self.memory_retriever = memory_retriever
|
|
150
|
+
|
|
151
|
+
def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]:
|
|
152
|
+
"""
|
|
153
|
+
Main execution method implementing the deep search pipeline.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
query: User query string
|
|
157
|
+
**kwargs: Additional arguments (history, user_id, etc.)
|
|
158
|
+
Returns:
|
|
159
|
+
Comprehensive response string
|
|
160
|
+
"""
|
|
161
|
+
if not self.llm:
|
|
162
|
+
raise RuntimeError("LLM not initialized.")
|
|
163
|
+
|
|
164
|
+
history = kwargs.get("history", [])
|
|
165
|
+
user_id = kwargs.get("user_id")
|
|
166
|
+
generated_answer = kwargs.get("generated_answer")
|
|
167
|
+
|
|
168
|
+
# Step 1: Query Rewriting
|
|
169
|
+
current_query = self.query_rewriter.run(query, history)
|
|
170
|
+
|
|
171
|
+
accumulated_context = []
|
|
172
|
+
accumulated_memories = []
|
|
173
|
+
search_keywords = [] # Can be extended with keyword extraction
|
|
174
|
+
|
|
175
|
+
# Step 2: Iterative Search and Reflection Loop
|
|
176
|
+
for iteration in range(self.max_iterations):
|
|
177
|
+
logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}")
|
|
178
|
+
|
|
179
|
+
search_results = self._perform_memory_search(
|
|
180
|
+
current_query, keywords=search_keywords, user_id=user_id, history=history
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if search_results:
|
|
184
|
+
context_batch = [self._extract_context_from_memory(mem) for mem in search_results]
|
|
185
|
+
accumulated_context.extend(context_batch)
|
|
186
|
+
reflection_result = self.reflector.run(current_query, context_batch)
|
|
187
|
+
status = reflection_result.get("status", "sufficient")
|
|
188
|
+
reasoning = reflection_result.get("reasoning", "")
|
|
189
|
+
|
|
190
|
+
logger.info(f"Reflection status: {status} - {reasoning}")
|
|
191
|
+
|
|
192
|
+
if status == "sufficient":
|
|
193
|
+
logger.info("Sufficient information collected")
|
|
194
|
+
accumulated_memories.extend(search_results)
|
|
195
|
+
break
|
|
196
|
+
elif status == "needs_raw":
|
|
197
|
+
logger.info("Need original sources, retrieving raw content")
|
|
198
|
+
accumulated_memories.extend(self._set_source_from_memory(search_results))
|
|
199
|
+
break
|
|
200
|
+
elif status == "missing_info":
|
|
201
|
+
accumulated_memories.extend(search_results)
|
|
202
|
+
missing_entities = reflection_result.get("missing_entities", [])
|
|
203
|
+
logger.info(f"Missing information: {missing_entities}")
|
|
204
|
+
current_query = reflection_result.get("new_search_query")
|
|
205
|
+
if not current_query:
|
|
206
|
+
refined_query = self._refine_query_for_missing_info(
|
|
207
|
+
current_query, missing_entities
|
|
208
|
+
)
|
|
209
|
+
current_query = refined_query
|
|
210
|
+
logger.info(f"Refined query: {current_query}")
|
|
211
|
+
else:
|
|
212
|
+
logger.warning(f"No search results for iteration {iteration + 1}")
|
|
213
|
+
if iteration == 0:
|
|
214
|
+
current_query = query
|
|
215
|
+
else:
|
|
216
|
+
break
|
|
217
|
+
|
|
218
|
+
if not generated_answer:
|
|
219
|
+
return self._remove_duplicate_memories(accumulated_memories)
|
|
220
|
+
else:
|
|
221
|
+
return self._generate_final_answer(
|
|
222
|
+
query, accumulated_memories, accumulated_context, history
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def _remove_duplicate_memories(
|
|
226
|
+
self, memories: list[TextualMemoryItem]
|
|
227
|
+
) -> list[TextualMemoryItem]:
|
|
228
|
+
"""
|
|
229
|
+
Remove duplicate memories based on memory content.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
memories: List of TextualMemoryItem objects to deduplicate
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
List of unique TextualMemoryItem objects (first occurrence kept)
|
|
236
|
+
"""
|
|
237
|
+
seen = set()
|
|
238
|
+
return [
|
|
239
|
+
memory
|
|
240
|
+
for memory in memories
|
|
241
|
+
if (content := getattr(memory, "memory", "").strip())
|
|
242
|
+
and content not in seen
|
|
243
|
+
and not seen.add(content)
|
|
244
|
+
]
|
|
245
|
+
|
|
246
|
+
def _generate_final_answer(
|
|
247
|
+
self,
|
|
248
|
+
original_query: str,
|
|
249
|
+
search_results: list[TextualMemoryItem],
|
|
250
|
+
context: list[str],
|
|
251
|
+
history: list[str] | None = None,
|
|
252
|
+
sources: list[str] | None = None,
|
|
253
|
+
missing_info: str | None = None,
|
|
254
|
+
) -> str:
|
|
255
|
+
"""
|
|
256
|
+
Generate the final answer.
|
|
257
|
+
"""
|
|
258
|
+
context_str = "\n".join([f"- {ctx}" for ctx in context[:20]])
|
|
259
|
+
prompt = FINAL_GENERATION_PROMPT.format(
|
|
260
|
+
query=original_query,
|
|
261
|
+
sources=sources,
|
|
262
|
+
context=context_str if context_str else "No specific context retrieved",
|
|
263
|
+
missing_info=missing_info if missing_info else "None identified",
|
|
264
|
+
)
|
|
265
|
+
messages: MessageList = [{"role": "user", "content": prompt}]
|
|
266
|
+
response = self.llm.generate(messages)
|
|
267
|
+
return response.strip()
|
|
268
|
+
|
|
269
|
+
def _perform_memory_search(
|
|
270
|
+
self,
|
|
271
|
+
query: str,
|
|
272
|
+
keywords: list[str] | None = None,
|
|
273
|
+
user_id: str | None = None,
|
|
274
|
+
history: list[str] | None = None,
|
|
275
|
+
top_k: int = 10,
|
|
276
|
+
) -> list[TextualMemoryItem]:
|
|
277
|
+
"""
|
|
278
|
+
Perform memory search using the configured retriever.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
query: Search query
|
|
282
|
+
keywords: Additional keywords for search
|
|
283
|
+
user_id: User identifier
|
|
284
|
+
top_k: Number of results to retrieve
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
List of retrieved memory items
|
|
288
|
+
"""
|
|
289
|
+
if not self.memory_retriever:
|
|
290
|
+
logger.warning("Memory retriever not configured, returning empty results")
|
|
291
|
+
return []
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
# Use the memory retriever interface
|
|
295
|
+
# This is a placeholder - actual implementation depends on the retriever interface
|
|
296
|
+
search_query = query
|
|
297
|
+
if keywords and len(keywords) > 1:
|
|
298
|
+
search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords
|
|
299
|
+
|
|
300
|
+
# Assuming the retriever has a search method similar to TreeTextMemory
|
|
301
|
+
results = self.memory_retriever.search(
|
|
302
|
+
query=search_query,
|
|
303
|
+
top_k=top_k,
|
|
304
|
+
mode="fast",
|
|
305
|
+
user_name=user_id,
|
|
306
|
+
info={"history": history},
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return results if isinstance(results, list) else []
|
|
310
|
+
|
|
311
|
+
except Exception as e:
|
|
312
|
+
logger.error(f"Error performing memory search: {e}")
|
|
313
|
+
return []
|
|
314
|
+
|
|
315
|
+
def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str:
|
|
316
|
+
"""Extract readable context from a memory item."""
|
|
317
|
+
if hasattr(memory_item, "memory"):
|
|
318
|
+
return str(memory_item.memory)
|
|
319
|
+
elif hasattr(memory_item, "content"):
|
|
320
|
+
return str(memory_item.content)
|
|
321
|
+
else:
|
|
322
|
+
return str(memory_item)
|
|
323
|
+
|
|
324
|
+
def _refine_query_for_missing_info(self, query: str, missing_entities: list[str]) -> str:
|
|
325
|
+
"""Refine the query to search for missing information."""
|
|
326
|
+
if not missing_entities:
|
|
327
|
+
return query
|
|
328
|
+
|
|
329
|
+
# Simple refinement strategy - append missing entities
|
|
330
|
+
entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities
|
|
331
|
+
refined_query = f"{query} {entities_str}"
|
|
332
|
+
|
|
333
|
+
return refined_query
|
|
334
|
+
|
|
335
|
+
def _set_source_from_memory(
|
|
336
|
+
self, memory_items: list[TextualMemoryItem]
|
|
337
|
+
) -> list[TextualMemoryItem]:
|
|
338
|
+
"""set source from memory item"""
|
|
339
|
+
for memory_item in memory_items:
|
|
340
|
+
if not hasattr(memory_item.metadata, "sources"):
|
|
341
|
+
continue
|
|
342
|
+
chat_sources = [
|
|
343
|
+
f"{source.chat_time} {source.role}: {source.content}"
|
|
344
|
+
for source in memory_item.metadata.sources
|
|
345
|
+
if hasattr(source, "type") and source.type == "chat"
|
|
346
|
+
]
|
|
347
|
+
if chat_sources:
|
|
348
|
+
memory_item.memory = "\n".join(chat_sources) + "\n"
|
|
349
|
+
return memory_items
|
|
350
|
+
|
|
351
|
+
def _generate_final_answer(
|
|
352
|
+
self,
|
|
353
|
+
original_query: str,
|
|
354
|
+
search_results: list[TextualMemoryItem],
|
|
355
|
+
context: list[str],
|
|
356
|
+
missing_info: str = "",
|
|
357
|
+
) -> str:
|
|
358
|
+
"""
|
|
359
|
+
Generate the final comprehensive answer.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
original_query: Original user query
|
|
363
|
+
search_results: All retrieved memory items
|
|
364
|
+
context: Extracted context strings
|
|
365
|
+
missing_info: Information about missing data
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Final answer string
|
|
369
|
+
"""
|
|
370
|
+
# Prepare context for the prompt
|
|
371
|
+
context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context
|
|
372
|
+
sources = (
|
|
373
|
+
f"Retrieved {len(search_results)} memory items"
|
|
374
|
+
if search_results
|
|
375
|
+
else "No specific sources"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
prompt = FINAL_GENERATION_PROMPT.format(
|
|
379
|
+
query=original_query,
|
|
380
|
+
sources=sources,
|
|
381
|
+
context=context_str if context_str else "No specific context retrieved",
|
|
382
|
+
missing_info=missing_info if missing_info else "None identified",
|
|
383
|
+
)
|
|
384
|
+
messages: MessageList = [{"role": "user", "content": prompt}]
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
response = self.llm.generate(messages)
|
|
388
|
+
return response.strip()
|
|
389
|
+
except Exception as e:
|
|
390
|
+
logger.error(f"Error generating final answer: {e}")
|
|
391
|
+
return f"I apologize, but I encountered an error while processing your query: {original_query}. Please try again."
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from memos.configs.mem_agent import MemAgentConfigFactory
|
|
4
|
+
from memos.mem_agent.base import BaseMemAgent
|
|
5
|
+
from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MemAgentFactory:
|
|
9
|
+
"""Factory class for creating MemAgent instances."""
|
|
10
|
+
|
|
11
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
12
|
+
"deep_search": DeepSearchMemAgent,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def from_config(
|
|
17
|
+
cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Any | None = None
|
|
18
|
+
) -> BaseMemAgent:
|
|
19
|
+
"""
|
|
20
|
+
Create a MemAgent instance from configuration.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
config_factory: Configuration factory for the agent
|
|
24
|
+
llm: Language model instance
|
|
25
|
+
memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Initialized MemAgent instance
|
|
29
|
+
"""
|
|
30
|
+
backend = config_factory.backend
|
|
31
|
+
if backend not in cls.backend_to_class:
|
|
32
|
+
raise ValueError(f"Invalid backend: {backend}")
|
|
33
|
+
mem_agent_class = cls.backend_to_class[backend]
|
|
34
|
+
return mem_agent_class(
|
|
35
|
+
llm=llm, memory_retriever=memory_retriever, config=config_factory.config
|
|
36
|
+
)
|
|
File without changes
|
memos/mem_chat/base.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from memos.configs.mem_chat import BaseMemChatConfig
|
|
4
|
+
from memos.mem_cube.base import BaseMemCube
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseMemChat(ABC):
|
|
8
|
+
"""Base class for all MemChat."""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def __init__(self, config: BaseMemChatConfig):
|
|
12
|
+
"""Initialize the MemChat with the given configuration."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def mem_cube(self) -> BaseMemCube:
|
|
17
|
+
"""The memory cube associated with this MemChat."""
|
|
18
|
+
|
|
19
|
+
@mem_cube.setter
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def mem_cube(self, value: BaseMemCube) -> None:
|
|
22
|
+
"""The memory cube associated with this MemChat."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def run(self) -> None:
|
|
26
|
+
"""Run the MemChat.
|
|
27
|
+
|
|
28
|
+
This `run` method can represent the core logic of a MemChat.
|
|
29
|
+
It could be an iterative chat process.
|
|
30
|
+
"""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from memos.configs.mem_chat import MemChatConfigFactory
|
|
4
|
+
from memos.mem_chat.base import BaseMemChat
|
|
5
|
+
from memos.mem_chat.simple import SimpleMemChat
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MemChatFactory(BaseMemChat):
|
|
9
|
+
"""Factory class for creating MemChat instances."""
|
|
10
|
+
|
|
11
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
12
|
+
"simple": SimpleMemChat,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def from_config(cls, config_factory: MemChatConfigFactory) -> BaseMemChat:
|
|
17
|
+
backend = config_factory.backend
|
|
18
|
+
if backend not in cls.backend_to_class:
|
|
19
|
+
raise ValueError(f"Invalid backend: {backend}")
|
|
20
|
+
mem_chat_class = cls.backend_to_class[backend]
|
|
21
|
+
return mem_chat_class(config_factory.config)
|
memos/mem_chat/simple.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from memos.configs.mem_chat import SimpleMemChatConfig
|
|
6
|
+
from memos.llms.factory import LLMFactory
|
|
7
|
+
from memos.log import get_logger
|
|
8
|
+
from memos.mem_chat.base import BaseMemChat
|
|
9
|
+
from memos.mem_cube.base import BaseMemCube
|
|
10
|
+
from memos.memories.activation.kv import move_dynamic_cache_htod
|
|
11
|
+
from memos.memories.textual.item import TextualMemoryItem
|
|
12
|
+
from memos.types import ChatHistory, MessageList
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SimpleMemChat(BaseMemChat):
|
|
19
|
+
"""Simple MemChat class."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: SimpleMemChatConfig):
|
|
22
|
+
"""Initialize the MemChat with the given configuration."""
|
|
23
|
+
self.config = config
|
|
24
|
+
self.chat_llm = LLMFactory.from_config(config.chat_llm)
|
|
25
|
+
self._mem_cube = None
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def mem_cube(self) -> BaseMemCube:
|
|
29
|
+
"""The memory cube associated with this MemChat."""
|
|
30
|
+
return self._mem_cube
|
|
31
|
+
|
|
32
|
+
@mem_cube.setter
|
|
33
|
+
def mem_cube(self, value: BaseMemCube) -> None:
|
|
34
|
+
"""The memory cube associated with this MemChat."""
|
|
35
|
+
self._mem_cube = value
|
|
36
|
+
|
|
37
|
+
def run(self) -> None:
|
|
38
|
+
"""Run the MemChat."""
|
|
39
|
+
|
|
40
|
+
# Start MemChat
|
|
41
|
+
|
|
42
|
+
print(
|
|
43
|
+
"\n📢 [System] " + "Simple MemChat is running.\n"
|
|
44
|
+
"Commands: 'bye' to quit, 'clear' to clear chat history, 'mem' to show all memories, 'export' to export chat history\n",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
messages = []
|
|
48
|
+
while True:
|
|
49
|
+
# Get user input
|
|
50
|
+
|
|
51
|
+
user_input = input("👤 [You] ").strip()
|
|
52
|
+
print()
|
|
53
|
+
|
|
54
|
+
if user_input.lower() == "bye":
|
|
55
|
+
break
|
|
56
|
+
elif user_input.lower() == "clear":
|
|
57
|
+
messages = []
|
|
58
|
+
print("📢 [System] Chat history cleared.")
|
|
59
|
+
continue
|
|
60
|
+
elif user_input.lower() == "mem":
|
|
61
|
+
if self.config.enable_textual_memory:
|
|
62
|
+
all_memories = self.mem_cube.text_mem.get_all()
|
|
63
|
+
print(f"🧠[Memory] \n{self._str_memories(all_memories)}\n")
|
|
64
|
+
else:
|
|
65
|
+
print("📢 [System] Textual memory is not enabled.\n")
|
|
66
|
+
continue
|
|
67
|
+
elif user_input.lower() == "export":
|
|
68
|
+
if messages:
|
|
69
|
+
filepath = self._export_chat_history(messages)
|
|
70
|
+
print(f"📢 [System] Chat history exported to: {filepath}\n")
|
|
71
|
+
else:
|
|
72
|
+
print("📢 [System] No chat history to export.\n")
|
|
73
|
+
continue
|
|
74
|
+
elif user_input == "":
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
# Get memories
|
|
78
|
+
|
|
79
|
+
if self.config.enable_textual_memory:
|
|
80
|
+
memories = self.mem_cube.text_mem.search(user_input, top_k=self.config.top_k)
|
|
81
|
+
print(
|
|
82
|
+
f"🧠[Memory] Searched memories:\n{self._str_memories(memories, mode='concise')}\n"
|
|
83
|
+
)
|
|
84
|
+
system_prompt = self._build_system_prompt(memories)
|
|
85
|
+
else:
|
|
86
|
+
system_prompt = self._build_system_prompt()
|
|
87
|
+
current_messages = [
|
|
88
|
+
{"role": "system", "content": system_prompt},
|
|
89
|
+
*messages,
|
|
90
|
+
{"role": "user", "content": user_input},
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
if self.config.enable_activation_memory:
|
|
94
|
+
past_key_values = None
|
|
95
|
+
loaded_kv_cache_item = next(
|
|
96
|
+
iter(self.mem_cube.act_mem.kv_cache_memories.values()), None
|
|
97
|
+
)
|
|
98
|
+
if loaded_kv_cache_item is not None:
|
|
99
|
+
# If has loaded kv cache, we move it to device before inferring.
|
|
100
|
+
# Currently, we move only single kv cache item
|
|
101
|
+
past_key_values = loaded_kv_cache_item
|
|
102
|
+
past_key_values.kv_cache = move_dynamic_cache_htod(
|
|
103
|
+
past_key_values.kv_cache, self.chat_llm.model.device
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Generate response
|
|
107
|
+
response = self.chat_llm.generate(
|
|
108
|
+
current_messages,
|
|
109
|
+
past_key_values=past_key_values.kv_cache if past_key_values else None,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
# Generate response without activation memory
|
|
113
|
+
response = self.chat_llm.generate(current_messages)
|
|
114
|
+
|
|
115
|
+
print(f"🤖 [Assistant] {response}\n")
|
|
116
|
+
messages.append({"role": "user", "content": user_input})
|
|
117
|
+
messages.append({"role": "assistant", "content": response})
|
|
118
|
+
messages = messages[
|
|
119
|
+
-self.config.max_turns_window :
|
|
120
|
+
] # Keep only recent messages to avoid context overflow
|
|
121
|
+
|
|
122
|
+
# Extract memories
|
|
123
|
+
|
|
124
|
+
if self.config.enable_textual_memory:
|
|
125
|
+
new_memories = self.mem_cube.text_mem.extract(messages[-2:])
|
|
126
|
+
for memory in new_memories:
|
|
127
|
+
memory.metadata.user_id = self.config.user_id
|
|
128
|
+
memory.metadata.session_id = self.config.session_id
|
|
129
|
+
memory.metadata.status = "activated"
|
|
130
|
+
self.mem_cube.text_mem.add(new_memories)
|
|
131
|
+
print(
|
|
132
|
+
f"🧠[Memory] Stored {len(new_memories)} new memory(ies):\n"
|
|
133
|
+
f"{self._str_memories(new_memories, 'concise')}\n"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Stop MemChat
|
|
137
|
+
|
|
138
|
+
print("📢 [System] MemChat has stopped.")
|
|
139
|
+
|
|
140
|
+
def _build_system_prompt(self, memories: list | None = None) -> str:
|
|
141
|
+
"""Build system prompt with optional memories context."""
|
|
142
|
+
base_prompt = (
|
|
143
|
+
"You are a knowledgeable and helpful AI assistant. "
|
|
144
|
+
"You have access to conversation memories that help you provide more personalized responses. "
|
|
145
|
+
"Use the memories to understand the user's context, preferences, and past interactions. "
|
|
146
|
+
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if memories:
|
|
150
|
+
memory_context = "\n\n## Memories:\n"
|
|
151
|
+
for i, memory in enumerate(memories, 1):
|
|
152
|
+
memory_context += f"{i}. ({memory.metadata.memory_time}) {memory.memory}\n"
|
|
153
|
+
return base_prompt + memory_context
|
|
154
|
+
|
|
155
|
+
return base_prompt
|
|
156
|
+
|
|
157
|
+
def _str_memories(
|
|
158
|
+
self, memories: list[TextualMemoryItem], mode: Literal["concise", "full"] = "full"
|
|
159
|
+
) -> str:
|
|
160
|
+
"""Format memories for display."""
|
|
161
|
+
if not memories:
|
|
162
|
+
return "No memories."
|
|
163
|
+
if mode == "concise":
|
|
164
|
+
return "\n".join(f"{i + 1}. {memory.memory}" for i, memory in enumerate(memories))
|
|
165
|
+
elif mode == "full":
|
|
166
|
+
return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories))
|
|
167
|
+
|
|
168
|
+
def _export_chat_history(self, messages: MessageList, output_dir: str = "chat_exports") -> str:
|
|
169
|
+
"""Export chat history to JSON file.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
messages: List of chat messages
|
|
173
|
+
output_dir: Directory to save the export file
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Path to the exported JSON file
|
|
177
|
+
"""
|
|
178
|
+
# Create output directory if it doesn't exist
|
|
179
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
180
|
+
|
|
181
|
+
# Generate filename with user_id and timestamp
|
|
182
|
+
timestamp = self.config.created_at.strftime("%Y%m%d_%H%M%S")
|
|
183
|
+
filename = f"{self.config.user_id}_{timestamp}_chat_history.json"
|
|
184
|
+
filepath = os.path.join(output_dir, filename)
|
|
185
|
+
|
|
186
|
+
# Prepare export data
|
|
187
|
+
export_data = ChatHistory(
|
|
188
|
+
user_id=self.config.user_id,
|
|
189
|
+
session_id=self.config.session_id,
|
|
190
|
+
created_at=self.config.created_at,
|
|
191
|
+
total_messages=len(messages),
|
|
192
|
+
chat_history=messages,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Write to JSON file
|
|
196
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
|
197
|
+
f.write(export_data.model_dump_json(indent=4, exclude_none=True, warnings="none"))
|
|
198
|
+
|
|
199
|
+
logger.info(f"Chat history exported to {filepath}")
|
|
200
|
+
return filepath
|
|
File without changes
|