MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from memos.configs.memory import NaiveTextMemoryConfig
|
|
8
|
+
from memos.llms.factory import LLMFactory
|
|
9
|
+
from memos.log import get_logger
|
|
10
|
+
from memos.memories.textual.base import BaseTextMemory
|
|
11
|
+
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
|
|
12
|
+
from memos.types import MessageList
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
EXTRACTION_PROMPT_PART_1 = f"""You are a memory extractor. Your task is to extract memories from the given messages.
|
|
19
|
+
* You will receive a list of messages, each with a role (user or assistant) and content.
|
|
20
|
+
* Your job is to extract the memories from these messages.
|
|
21
|
+
* Each memory should be a dictionary with the following keys:
|
|
22
|
+
- "memory": The content of the memory (string). Rephrase the content if necessary.
|
|
23
|
+
- "type": The type of memory (string), e.g., "procedure", "fact", "event", "opinion", etc.
|
|
24
|
+
* Current date and time is {datetime.now().isoformat()}.
|
|
25
|
+
* Only return the list of memories in JSON format.
|
|
26
|
+
* Do not include any other text or explanation.
|
|
27
|
+
|
|
28
|
+
## Example
|
|
29
|
+
|
|
30
|
+
### Input
|
|
31
|
+
|
|
32
|
+
[
|
|
33
|
+
{{"role": "user", "content": "I plan to visit Paris next week."}},
|
|
34
|
+
{{"role": "assistant", "content": "Paris is a beautiful city with many attractions."}},
|
|
35
|
+
{{"role": "user", "content": "I love the Eiffel Tower."}},
|
|
36
|
+
{{"role": "assistant", "content": "The Eiffel Tower is a must-see landmark in Paris."}}
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
### Output
|
|
40
|
+
|
|
41
|
+
[
|
|
42
|
+
{{"memory": "User plans to visit Paris next week.", "metadata": {{"type": "event"}}}},
|
|
43
|
+
{{"memory": "User loves the Eiffel Tower.", "metadata": {{"type": "opinion"}}}},
|
|
44
|
+
]
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
EXTRACTION_PROMPT_PART_2 = """
|
|
48
|
+
## Query
|
|
49
|
+
|
|
50
|
+
### Input
|
|
51
|
+
|
|
52
|
+
{messages}
|
|
53
|
+
|
|
54
|
+
### Output
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class NaiveTextMemory(BaseTextMemory):
|
|
60
|
+
"""Naive textual memory implementation for storing and retrieving memories."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: NaiveTextMemoryConfig):
|
|
63
|
+
"""Initialize memory with the given configuration."""
|
|
64
|
+
# Set mode from class default or override if needed
|
|
65
|
+
self.mode = getattr(self.__class__, "mode", "sync")
|
|
66
|
+
self.config = config
|
|
67
|
+
self.extractor_llm = LLMFactory.from_config(config.extractor_llm)
|
|
68
|
+
self.memories = []
|
|
69
|
+
|
|
70
|
+
def extract(self, messages: MessageList) -> list[TextualMemoryItem]:
|
|
71
|
+
"""Extract memories based on the messages."""
|
|
72
|
+
str_messages = json.dumps(messages)
|
|
73
|
+
user_query = EXTRACTION_PROMPT_PART_1 + EXTRACTION_PROMPT_PART_2.format(
|
|
74
|
+
messages=str_messages
|
|
75
|
+
)
|
|
76
|
+
response = self.extractor_llm.generate([{"role": "user", "content": user_query}])
|
|
77
|
+
raw_extracted_memories = json.loads(response)
|
|
78
|
+
|
|
79
|
+
# Convert raw dictionaries to TextualMemoryItem objects
|
|
80
|
+
extracted_memories = []
|
|
81
|
+
for memory_dict in raw_extracted_memories:
|
|
82
|
+
# Ensure proper structure with memory and metadata
|
|
83
|
+
memory_content = memory_dict.get("memory", "")
|
|
84
|
+
metadata_dict = memory_dict.get("metadata", {})
|
|
85
|
+
|
|
86
|
+
# Create a TextualMemoryItem with properly structured metadata
|
|
87
|
+
memory_item = TextualMemoryItem(memory=memory_content, metadata=metadata_dict)
|
|
88
|
+
extracted_memories.append(memory_item)
|
|
89
|
+
|
|
90
|
+
return extracted_memories
|
|
91
|
+
|
|
92
|
+
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> None:
|
|
93
|
+
"""Add memories."""
|
|
94
|
+
for m in memories:
|
|
95
|
+
# Convert dict to TextualMemoryItem if needed
|
|
96
|
+
memory_item = TextualMemoryItem(**m) if isinstance(m, dict) else m
|
|
97
|
+
|
|
98
|
+
# Convert to dictionary for storage
|
|
99
|
+
memory_dict = memory_item.model_dump()
|
|
100
|
+
|
|
101
|
+
if memory_dict["id"] not in [m["id"] for m in self.memories]:
|
|
102
|
+
self.memories.append(memory_dict)
|
|
103
|
+
|
|
104
|
+
def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None:
|
|
105
|
+
"""Update a memory by memory_id."""
|
|
106
|
+
# Convert dict to TextualMemoryItem if needed
|
|
107
|
+
memory_item = (
|
|
108
|
+
TextualMemoryItem(**new_memory) if isinstance(new_memory, dict) else new_memory
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Ensure the memory item has the correct ID
|
|
112
|
+
memory_item.id = memory_id
|
|
113
|
+
memory_dict = memory_item.model_dump()
|
|
114
|
+
|
|
115
|
+
for i, memory in enumerate(self.memories):
|
|
116
|
+
if memory["id"] == memory_id:
|
|
117
|
+
self.memories[i] = memory_dict
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]:
|
|
121
|
+
"""Search for memories based on a query."""
|
|
122
|
+
sims = [
|
|
123
|
+
(memory, len(set(query.split()) & set(memory["memory"].split())))
|
|
124
|
+
for memory in self.memories
|
|
125
|
+
]
|
|
126
|
+
sims.sort(key=lambda x: x[1], reverse=True)
|
|
127
|
+
# Convert search results to TextualMemoryItem objects
|
|
128
|
+
return [TextualMemoryItem(**memory) for memory, _ in sims[:top_k]]
|
|
129
|
+
|
|
130
|
+
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
|
|
131
|
+
"""Get a memory by its ID."""
|
|
132
|
+
for memory in self.memories:
|
|
133
|
+
if memory["id"] == memory_id:
|
|
134
|
+
return TextualMemoryItem(**memory)
|
|
135
|
+
# Return empty memory item if not found
|
|
136
|
+
return TextualMemoryItem(id=memory_id, memory="", metadata=TextualMemoryMetadata())
|
|
137
|
+
|
|
138
|
+
def get_all(self) -> list[TextualMemoryItem]:
|
|
139
|
+
"""Get all memories."""
|
|
140
|
+
return [TextualMemoryItem(**memory) for memory in self.memories]
|
|
141
|
+
|
|
142
|
+
def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
|
|
143
|
+
"""Get memories by their IDs.
|
|
144
|
+
Args:
|
|
145
|
+
memory_ids (list[str]): List of memory IDs to retrieve.
|
|
146
|
+
Returns:
|
|
147
|
+
list[TextualMemoryItem]: List of memories with the specified IDs.
|
|
148
|
+
"""
|
|
149
|
+
return [self.get(memory_id) for memory_id in memory_ids]
|
|
150
|
+
|
|
151
|
+
def delete(self, memory_ids: list[str]) -> None:
|
|
152
|
+
"""Delete memories.
|
|
153
|
+
Args:
|
|
154
|
+
memory_ids (list[str]): List of memory IDs to delete.
|
|
155
|
+
"""
|
|
156
|
+
self.memories = [m for m in self.memories if m["id"] not in memory_ids]
|
|
157
|
+
|
|
158
|
+
def delete_all(self) -> None:
|
|
159
|
+
"""Delete all memories."""
|
|
160
|
+
self.memories = []
|
|
161
|
+
|
|
162
|
+
def load(self, dir: str) -> None:
|
|
163
|
+
try:
|
|
164
|
+
with open(os.path.join(dir, self.config.memory_filename), encoding="utf-8") as file:
|
|
165
|
+
raw_memories = json.load(file)
|
|
166
|
+
self.add(raw_memories)
|
|
167
|
+
except FileNotFoundError:
|
|
168
|
+
logger.error(f"Directory not found: {dir}")
|
|
169
|
+
except json.JSONDecodeError:
|
|
170
|
+
logger.error(f"Error decoding JSON from file in directory: {dir}")
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.error(f"An error occurred while loading memories: {e}")
|
|
173
|
+
|
|
174
|
+
def dump(self, dir: str) -> None:
|
|
175
|
+
try:
|
|
176
|
+
os.makedirs(dir, exist_ok=True)
|
|
177
|
+
memory_file = os.path.join(dir, self.config.memory_filename)
|
|
178
|
+
with open(memory_file, "w", encoding="utf-8") as file:
|
|
179
|
+
json.dump(self.memories, file, indent=4, ensure_ascii=False)
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.error(f"An error occurred while dumping memories: {e}")
|
|
182
|
+
raise
|
|
183
|
+
|
|
184
|
+
def drop(
|
|
185
|
+
self,
|
|
186
|
+
) -> None:
|
|
187
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from concurrent.futures import as_completed
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from memos.context.context import ContextThreadPoolExecutor
|
|
10
|
+
from memos.log import get_logger
|
|
11
|
+
from memos.memories.textual.item import TextualMemoryItem
|
|
12
|
+
from memos.templates.prefer_complete_prompt import (
|
|
13
|
+
NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT,
|
|
14
|
+
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT,
|
|
15
|
+
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE,
|
|
16
|
+
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE,
|
|
17
|
+
)
|
|
18
|
+
from memos.vec_dbs.item import MilvusVecDBItem
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseAdder(ABC):
|
|
25
|
+
"""Abstract base class for adders."""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None):
|
|
29
|
+
"""Initialize the adder."""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwargs) -> list[str]:
|
|
33
|
+
"""Add the instruct preference memories.
|
|
34
|
+
Args:
|
|
35
|
+
memories (list[TextualMemoryItem | dict[str, Any]]): The memories to add.
|
|
36
|
+
**kwargs: Additional keyword arguments.
|
|
37
|
+
Returns:
|
|
38
|
+
list[str]: List of added memory IDs.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NaiveAdder(BaseAdder):
|
|
43
|
+
"""Naive adder."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None):
|
|
46
|
+
"""Initialize the naive adder."""
|
|
47
|
+
super().__init__(llm_provider, embedder, vector_db, text_mem)
|
|
48
|
+
self.llm_provider = llm_provider
|
|
49
|
+
self.embedder = embedder
|
|
50
|
+
self.vector_db = vector_db
|
|
51
|
+
self.text_mem = text_mem
|
|
52
|
+
|
|
53
|
+
def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool:
|
|
54
|
+
"""Judge if the new message expresses the same core content as the old message."""
|
|
55
|
+
# Use the template prompt with placeholders
|
|
56
|
+
prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace(
|
|
57
|
+
"{new_information}", new_msg
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
|
|
62
|
+
response = response.strip().replace("```json", "").replace("```", "").strip()
|
|
63
|
+
result = json.loads(response)
|
|
64
|
+
response = result.get("is_same", False)
|
|
65
|
+
return response if isinstance(response, bool) else response.lower() == "true"
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.error(f"Error in judge_update_or_add: {e}")
|
|
68
|
+
# Fallback to simple string comparison
|
|
69
|
+
return old_msg == new_msg
|
|
70
|
+
|
|
71
|
+
def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[str, Any] | None:
|
|
72
|
+
if not retrieved_mems:
|
|
73
|
+
return None
|
|
74
|
+
prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE.replace("{new_memory}", new_mem).replace(
|
|
75
|
+
"{retrieved_memories}", retrieved_mems
|
|
76
|
+
)
|
|
77
|
+
try:
|
|
78
|
+
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
|
|
79
|
+
response = response.strip().replace("```json", "").replace("```", "").strip()
|
|
80
|
+
result = json.loads(response)
|
|
81
|
+
return result
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error in judge_update_or_add_fine: {e}")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool:
|
|
87
|
+
"""Judge if the new message is the same as the text memory for a single preference."""
|
|
88
|
+
if new_pref.payload["preference_type"] != "explicit_preference":
|
|
89
|
+
return False
|
|
90
|
+
text_recalls = self.text_mem.search(
|
|
91
|
+
query=new_pref.memory,
|
|
92
|
+
top_k=5,
|
|
93
|
+
info={
|
|
94
|
+
"user_id": new_pref.payload["user_id"],
|
|
95
|
+
"session_id": new_pref.payload["session_id"],
|
|
96
|
+
},
|
|
97
|
+
mode="fast",
|
|
98
|
+
search_filter={"session_id": new_pref.payload["session_id"]},
|
|
99
|
+
user_name=new_pref.payload["mem_cube_id"],
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
text_mem_recalls = [
|
|
103
|
+
{"id": text_recall.id, "memory": text_recall.memory} for text_recall in text_recalls
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
if not text_mem_recalls:
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
new_preference = {"id": new_pref.id, "memory": new_pref.payload["preference"]}
|
|
110
|
+
|
|
111
|
+
prompt = NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT.replace(
|
|
112
|
+
"{new_preference}", json.dumps(new_preference, ensure_ascii=False)
|
|
113
|
+
).replace("{retrieved_memories}", json.dumps(text_mem_recalls, ensure_ascii=False))
|
|
114
|
+
try:
|
|
115
|
+
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
|
|
116
|
+
response = response.strip().replace("```json", "").replace("```", "").strip()
|
|
117
|
+
result = json.loads(response)
|
|
118
|
+
exists = result.get("exists", False)
|
|
119
|
+
return exists
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.error(f"Error in judge_dup_with_text_mem: {e}")
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
def _judge_update_or_add_trace_op(
|
|
125
|
+
self, new_mems: str, retrieved_mems: str
|
|
126
|
+
) -> dict[str, Any] | None:
|
|
127
|
+
if not retrieved_mems:
|
|
128
|
+
return None
|
|
129
|
+
prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace(
|
|
130
|
+
"{new_memories}", new_mems
|
|
131
|
+
).replace("{retrieved_memories}", retrieved_mems)
|
|
132
|
+
try:
|
|
133
|
+
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
|
|
134
|
+
response = response.strip().replace("```json", "").replace("```", "").strip()
|
|
135
|
+
result = json.loads(response)
|
|
136
|
+
return result
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Error in judge_update_or_add_trace_op: {e}")
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
def _dedup_explicit_pref_by_textual(
|
|
142
|
+
self, new_prefs: list[MilvusVecDBItem]
|
|
143
|
+
) -> list[MilvusVecDBItem]:
|
|
144
|
+
"""Deduplicate explicit preferences by textual memory."""
|
|
145
|
+
if os.getenv("DEDUP_PREF_EXP_BY_TEXTUAL", "false").lower() != "true" or not self.text_mem:
|
|
146
|
+
return new_prefs
|
|
147
|
+
dedup_prefs = []
|
|
148
|
+
with ContextThreadPoolExecutor(max_workers=max(1, min(len(new_prefs), 5))) as executor:
|
|
149
|
+
future_to_idx = {
|
|
150
|
+
executor.submit(self._judge_dup_with_text_mem, new_pref): idx
|
|
151
|
+
for idx, new_pref in enumerate(new_prefs)
|
|
152
|
+
}
|
|
153
|
+
is_dup_flags = [False] * len(new_prefs)
|
|
154
|
+
for future in as_completed(future_to_idx):
|
|
155
|
+
idx = future_to_idx[future]
|
|
156
|
+
try:
|
|
157
|
+
is_dup_flags[idx] = future.result()
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(
|
|
160
|
+
f"Error in _judge_dup_with_text_mem for pref {new_prefs[idx].id}: {e}"
|
|
161
|
+
)
|
|
162
|
+
is_dup_flags[idx] = False
|
|
163
|
+
|
|
164
|
+
dedup_prefs = [pref for idx, pref in enumerate(new_prefs) if not is_dup_flags[idx]]
|
|
165
|
+
return dedup_prefs
|
|
166
|
+
|
|
167
|
+
def _update_memory_op_trace(
|
|
168
|
+
self,
|
|
169
|
+
new_memories: list[TextualMemoryItem],
|
|
170
|
+
retrieved_memories: list[MilvusVecDBItem],
|
|
171
|
+
collection_name: str,
|
|
172
|
+
) -> list[str] | str:
|
|
173
|
+
# create new vec db items
|
|
174
|
+
new_vec_db_items: list[MilvusVecDBItem] = []
|
|
175
|
+
for new_memory in new_memories:
|
|
176
|
+
payload = new_memory.to_dict()["metadata"]
|
|
177
|
+
fields_to_remove = {"dialog_id", "original_text", "embedding"}
|
|
178
|
+
payload = {k: v for k, v in payload.items() if k not in fields_to_remove}
|
|
179
|
+
new_vec_db_item = MilvusVecDBItem(
|
|
180
|
+
id=new_memory.id,
|
|
181
|
+
memory=new_memory.memory,
|
|
182
|
+
original_text=new_memory.metadata.original_text,
|
|
183
|
+
vector=new_memory.metadata.embedding,
|
|
184
|
+
payload=payload,
|
|
185
|
+
)
|
|
186
|
+
new_vec_db_items.append(new_vec_db_item)
|
|
187
|
+
|
|
188
|
+
new_mem_inputs = [
|
|
189
|
+
{
|
|
190
|
+
"id": new_memory.id,
|
|
191
|
+
"context_summary": new_memory.memory,
|
|
192
|
+
"preference": new_memory.payload["preference"],
|
|
193
|
+
}
|
|
194
|
+
for new_memory in new_vec_db_items
|
|
195
|
+
if new_memory.payload.get("preference", None)
|
|
196
|
+
]
|
|
197
|
+
retrieved_mem_inputs = [
|
|
198
|
+
{
|
|
199
|
+
"id": mem.id,
|
|
200
|
+
"context_summary": mem.memory,
|
|
201
|
+
"preference": mem.payload["preference"],
|
|
202
|
+
}
|
|
203
|
+
for mem in retrieved_memories
|
|
204
|
+
if mem.payload.get("preference", None)
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
rsp = self._judge_update_or_add_trace_op(
|
|
208
|
+
new_mems=json.dumps(new_mem_inputs, ensure_ascii=False),
|
|
209
|
+
retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False)
|
|
210
|
+
if retrieved_mem_inputs
|
|
211
|
+
else "",
|
|
212
|
+
)
|
|
213
|
+
if not rsp:
|
|
214
|
+
dedup_rsp = self._dedup_explicit_pref_by_textual(new_vec_db_items)
|
|
215
|
+
if not dedup_rsp:
|
|
216
|
+
return []
|
|
217
|
+
else:
|
|
218
|
+
new_vec_db_items = dedup_rsp
|
|
219
|
+
with ContextThreadPoolExecutor(max_workers=min(len(new_vec_db_items), 5)) as executor:
|
|
220
|
+
futures = {
|
|
221
|
+
executor.submit(self.vector_db.add, collection_name, [db_item]): db_item
|
|
222
|
+
for db_item in new_vec_db_items
|
|
223
|
+
}
|
|
224
|
+
for future in as_completed(futures):
|
|
225
|
+
result = future.result()
|
|
226
|
+
return [db_item.id for db_item in new_vec_db_items]
|
|
227
|
+
|
|
228
|
+
new_mem_db_item_map = {db_item.id: db_item for db_item in new_vec_db_items}
|
|
229
|
+
retrieved_mem_db_item_map = {db_item.id: db_item for db_item in retrieved_memories}
|
|
230
|
+
|
|
231
|
+
def execute_op(
|
|
232
|
+
op,
|
|
233
|
+
new_mem_db_item_map: dict[str, MilvusVecDBItem],
|
|
234
|
+
retrieved_mem_db_item_map: dict[str, MilvusVecDBItem],
|
|
235
|
+
) -> str | None:
|
|
236
|
+
op_type = op["type"].lower()
|
|
237
|
+
if op_type == "add":
|
|
238
|
+
if op["target_id"] in new_mem_db_item_map:
|
|
239
|
+
self.vector_db.add(collection_name, [new_mem_db_item_map[op["target_id"]]])
|
|
240
|
+
return new_mem_db_item_map[op["target_id"]].id
|
|
241
|
+
return None
|
|
242
|
+
elif op_type == "update":
|
|
243
|
+
if op["target_id"] in retrieved_mem_db_item_map:
|
|
244
|
+
update_mem_db_item = retrieved_mem_db_item_map[op["target_id"]]
|
|
245
|
+
update_mem_db_item.payload["preference"] = op["new_preference"]
|
|
246
|
+
update_mem_db_item.payload["updated_at"] = datetime.now().isoformat()
|
|
247
|
+
update_mem_db_item.memory = op["new_context_summary"]
|
|
248
|
+
update_mem_db_item.original_text = op["new_context_summary"]
|
|
249
|
+
update_mem_db_item.vector = self.embedder.embed([op["new_context_summary"]])[0]
|
|
250
|
+
self.vector_db.update(collection_name, op["target_id"], update_mem_db_item)
|
|
251
|
+
return op["target_id"]
|
|
252
|
+
return None
|
|
253
|
+
elif op_type == "delete":
|
|
254
|
+
self.vector_db.delete(collection_name, [op["target_id"]])
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
with ContextThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor:
|
|
258
|
+
future_to_op = {
|
|
259
|
+
executor.submit(execute_op, op, new_mem_db_item_map, retrieved_mem_db_item_map): op
|
|
260
|
+
for op in rsp["trace"]
|
|
261
|
+
}
|
|
262
|
+
added_ids = []
|
|
263
|
+
for future in as_completed(future_to_op):
|
|
264
|
+
result = future.result()
|
|
265
|
+
if result is not None:
|
|
266
|
+
added_ids.append(result)
|
|
267
|
+
|
|
268
|
+
return added_ids
|
|
269
|
+
|
|
270
|
+
def _update_memory_fine(
|
|
271
|
+
self,
|
|
272
|
+
new_memory: TextualMemoryItem,
|
|
273
|
+
retrieved_memories: list[MilvusVecDBItem],
|
|
274
|
+
collection_name: str,
|
|
275
|
+
) -> str:
|
|
276
|
+
payload = new_memory.to_dict()["metadata"]
|
|
277
|
+
fields_to_remove = {"dialog_id", "original_text", "embedding"}
|
|
278
|
+
payload = {k: v for k, v in payload.items() if k not in fields_to_remove}
|
|
279
|
+
vec_db_item = MilvusVecDBItem(
|
|
280
|
+
id=new_memory.id,
|
|
281
|
+
memory=new_memory.memory,
|
|
282
|
+
original_text=new_memory.metadata.original_text,
|
|
283
|
+
vector=new_memory.metadata.embedding,
|
|
284
|
+
payload=payload,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
new_mem_input = {"memory": new_memory.memory, "preference": new_memory.metadata.preference}
|
|
288
|
+
retrieved_mem_inputs = [
|
|
289
|
+
{
|
|
290
|
+
"id": mem.id,
|
|
291
|
+
"memory": mem.memory,
|
|
292
|
+
"preference": mem.payload["preference"],
|
|
293
|
+
}
|
|
294
|
+
for mem in retrieved_memories
|
|
295
|
+
if mem.payload.get("preference", None)
|
|
296
|
+
]
|
|
297
|
+
rsp = self._judge_update_or_add_fine(
|
|
298
|
+
new_mem=json.dumps(new_mem_input, ensure_ascii=False),
|
|
299
|
+
retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False)
|
|
300
|
+
if retrieved_mem_inputs
|
|
301
|
+
else "",
|
|
302
|
+
)
|
|
303
|
+
need_update = rsp.get("need_update", False) if rsp else False
|
|
304
|
+
need_update = (
|
|
305
|
+
need_update if isinstance(need_update, bool) else need_update.lower() == "true"
|
|
306
|
+
)
|
|
307
|
+
update_item = (
|
|
308
|
+
[mem for mem in retrieved_memories if mem.id == rsp["id"]]
|
|
309
|
+
if rsp and "id" in rsp
|
|
310
|
+
else []
|
|
311
|
+
)
|
|
312
|
+
if need_update and update_item and rsp:
|
|
313
|
+
update_vec_db_item = update_item[0]
|
|
314
|
+
update_vec_db_item.payload["preference"] = rsp["new_preference"]
|
|
315
|
+
update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"]
|
|
316
|
+
update_vec_db_item.memory = rsp["new_memory"]
|
|
317
|
+
update_vec_db_item.original_text = vec_db_item.original_text
|
|
318
|
+
update_vec_db_item.vector = self.embedder.embed([rsp["new_memory"]])[0]
|
|
319
|
+
|
|
320
|
+
self.vector_db.update(collection_name, rsp["id"], update_vec_db_item)
|
|
321
|
+
return rsp["id"]
|
|
322
|
+
else:
|
|
323
|
+
dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item])
|
|
324
|
+
if not dedup_rsp:
|
|
325
|
+
return ""
|
|
326
|
+
self.vector_db.add(collection_name, [vec_db_item])
|
|
327
|
+
return vec_db_item.id
|
|
328
|
+
|
|
329
|
+
def _update_memory_fast(
|
|
330
|
+
self,
|
|
331
|
+
new_memory: TextualMemoryItem,
|
|
332
|
+
retrieved_memories: list[MilvusVecDBItem],
|
|
333
|
+
collection_name: str,
|
|
334
|
+
) -> str:
|
|
335
|
+
payload = new_memory.to_dict()["metadata"]
|
|
336
|
+
fields_to_remove = {"dialog_id", "original_text", "embedding"}
|
|
337
|
+
payload = {k: v for k, v in payload.items() if k not in fields_to_remove}
|
|
338
|
+
vec_db_item = MilvusVecDBItem(
|
|
339
|
+
id=new_memory.id,
|
|
340
|
+
memory=new_memory.memory,
|
|
341
|
+
original_text=new_memory.metadata.original_text,
|
|
342
|
+
vector=new_memory.metadata.embedding,
|
|
343
|
+
payload=payload,
|
|
344
|
+
)
|
|
345
|
+
recall = retrieved_memories[0] if retrieved_memories else None
|
|
346
|
+
if not recall or (recall.score is not None and recall.score < 0.5):
|
|
347
|
+
self.vector_db.add(collection_name, [vec_db_item])
|
|
348
|
+
return new_memory.id
|
|
349
|
+
|
|
350
|
+
old_msg_str = recall.memory
|
|
351
|
+
new_msg_str = new_memory.memory
|
|
352
|
+
is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str)
|
|
353
|
+
dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item])
|
|
354
|
+
if not dedup_rsp:
|
|
355
|
+
return ""
|
|
356
|
+
if is_same:
|
|
357
|
+
vec_db_item.id = recall.id
|
|
358
|
+
self.vector_db.update(collection_name, recall.id, vec_db_item)
|
|
359
|
+
self.vector_db.add(collection_name, [vec_db_item])
|
|
360
|
+
return new_memory.id
|
|
361
|
+
|
|
362
|
+
def _update_memory(
|
|
363
|
+
self,
|
|
364
|
+
new_memory: TextualMemoryItem,
|
|
365
|
+
retrieved_memories: list[MilvusVecDBItem],
|
|
366
|
+
collection_name: str,
|
|
367
|
+
update_mode: str = "fast",
|
|
368
|
+
) -> list[str] | str | None:
|
|
369
|
+
"""Update the memory.
|
|
370
|
+
Args:
|
|
371
|
+
new_memory: TextualMemoryItem
|
|
372
|
+
retrieved_memories: list[MilvusVecDBItem]
|
|
373
|
+
collection_name: str
|
|
374
|
+
update_mode: str, "fast" or "fine"
|
|
375
|
+
"""
|
|
376
|
+
if update_mode == "fast":
|
|
377
|
+
return self._update_memory_fast(new_memory, retrieved_memories, collection_name)
|
|
378
|
+
elif update_mode == "fine":
|
|
379
|
+
return self._update_memory_fine(new_memory, retrieved_memories, collection_name)
|
|
380
|
+
else:
|
|
381
|
+
raise ValueError(f"Invalid update mode: {update_mode}")
|
|
382
|
+
|
|
383
|
+
def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | None:
|
|
384
|
+
"""Process a single memory and return its ID if added successfully."""
|
|
385
|
+
try:
|
|
386
|
+
pref_type_collection_map = {
|
|
387
|
+
"explicit_preference": "explicit_preference",
|
|
388
|
+
"implicit_preference": "implicit_preference",
|
|
389
|
+
}
|
|
390
|
+
preference_type = memory.metadata.preference_type
|
|
391
|
+
collection_name = pref_type_collection_map[preference_type]
|
|
392
|
+
|
|
393
|
+
search_results = self.vector_db.search(
|
|
394
|
+
query_vector=memory.metadata.embedding,
|
|
395
|
+
query=memory.memory,
|
|
396
|
+
collection_name=collection_name,
|
|
397
|
+
top_k=5,
|
|
398
|
+
filter={"user_id": memory.metadata.user_id},
|
|
399
|
+
)
|
|
400
|
+
search_results.sort(key=lambda x: x.score, reverse=True)
|
|
401
|
+
|
|
402
|
+
return self._update_memory(
|
|
403
|
+
memory,
|
|
404
|
+
search_results,
|
|
405
|
+
collection_name,
|
|
406
|
+
update_mode=os.getenv("PREFERENCE_ADDER_MODE", "fast"),
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
except Exception as e:
|
|
410
|
+
logger.error(f"Error processing memory {memory.id}: {e}")
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwargs) -> list[str]:
|
|
414
|
+
pref_type_collection_map = {
|
|
415
|
+
"explicit_preference": "explicit_preference",
|
|
416
|
+
"implicit_preference": "implicit_preference",
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
explicit_new_mems = []
|
|
420
|
+
implicit_new_mems = []
|
|
421
|
+
explicit_recalls = []
|
|
422
|
+
implicit_recalls = []
|
|
423
|
+
|
|
424
|
+
for memory in memories:
|
|
425
|
+
preference_type = memory.metadata.preference_type
|
|
426
|
+
collection_name = pref_type_collection_map[preference_type]
|
|
427
|
+
search_results = self.vector_db.search(
|
|
428
|
+
query_vector=memory.metadata.embedding,
|
|
429
|
+
query=memory.memory,
|
|
430
|
+
collection_name=collection_name,
|
|
431
|
+
top_k=5,
|
|
432
|
+
filter={"user_id": memory.metadata.user_id},
|
|
433
|
+
)
|
|
434
|
+
if preference_type == "explicit_preference":
|
|
435
|
+
explicit_recalls.extend(search_results)
|
|
436
|
+
explicit_new_mems.append(memory)
|
|
437
|
+
elif preference_type == "implicit_preference":
|
|
438
|
+
implicit_recalls.extend(search_results)
|
|
439
|
+
implicit_new_mems.append(memory)
|
|
440
|
+
|
|
441
|
+
explicit_recalls = list({recall.id: recall for recall in explicit_recalls}.values())
|
|
442
|
+
implicit_recalls = list({recall.id: recall for recall in implicit_recalls}.values())
|
|
443
|
+
|
|
444
|
+
# 使用线程池并行处理显式和隐式偏好
|
|
445
|
+
with ContextThreadPoolExecutor(max_workers=2) as executor:
|
|
446
|
+
explicit_future = executor.submit(
|
|
447
|
+
self._update_memory_op_trace,
|
|
448
|
+
explicit_new_mems,
|
|
449
|
+
explicit_recalls,
|
|
450
|
+
pref_type_collection_map["explicit_preference"],
|
|
451
|
+
)
|
|
452
|
+
implicit_future = executor.submit(
|
|
453
|
+
self._update_memory_op_trace,
|
|
454
|
+
implicit_new_mems,
|
|
455
|
+
implicit_recalls,
|
|
456
|
+
pref_type_collection_map["implicit_preference"],
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
explicit_added_ids = explicit_future.result()
|
|
460
|
+
implicit_added_ids = implicit_future.result()
|
|
461
|
+
|
|
462
|
+
return explicit_added_ids + implicit_added_ids
|
|
463
|
+
|
|
464
|
+
def process_memory_single(
|
|
465
|
+
self, memories: list[TextualMemoryItem], max_workers: int = 8, *args, **kwargs
|
|
466
|
+
) -> list[str]:
|
|
467
|
+
added_ids: list[str] = []
|
|
468
|
+
with ContextThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor:
|
|
469
|
+
future_to_memory = {
|
|
470
|
+
executor.submit(self._process_single_memory, memory): memory for memory in memories
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
for future in as_completed(future_to_memory):
|
|
474
|
+
try:
|
|
475
|
+
memory_id = future.result()
|
|
476
|
+
if memory_id:
|
|
477
|
+
if isinstance(memory_id, list):
|
|
478
|
+
added_ids.extend(memory_id)
|
|
479
|
+
else:
|
|
480
|
+
added_ids.append(memory_id)
|
|
481
|
+
except Exception as e:
|
|
482
|
+
memory = future_to_memory[future]
|
|
483
|
+
logger.error(f"Error processing memory {memory.id}: {e}")
|
|
484
|
+
continue
|
|
485
|
+
return added_ids
|
|
486
|
+
|
|
487
|
+
def add(
|
|
488
|
+
self,
|
|
489
|
+
memories: list[TextualMemoryItem | dict[str, Any]],
|
|
490
|
+
max_workers: int = 8,
|
|
491
|
+
*args,
|
|
492
|
+
**kwargs,
|
|
493
|
+
) -> list[str]:
|
|
494
|
+
"""Add the instruct preference memories using thread pool for acceleration."""
|
|
495
|
+
if not memories:
|
|
496
|
+
return []
|
|
497
|
+
|
|
498
|
+
process_map = {
|
|
499
|
+
"single": self.process_memory_single,
|
|
500
|
+
"batch": self.process_memory_batch,
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
process_func = process_map["single"]
|
|
504
|
+
return process_func(memories, max_workers)
|