MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
memos/hello_world.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from memos import log
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
logger = log.get_logger(__name__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def memos_hello_world() -> str:
|
|
8
|
+
logger.info("memos_hello_world function called.")
|
|
9
|
+
return "Hello world from memos!"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def memos_chend_hello_world() -> str:
|
|
13
|
+
logger.info("memos_chend_hello_world function called.")
|
|
14
|
+
return "Hello world from memos-chend!"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def memos_wanghy_hello_world() -> str:
|
|
18
|
+
logger.info("memos_wanghy_hello_world function called.")
|
|
19
|
+
return "Hello world from memos-wanghy!"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def memos_niusm_hello_world() -> str:
|
|
23
|
+
logger.info("memos_niusm_hello_world function called.")
|
|
24
|
+
return "Hello world from memos-niusm!"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def memos_huojh_hello_world(arr: list) -> list:
|
|
28
|
+
logger.info("memos_huojh_hello_world function called.")
|
|
29
|
+
if len(arr) <= 1:
|
|
30
|
+
return arr
|
|
31
|
+
else:
|
|
32
|
+
pivot = arr[0]
|
|
33
|
+
left = [x for x in arr[1:] if x < pivot]
|
|
34
|
+
right = [x for x in arr[1:] if x >= pivot]
|
|
35
|
+
return [*memos_huojh_hello_world(left), pivot, *memos_huojh_hello_world(right)]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def memos_dany_hello_world(para_1: int, para_2: str) -> str:
|
|
39
|
+
logger.info(f"logger.info: para_1 is {para_1}")
|
|
40
|
+
logger.debug(f"logger.debug: para_2 is {para_2}")
|
|
41
|
+
return f"return_value_{para_1}"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def memos_wangyzh_hello_world() -> str:
|
|
45
|
+
logger.info("memos_wangyzh_hello_world function called.")
|
|
46
|
+
return "Hello world from memos-wangyzh!"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def memos_zhaojihao_hello_world() -> str:
|
|
50
|
+
logger.info("memos_zhaojihao_hello_world function called.")
|
|
51
|
+
return "Hello world from memos-zhaojihao!"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def memos_yuqingchen_hello_world() -> str:
|
|
55
|
+
logger.info("memos_yuqingchen_hello_world function called.")
|
|
56
|
+
return "Hello world from memos-yuqingchen!"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def memos_chentang_hello_world(user_id: str = "locomo_exp_user_1", version: str = "default"):
|
|
60
|
+
import os
|
|
61
|
+
|
|
62
|
+
from memos.configs.memory import MemoryConfigFactory
|
|
63
|
+
from memos.memories.factory import MemoryFactory
|
|
64
|
+
|
|
65
|
+
config = MemoryConfigFactory(
|
|
66
|
+
backend="general_text",
|
|
67
|
+
config={
|
|
68
|
+
"extractor_llm": {
|
|
69
|
+
"backend": "openai",
|
|
70
|
+
"config": {
|
|
71
|
+
"model_name_or_path": os.getenv("MODEL"),
|
|
72
|
+
"temperature": 0,
|
|
73
|
+
"max_tokens": 8192,
|
|
74
|
+
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
75
|
+
"api_base": os.getenv("OPENAI_BASE_URL"),
|
|
76
|
+
},
|
|
77
|
+
},
|
|
78
|
+
"vector_db": {
|
|
79
|
+
"backend": "qdrant",
|
|
80
|
+
"config": {
|
|
81
|
+
"path": f"outputs/locomo/memos-{version}/storages/{user_id}/qdrant",
|
|
82
|
+
"collection_name": "test_textual_memory",
|
|
83
|
+
"distance_metric": "cosine",
|
|
84
|
+
"vector_dimension": 768, # nomic-embed-text model's embedding dimension is 768
|
|
85
|
+
},
|
|
86
|
+
},
|
|
87
|
+
"embedder": {
|
|
88
|
+
"backend": "ollama",
|
|
89
|
+
"config": {
|
|
90
|
+
"model_name_or_path": os.getenv("EMBEDDING_MODEL"),
|
|
91
|
+
},
|
|
92
|
+
},
|
|
93
|
+
},
|
|
94
|
+
)
|
|
95
|
+
memory = MemoryFactory.from_config(config)
|
|
96
|
+
|
|
97
|
+
return memory
|
memos/llms/__init__.py
ADDED
|
File without changes
|
memos/llms/base.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
|
|
4
|
+
from memos.configs.llm import BaseLLMConfig
|
|
5
|
+
from memos.types import MessageList
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseLLM(ABC):
|
|
9
|
+
"""Base class for all LLMs."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def __init__(self, config: BaseLLMConfig):
|
|
13
|
+
"""Initialize the LLM with the given configuration."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def generate(self, messages: MessageList, **kwargs) -> str:
|
|
17
|
+
"""Generate a response from the LLM."""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
21
|
+
"""
|
|
22
|
+
(Optional) Generate a streaming response from the LLM.
|
|
23
|
+
Subclasses should override this if they support streaming.
|
|
24
|
+
By default, this raises NotImplementedError.
|
|
25
|
+
"""
|
memos/llms/deepseek.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from memos.configs.llm import DeepSeekLLMConfig
|
|
2
|
+
from memos.llms.openai import OpenAILLM
|
|
3
|
+
from memos.log import get_logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
logger = get_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DeepSeekLLM(OpenAILLM):
|
|
10
|
+
"""DeepSeek LLM via OpenAI-compatible API."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, config: DeepSeekLLMConfig):
|
|
13
|
+
super().__init__(config)
|
memos/llms/factory.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from memos.configs.llm import LLMConfigFactory
|
|
4
|
+
from memos.llms.base import BaseLLM
|
|
5
|
+
from memos.llms.deepseek import DeepSeekLLM
|
|
6
|
+
from memos.llms.hf import HFLLM
|
|
7
|
+
from memos.llms.hf_singleton import HFSingletonLLM
|
|
8
|
+
from memos.llms.ollama import OllamaLLM
|
|
9
|
+
from memos.llms.openai import AzureLLM, OpenAILLM
|
|
10
|
+
from memos.llms.openai_new import OpenAIResponsesLLM
|
|
11
|
+
from memos.llms.qwen import QwenLLM
|
|
12
|
+
from memos.llms.vllm import VLLMLLM
|
|
13
|
+
from memos.memos_tools.singleton import singleton_factory
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMFactory(BaseLLM):
|
|
17
|
+
"""Factory class for creating LLM instances."""
|
|
18
|
+
|
|
19
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
20
|
+
"openai": OpenAILLM,
|
|
21
|
+
"azure": AzureLLM,
|
|
22
|
+
"ollama": OllamaLLM,
|
|
23
|
+
"huggingface": HFLLM,
|
|
24
|
+
"huggingface_singleton": HFSingletonLLM, # Add singleton version
|
|
25
|
+
"vllm": VLLMLLM,
|
|
26
|
+
"qwen": QwenLLM,
|
|
27
|
+
"deepseek": DeepSeekLLM,
|
|
28
|
+
"openai_new": OpenAIResponsesLLM,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
@singleton_factory()
|
|
33
|
+
def from_config(cls, config_factory: LLMConfigFactory) -> BaseLLM:
|
|
34
|
+
backend = config_factory.backend
|
|
35
|
+
if backend not in cls.backend_to_class:
|
|
36
|
+
raise ValueError(f"Invalid backend: {backend}")
|
|
37
|
+
llm_class = cls.backend_to_class[backend]
|
|
38
|
+
return llm_class(config_factory.config)
|
memos/llms/hf.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from transformers import (
|
|
5
|
+
DynamicCache,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from memos.configs.llm import HFLLMConfig
|
|
9
|
+
from memos.llms.base import BaseLLM
|
|
10
|
+
from memos.llms.utils import remove_thinking_tags
|
|
11
|
+
from memos.log import get_logger
|
|
12
|
+
from memos.types import MessageList
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HFLLM(BaseLLM):
|
|
19
|
+
"""
|
|
20
|
+
HFLLM: Transformers LLM class supporting cache-augmented generation (CAG) and sampling.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: HFLLMConfig):
|
|
24
|
+
"""
|
|
25
|
+
Initialize the HFLLM model and tokenizer, and set up logits processors for sampling.
|
|
26
|
+
"""
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
from transformers import (
|
|
30
|
+
AutoModelForCausalLM,
|
|
31
|
+
AutoTokenizer,
|
|
32
|
+
LogitsProcessorList,
|
|
33
|
+
TemperatureLogitsWarper,
|
|
34
|
+
TopKLogitsWarper,
|
|
35
|
+
TopPLogitsWarper,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
self.config = config
|
|
39
|
+
|
|
40
|
+
# Default model if not specified
|
|
41
|
+
if not self.config.model_name_or_path:
|
|
42
|
+
self.config.model_name_or_path = "Qwen/Qwen3-1.7B"
|
|
43
|
+
|
|
44
|
+
# Initialize hf model
|
|
45
|
+
if torch.backends.mps.is_available():
|
|
46
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
47
|
+
self.config.model_name_or_path, torch_dtype="auto"
|
|
48
|
+
).to("mps")
|
|
49
|
+
else:
|
|
50
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
51
|
+
self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
|
|
52
|
+
)
|
|
53
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
54
|
+
self.config.model_name_or_path, use_fast=True, force_download=True
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Logits processors for sampling
|
|
58
|
+
processors = []
|
|
59
|
+
if getattr(self.config, "temperature", 1.0) != 1.0:
|
|
60
|
+
processors.append(TemperatureLogitsWarper(self.config.temperature))
|
|
61
|
+
if getattr(self.config, "top_k", 0) > 0:
|
|
62
|
+
processors.append(TopKLogitsWarper(self.config.top_k))
|
|
63
|
+
if 0.0 < getattr(self.config, "top_p", 1.0) < 1.0:
|
|
64
|
+
processors.append(TopPLogitsWarper(self.config.top_p))
|
|
65
|
+
self.logits_processors = LogitsProcessorList(processors)
|
|
66
|
+
|
|
67
|
+
def generate(
|
|
68
|
+
self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Generate a response from the model. If past_key_values is provided, use cache-augmented generation.
|
|
72
|
+
Args:
|
|
73
|
+
messages (MessageList): Chat messages for prompt construction.
|
|
74
|
+
past_key_values (DynamicCache | None): Optional KV cache for fast generation.
|
|
75
|
+
Returns:
|
|
76
|
+
str: Model response.
|
|
77
|
+
"""
|
|
78
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
79
|
+
messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
|
|
80
|
+
)
|
|
81
|
+
logger.info(f"HFLLM prompt: {prompt}")
|
|
82
|
+
if past_key_values is None:
|
|
83
|
+
return self._generate_full(prompt, **kwargs)
|
|
84
|
+
else:
|
|
85
|
+
return self._generate_with_cache(prompt, past_key_values, **kwargs)
|
|
86
|
+
|
|
87
|
+
def generate_stream(
|
|
88
|
+
self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs
|
|
89
|
+
) -> Generator[str, None, None]:
|
|
90
|
+
"""
|
|
91
|
+
Generate a streaming response from the model.
|
|
92
|
+
Args:
|
|
93
|
+
messages (MessageList): Chat messages for prompt construction.
|
|
94
|
+
past_key_values (DynamicCache | None): Optional KV cache for fast generation.
|
|
95
|
+
Yields:
|
|
96
|
+
str: Streaming model response chunks.
|
|
97
|
+
"""
|
|
98
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
99
|
+
messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
|
|
100
|
+
)
|
|
101
|
+
logger.info(f"HFLLM streaming prompt: {prompt}")
|
|
102
|
+
if past_key_values is None:
|
|
103
|
+
yield from self._generate_full_stream(prompt)
|
|
104
|
+
else:
|
|
105
|
+
yield from self._generate_with_cache_stream(prompt, past_key_values)
|
|
106
|
+
|
|
107
|
+
def _generate_full(self, prompt: str, **kwargs) -> str:
|
|
108
|
+
"""
|
|
109
|
+
Generate output from scratch using the full prompt.
|
|
110
|
+
Args:
|
|
111
|
+
prompt (str): The input prompt string.
|
|
112
|
+
Returns:
|
|
113
|
+
str: Model response.
|
|
114
|
+
"""
|
|
115
|
+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
|
|
116
|
+
gen_kwargs = {
|
|
117
|
+
"max_new_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
118
|
+
"do_sample": getattr(self.config, "do_sample", True),
|
|
119
|
+
}
|
|
120
|
+
if self.config.do_sample:
|
|
121
|
+
gen_kwargs["temperature"] = kwargs.get("temperature", self.config.temperature)
|
|
122
|
+
gen_kwargs["top_k"] = kwargs.get("top_k", self.config.top_k)
|
|
123
|
+
gen_kwargs["top_p"] = kwargs.get("top_p", self.config.top_p)
|
|
124
|
+
gen_ids = self.model.generate(
|
|
125
|
+
**inputs,
|
|
126
|
+
**gen_kwargs,
|
|
127
|
+
)
|
|
128
|
+
new_ids = [
|
|
129
|
+
out_ids[len(src_ids) :]
|
|
130
|
+
for src_ids, out_ids in zip(inputs.input_ids, gen_ids, strict=False)
|
|
131
|
+
]
|
|
132
|
+
response = self.tokenizer.batch_decode(new_ids, skip_special_tokens=True)[0]
|
|
133
|
+
logger.info(f"Full-gen raw response: {response}")
|
|
134
|
+
return (
|
|
135
|
+
remove_thinking_tags(response)
|
|
136
|
+
if getattr(self.config, "remove_think_prefix", False)
|
|
137
|
+
else response
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _generate_full_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]:
|
|
141
|
+
"""
|
|
142
|
+
Generate output from scratch using the full prompt with streaming.
|
|
143
|
+
Args:
|
|
144
|
+
prompt (str): The input prompt string.
|
|
145
|
+
Yields:
|
|
146
|
+
str: Streaming response chunks.
|
|
147
|
+
"""
|
|
148
|
+
import torch
|
|
149
|
+
|
|
150
|
+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
|
|
151
|
+
|
|
152
|
+
# Get generation parameters
|
|
153
|
+
max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
|
|
154
|
+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
|
|
155
|
+
|
|
156
|
+
# Manual streaming generation
|
|
157
|
+
generated_ids = inputs.input_ids.clone()
|
|
158
|
+
accumulated_text = ""
|
|
159
|
+
|
|
160
|
+
for _ in range(max_new_tokens):
|
|
161
|
+
# Forward pass
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
outputs = self.model(
|
|
164
|
+
input_ids=generated_ids,
|
|
165
|
+
use_cache=True,
|
|
166
|
+
return_dict=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Get next token logits
|
|
170
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
171
|
+
|
|
172
|
+
# Apply logits processors if sampling
|
|
173
|
+
if getattr(self.config, "do_sample", True):
|
|
174
|
+
batch_size, _ = next_token_logits.size()
|
|
175
|
+
dummy_ids = torch.zeros(
|
|
176
|
+
(batch_size, 1), dtype=torch.long, device=next_token_logits.device
|
|
177
|
+
)
|
|
178
|
+
filtered_logits = self.logits_processors(dummy_ids, next_token_logits)
|
|
179
|
+
probs = torch.softmax(filtered_logits, dim=-1)
|
|
180
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
181
|
+
else:
|
|
182
|
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
183
|
+
|
|
184
|
+
# Check for EOS token
|
|
185
|
+
if self._should_stop(next_token):
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
# Append new token
|
|
189
|
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
|
190
|
+
|
|
191
|
+
# Decode and yield the new token
|
|
192
|
+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
193
|
+
if new_token_text: # Only yield non-empty tokens
|
|
194
|
+
accumulated_text += new_token_text
|
|
195
|
+
|
|
196
|
+
# Apply thinking tag removal if enabled
|
|
197
|
+
if remove_think_prefix:
|
|
198
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
199
|
+
# Only yield the difference (new content)
|
|
200
|
+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
|
|
201
|
+
yield processed_text[len(accumulated_text) - len(new_token_text) :]
|
|
202
|
+
else:
|
|
203
|
+
yield new_token_text
|
|
204
|
+
else:
|
|
205
|
+
yield new_token_text
|
|
206
|
+
|
|
207
|
+
def _generate_with_cache(self, query: str, kv: DynamicCache, **kwargs) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Generate output incrementally using an existing KV cache.
|
|
210
|
+
Args:
|
|
211
|
+
query (str): The new user query string.
|
|
212
|
+
kv (DynamicCache): The prefilled KV cache.
|
|
213
|
+
Returns:
|
|
214
|
+
str: Model response.
|
|
215
|
+
"""
|
|
216
|
+
import torch
|
|
217
|
+
|
|
218
|
+
query_ids = self.tokenizer(
|
|
219
|
+
query, return_tensors="pt", add_special_tokens=False
|
|
220
|
+
).input_ids.to(self.model.device)
|
|
221
|
+
logits, kv = self._prefill(query_ids, kv)
|
|
222
|
+
next_token = self._select_next_token(logits)
|
|
223
|
+
generated = [next_token]
|
|
224
|
+
for _ in range(kwargs.get("max_tokens", self.config.max_tokens) - 1):
|
|
225
|
+
if self._should_stop(next_token):
|
|
226
|
+
break
|
|
227
|
+
logits, kv = self._prefill(next_token, kv)
|
|
228
|
+
next_token = self._select_next_token(logits)
|
|
229
|
+
generated.append(next_token)
|
|
230
|
+
if generated:
|
|
231
|
+
concat = torch.cat(generated, dim=-1)
|
|
232
|
+
response = self.tokenizer.decode(concat[0], skip_special_tokens=True)
|
|
233
|
+
else:
|
|
234
|
+
response = ""
|
|
235
|
+
logger.info(f"Cache-gen raw response: {response}")
|
|
236
|
+
return (
|
|
237
|
+
remove_thinking_tags(response)
|
|
238
|
+
if getattr(self.config, "remove_think_prefix", False)
|
|
239
|
+
else response
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def _generate_with_cache_stream(
|
|
243
|
+
self, query: str, kv: DynamicCache, **kwargs
|
|
244
|
+
) -> Generator[str, None, None]:
|
|
245
|
+
"""
|
|
246
|
+
Generate output incrementally using an existing KV cache with streaming.
|
|
247
|
+
Args:
|
|
248
|
+
query (str): The new user query string.
|
|
249
|
+
kv (DynamicCache): The prefilled KV cache.
|
|
250
|
+
Yields:
|
|
251
|
+
str: Streaming response chunks.
|
|
252
|
+
"""
|
|
253
|
+
query_ids = self.tokenizer(
|
|
254
|
+
query, return_tensors="pt", add_special_tokens=False
|
|
255
|
+
).input_ids.to(self.model.device)
|
|
256
|
+
|
|
257
|
+
max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
|
|
258
|
+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
|
|
259
|
+
|
|
260
|
+
# Initial forward pass
|
|
261
|
+
logits, kv = self._prefill(query_ids, kv)
|
|
262
|
+
next_token = self._select_next_token(logits)
|
|
263
|
+
|
|
264
|
+
# Yield first token
|
|
265
|
+
first_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
266
|
+
accumulated_text = ""
|
|
267
|
+
if first_token_text:
|
|
268
|
+
accumulated_text += first_token_text
|
|
269
|
+
if remove_think_prefix:
|
|
270
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
271
|
+
if len(processed_text) > len(accumulated_text) - len(first_token_text):
|
|
272
|
+
yield processed_text[len(accumulated_text) - len(first_token_text) :]
|
|
273
|
+
else:
|
|
274
|
+
yield first_token_text
|
|
275
|
+
else:
|
|
276
|
+
yield first_token_text
|
|
277
|
+
|
|
278
|
+
generated = [next_token]
|
|
279
|
+
|
|
280
|
+
# Continue generation
|
|
281
|
+
for _ in range(max_new_tokens - 1):
|
|
282
|
+
if self._should_stop(next_token):
|
|
283
|
+
break
|
|
284
|
+
logits, kv = self._prefill(next_token, kv)
|
|
285
|
+
next_token = self._select_next_token(logits)
|
|
286
|
+
|
|
287
|
+
# Decode and yield the new token
|
|
288
|
+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
|
|
289
|
+
if new_token_text:
|
|
290
|
+
accumulated_text += new_token_text
|
|
291
|
+
|
|
292
|
+
# Apply thinking tag removal if enabled
|
|
293
|
+
if remove_think_prefix:
|
|
294
|
+
processed_text = remove_thinking_tags(accumulated_text)
|
|
295
|
+
# Only yield the difference (new content)
|
|
296
|
+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
|
|
297
|
+
yield processed_text[len(accumulated_text) - len(new_token_text) :]
|
|
298
|
+
else:
|
|
299
|
+
yield new_token_text
|
|
300
|
+
else:
|
|
301
|
+
yield new_token_text
|
|
302
|
+
|
|
303
|
+
generated.append(next_token)
|
|
304
|
+
|
|
305
|
+
def _prefill(self, input_ids: Any, kv: DynamicCache) -> tuple[Any, DynamicCache]:
|
|
306
|
+
"""
|
|
307
|
+
Forward the model once, returning last-step logits and updated KV cache.
|
|
308
|
+
Args:
|
|
309
|
+
input_ids (torch.Tensor): Input token IDs.
|
|
310
|
+
kv (DynamicCache): Existing KV cache.
|
|
311
|
+
Returns:
|
|
312
|
+
tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
|
|
313
|
+
"""
|
|
314
|
+
import torch
|
|
315
|
+
|
|
316
|
+
with torch.no_grad():
|
|
317
|
+
out = self.model(
|
|
318
|
+
input_ids=input_ids,
|
|
319
|
+
use_cache=True,
|
|
320
|
+
past_key_values=kv,
|
|
321
|
+
return_dict=True,
|
|
322
|
+
)
|
|
323
|
+
return out.logits[:, -1, :], out.past_key_values
|
|
324
|
+
|
|
325
|
+
def _select_next_token(self, logits: Any) -> Any:
|
|
326
|
+
"""
|
|
327
|
+
Select the next token from logits using sampling or argmax, depending on config.
|
|
328
|
+
Args:
|
|
329
|
+
logits (torch.Tensor): Logits for the next token.
|
|
330
|
+
Returns:
|
|
331
|
+
torch.Tensor: Selected token ID(s).
|
|
332
|
+
"""
|
|
333
|
+
import torch
|
|
334
|
+
|
|
335
|
+
if getattr(self.config, "do_sample", True):
|
|
336
|
+
batch_size, _ = logits.size()
|
|
337
|
+
dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device)
|
|
338
|
+
filtered = self.logits_processors(dummy_ids, logits)
|
|
339
|
+
probs = torch.softmax(filtered, dim=-1)
|
|
340
|
+
return torch.multinomial(probs, num_samples=1)
|
|
341
|
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
|
342
|
+
|
|
343
|
+
def _should_stop(self, token: Any) -> bool:
|
|
344
|
+
"""
|
|
345
|
+
Check if the given token is the EOS (end-of-sequence) token.
|
|
346
|
+
Args:
|
|
347
|
+
token (torch.Tensor): Token ID to check.
|
|
348
|
+
Returns:
|
|
349
|
+
bool: True if token is EOS, else False.
|
|
350
|
+
"""
|
|
351
|
+
eos_id = self.tokenizer.eos_token_id
|
|
352
|
+
return eos_id is not None and token.item() == eos_id
|
|
353
|
+
|
|
354
|
+
def build_kv_cache(self, messages) -> DynamicCache:
|
|
355
|
+
"""
|
|
356
|
+
Build a KV cache from chat messages via one forward pass.
|
|
357
|
+
Supports the following input types:
|
|
358
|
+
- str: Used as a system prompt.
|
|
359
|
+
- list[str]: Concatenated and used as a system prompt.
|
|
360
|
+
- list[dict]: Used directly as chat messages.
|
|
361
|
+
The messages are always converted to a standard chat template.
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If the resulting prompt is empty after template processing.
|
|
364
|
+
Returns:
|
|
365
|
+
DynamicCache: The constructed KV cache object.
|
|
366
|
+
"""
|
|
367
|
+
import torch
|
|
368
|
+
import transformers
|
|
369
|
+
|
|
370
|
+
# Accept multiple input types and convert to standard chat messages
|
|
371
|
+
if isinstance(messages, str):
|
|
372
|
+
messages = [
|
|
373
|
+
{
|
|
374
|
+
"role": "system",
|
|
375
|
+
"content": f"Below is some information about the user.\n{messages}",
|
|
376
|
+
}
|
|
377
|
+
]
|
|
378
|
+
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
|
|
379
|
+
messages = [
|
|
380
|
+
{
|
|
381
|
+
"role": "system",
|
|
382
|
+
"content": f"Below is some information about the user.\n{' '.join(messages)}",
|
|
383
|
+
}
|
|
384
|
+
]
|
|
385
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
386
|
+
messages, tokenize=False, add_generation_prompt=False
|
|
387
|
+
)
|
|
388
|
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
389
|
+
inputs["input_ids"] = inputs["input_ids"].to(self.model.device, dtype=torch.long)
|
|
390
|
+
seq_len = inputs["input_ids"].size(-1)
|
|
391
|
+
if seq_len == 0:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
"Prompt after chat template is empty, cannot build KV cache. Check your messages input."
|
|
394
|
+
)
|
|
395
|
+
# Create cache and perform forward pass without pre-existing cache
|
|
396
|
+
with torch.no_grad():
|
|
397
|
+
outputs = self.model(**inputs, use_cache=True)
|
|
398
|
+
|
|
399
|
+
# Get the cache from model outputs
|
|
400
|
+
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
|
|
401
|
+
kv = outputs.past_key_values
|
|
402
|
+
|
|
403
|
+
# Convert from legacy tuple format to DynamicCache if needed
|
|
404
|
+
if isinstance(kv, tuple):
|
|
405
|
+
kv = transformers.DynamicCache.from_legacy_cache(kv)
|
|
406
|
+
|
|
407
|
+
# Handle compatibility between old and new transformers versions
|
|
408
|
+
# In newer versions, DynamicCache uses 'layers' attribute
|
|
409
|
+
# In older versions, it uses 'key_cache' and 'value_cache' attributes
|
|
410
|
+
if hasattr(kv, "layers"):
|
|
411
|
+
# New version: trim cache using layers attribute
|
|
412
|
+
for layer in kv.layers:
|
|
413
|
+
if hasattr(layer, "key_cache") and hasattr(layer, "value_cache"):
|
|
414
|
+
# Trim each layer's cache to the sequence length
|
|
415
|
+
if layer.key_cache is not None:
|
|
416
|
+
layer.key_cache = layer.key_cache[:, :, :seq_len, :]
|
|
417
|
+
if layer.value_cache is not None:
|
|
418
|
+
layer.value_cache = layer.value_cache[:, :, :seq_len, :]
|
|
419
|
+
elif hasattr(layer, "keys") and hasattr(layer, "values"):
|
|
420
|
+
# Alternative attribute names in some versions
|
|
421
|
+
if layer.keys is not None:
|
|
422
|
+
layer.keys = layer.keys[:, :, :seq_len, :]
|
|
423
|
+
if layer.values is not None:
|
|
424
|
+
layer.values = layer.values[:, :, :seq_len, :]
|
|
425
|
+
elif hasattr(kv, "key_cache") and hasattr(kv, "value_cache"):
|
|
426
|
+
# Old version: trim cache using key_cache and value_cache attributes
|
|
427
|
+
for i in range(len(kv.key_cache)):
|
|
428
|
+
if kv.key_cache[i] is not None:
|
|
429
|
+
kv.key_cache[i] = kv.key_cache[i][:, :, :seq_len, :]
|
|
430
|
+
if kv.value_cache[i] is not None:
|
|
431
|
+
kv.value_cache[i] = kv.value_cache[i][:, :, :seq_len, :]
|
|
432
|
+
else:
|
|
433
|
+
# Fallback: log warning but continue without trimming
|
|
434
|
+
logger.warning(
|
|
435
|
+
f"DynamicCache object of type {type(kv)} has unexpected structure. "
|
|
436
|
+
f"Cache trimming skipped. Available attributes: {dir(kv)}"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
return kv
|
|
440
|
+
else:
|
|
441
|
+
raise RuntimeError(
|
|
442
|
+
"Failed to build KV cache: no cache data available from model outputs"
|
|
443
|
+
)
|