MemoryOS 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/METADATA +78 -49
- memoryos-0.2.1.dist-info/RECORD +152 -0
- memoryos-0.2.1.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +471 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +159 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +358 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +83 -2
- memos/configs/llm.py +48 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_reader.py +4 -0
- memos/configs/mem_scheduler.py +91 -5
- memos/configs/memory.py +10 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +2 -2
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/neo4j.py +377 -101
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +68 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +102 -0
- memos/mem_os/core.py +131 -41
- memos/mem_os/main.py +93 -11
- memos/mem_os/product.py +1098 -35
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1154 -0
- memos/mem_reader/simple_struct.py +13 -8
- memos/mem_scheduler/base_scheduler.py +467 -36
- memos/mem_scheduler/general_scheduler.py +125 -244
- memos/mem_scheduler/modules/base.py +9 -0
- memos/mem_scheduler/modules/dispatcher.py +68 -2
- memos/mem_scheduler/modules/misc.py +39 -0
- memos/mem_scheduler/modules/monitor.py +228 -49
- memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +250 -23
- memos/mem_scheduler/modules/schemas.py +189 -7
- memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
- memos/mem_scheduler/utils.py +51 -2
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/general.py +7 -5
- memos/memories/textual/item.py +3 -1
- memos/memories/textual/tree.py +14 -6
- memos/memories/textual/tree_text_memory/organize/conflict.py +198 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +72 -23
- memos/memories/textual/tree_text_memory/organize/redundancy.py +193 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +233 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +606 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
- memos/parsers/markitdown.py +8 -2
- memos/templates/mem_reader_prompts.py +105 -36
- memos/templates/mem_scheduler_prompts.py +96 -47
- memos/templates/tree_reorganize_prompts.py +223 -0
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- memoryos-0.1.13.dist-info/RECORD +0 -122
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
- {memoryos-0.1.13.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
from memos.configs.llm import HFLLMConfig
|
|
6
|
+
from memos.llms.hf import HFLLM
|
|
7
|
+
from memos.log import get_logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HFSingletonLLM(HFLLM):
|
|
14
|
+
"""
|
|
15
|
+
Singleton version of HFLLM that prevents multiple loading of the same model.
|
|
16
|
+
This class inherits from HFLLM and adds singleton behavior.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
_instances: ClassVar[dict[str, "HFSingletonLLM"]] = {}
|
|
20
|
+
_lock: ClassVar[threading.Lock] = threading.Lock()
|
|
21
|
+
|
|
22
|
+
def __new__(cls, config: HFLLMConfig):
|
|
23
|
+
"""
|
|
24
|
+
Singleton pattern implementation.
|
|
25
|
+
Returns existing instance if config already exists, otherwise creates new one.
|
|
26
|
+
"""
|
|
27
|
+
config_key = cls._get_config_key(config)
|
|
28
|
+
|
|
29
|
+
if config_key in cls._instances:
|
|
30
|
+
logger.debug(f"Reusing existing HF model: {config.model_name_or_path}")
|
|
31
|
+
return cls._instances[config_key]
|
|
32
|
+
|
|
33
|
+
with cls._lock:
|
|
34
|
+
# Double-check pattern to prevent race conditions
|
|
35
|
+
if config_key in cls._instances:
|
|
36
|
+
logger.debug(f"Reusing existing HF model: {config.model_name_or_path}")
|
|
37
|
+
return cls._instances[config_key]
|
|
38
|
+
|
|
39
|
+
logger.info(f"Creating new HF model: {config.model_name_or_path}")
|
|
40
|
+
instance = super().__new__(cls)
|
|
41
|
+
cls._instances[config_key] = instance
|
|
42
|
+
return instance
|
|
43
|
+
|
|
44
|
+
def __init__(self, config: HFLLMConfig):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the singleton HFLLM instance.
|
|
47
|
+
Only initializes if this is a new instance.
|
|
48
|
+
"""
|
|
49
|
+
# Check if already initialized
|
|
50
|
+
if hasattr(self, "_initialized"):
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
# Call parent constructor
|
|
54
|
+
super().__init__(config)
|
|
55
|
+
self._initialized = True
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def _get_config_key(cls, config: HFLLMConfig) -> str:
|
|
59
|
+
"""
|
|
60
|
+
Generate a unique key for the HF model configuration.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
config: The HFLLM configuration
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
A unique string key representing the configuration
|
|
67
|
+
"""
|
|
68
|
+
# Create a unique key based on model path and key parameters
|
|
69
|
+
key_parts = [config.model_name_or_path]
|
|
70
|
+
return "|".join(key_parts)
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_instance_count(cls) -> int:
|
|
74
|
+
"""
|
|
75
|
+
Get the number of unique HF model instances currently managed.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Number of HF model instances
|
|
79
|
+
"""
|
|
80
|
+
return len(cls._instances)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def get_instance_info(cls) -> dict[str, str]:
|
|
84
|
+
"""
|
|
85
|
+
Get information about all managed HF model instances.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dictionary mapping config keys to model paths
|
|
89
|
+
"""
|
|
90
|
+
return {key: instance.config.model_name_or_path for key, instance in cls._instances.items()}
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def clear_all(cls) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Clear all HF model instances from memory.
|
|
96
|
+
This should be used carefully as it will force reloading of models.
|
|
97
|
+
"""
|
|
98
|
+
with cls._lock:
|
|
99
|
+
cls._instances.clear()
|
|
100
|
+
logger.info("All HF model instances cleared from singleton manager")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Convenience function to get singleton manager info
|
|
104
|
+
def get_hf_singleton_info() -> dict[str, int]:
|
|
105
|
+
"""
|
|
106
|
+
Get information about the HF singleton manager.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dictionary with instance count and info
|
|
110
|
+
"""
|
|
111
|
+
return {
|
|
112
|
+
"instance_count": HFSingletonLLM.get_instance_count(),
|
|
113
|
+
"instance_info": HFSingletonLLM.get_instance_info(),
|
|
114
|
+
}
|
memos/llms/ollama.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
from ollama import Client
|
|
@@ -80,3 +81,6 @@ class OllamaLLM(BaseLLM):
|
|
|
80
81
|
return remove_thinking_tags(str_response)
|
|
81
82
|
else:
|
|
82
83
|
return str_response
|
|
84
|
+
|
|
85
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
86
|
+
raise NotImplementedError
|
memos/llms/openai.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
|
|
1
3
|
import openai
|
|
2
4
|
|
|
3
|
-
from memos.configs.llm import OpenAILLMConfig
|
|
5
|
+
from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig
|
|
4
6
|
from memos.llms.base import BaseLLM
|
|
5
7
|
from memos.llms.utils import remove_thinking_tags
|
|
6
8
|
from memos.log import get_logger
|
|
@@ -22,6 +24,7 @@ class OpenAILLM(BaseLLM):
|
|
|
22
24
|
response = self.client.chat.completions.create(
|
|
23
25
|
model=self.config.model_name_or_path,
|
|
24
26
|
messages=messages,
|
|
27
|
+
extra_body=self.config.extra_body,
|
|
25
28
|
temperature=self.config.temperature,
|
|
26
29
|
max_tokens=self.config.max_tokens,
|
|
27
30
|
top_p=self.config.top_p,
|
|
@@ -32,3 +35,67 @@ class OpenAILLM(BaseLLM):
|
|
|
32
35
|
return remove_thinking_tags(response_content)
|
|
33
36
|
else:
|
|
34
37
|
return response_content
|
|
38
|
+
|
|
39
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
40
|
+
"""Stream response from OpenAI LLM with optional reasoning support."""
|
|
41
|
+
response = self.client.chat.completions.create(
|
|
42
|
+
model=self.config.model_name_or_path,
|
|
43
|
+
messages=messages,
|
|
44
|
+
stream=True,
|
|
45
|
+
temperature=self.config.temperature,
|
|
46
|
+
max_tokens=self.config.max_tokens,
|
|
47
|
+
top_p=self.config.top_p,
|
|
48
|
+
extra_body=self.config.extra_body,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
reasoning_started = False
|
|
52
|
+
|
|
53
|
+
for chunk in response:
|
|
54
|
+
delta = chunk.choices[0].delta
|
|
55
|
+
|
|
56
|
+
# Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen)
|
|
57
|
+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
58
|
+
if not reasoning_started and not self.config.remove_think_prefix:
|
|
59
|
+
yield "<think>"
|
|
60
|
+
reasoning_started = True
|
|
61
|
+
yield delta.reasoning_content
|
|
62
|
+
elif hasattr(delta, "content") and delta.content:
|
|
63
|
+
if reasoning_started and not self.config.remove_think_prefix:
|
|
64
|
+
yield "</think>"
|
|
65
|
+
reasoning_started = False
|
|
66
|
+
yield delta.content
|
|
67
|
+
|
|
68
|
+
# Ensure we close the <think> block if not already done
|
|
69
|
+
if reasoning_started and not self.config.remove_think_prefix:
|
|
70
|
+
yield "</think>"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AzureLLM(BaseLLM):
|
|
74
|
+
"""Azure OpenAI LLM class."""
|
|
75
|
+
|
|
76
|
+
def __init__(self, config: AzureLLMConfig):
|
|
77
|
+
self.config = config
|
|
78
|
+
self.client = openai.AzureOpenAI(
|
|
79
|
+
azure_endpoint=config.base_url,
|
|
80
|
+
api_version=config.api_version,
|
|
81
|
+
api_key=config.api_key,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def generate(self, messages: MessageList) -> str:
|
|
85
|
+
"""Generate a response from Azure OpenAI LLM."""
|
|
86
|
+
response = self.client.chat.completions.create(
|
|
87
|
+
model=self.config.model_name_or_path,
|
|
88
|
+
messages=messages,
|
|
89
|
+
temperature=self.config.temperature,
|
|
90
|
+
max_tokens=self.config.max_tokens,
|
|
91
|
+
top_p=self.config.top_p,
|
|
92
|
+
)
|
|
93
|
+
logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}")
|
|
94
|
+
response_content = response.choices[0].message.content
|
|
95
|
+
if self.config.remove_think_prefix:
|
|
96
|
+
return remove_thinking_tags(response_content)
|
|
97
|
+
else:
|
|
98
|
+
return response_content
|
|
99
|
+
|
|
100
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
101
|
+
raise NotImplementedError
|
memos/llms/qwen.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
|
|
3
|
+
from memos.configs.llm import QwenLLMConfig
|
|
4
|
+
from memos.llms.openai import OpenAILLM
|
|
5
|
+
from memos.llms.utils import remove_thinking_tags
|
|
6
|
+
from memos.log import get_logger
|
|
7
|
+
from memos.types import MessageList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class QwenLLM(OpenAILLM):
|
|
14
|
+
"""Qwen (DashScope) LLM class via OpenAI-compatible API."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: QwenLLMConfig):
|
|
17
|
+
super().__init__(config)
|
|
18
|
+
|
|
19
|
+
def generate(self, messages: MessageList) -> str:
|
|
20
|
+
"""Generate a response from Qwen LLM."""
|
|
21
|
+
response = self.client.chat.completions.create(
|
|
22
|
+
model=self.config.model_name_or_path,
|
|
23
|
+
messages=messages,
|
|
24
|
+
extra_body=self.config.extra_body,
|
|
25
|
+
temperature=self.config.temperature,
|
|
26
|
+
max_tokens=self.config.max_tokens,
|
|
27
|
+
top_p=self.config.top_p,
|
|
28
|
+
)
|
|
29
|
+
logger.info(f"Response from Qwen: {response.model_dump_json()}")
|
|
30
|
+
response_content = response.choices[0].message.content
|
|
31
|
+
if self.config.remove_think_prefix:
|
|
32
|
+
return remove_thinking_tags(response_content)
|
|
33
|
+
else:
|
|
34
|
+
return response_content
|
|
35
|
+
|
|
36
|
+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
|
|
37
|
+
"""Stream response from Qwen LLM."""
|
|
38
|
+
response = self.client.chat.completions.create(
|
|
39
|
+
model=self.config.model_name_or_path,
|
|
40
|
+
messages=messages,
|
|
41
|
+
stream=True,
|
|
42
|
+
temperature=self.config.temperature,
|
|
43
|
+
max_tokens=self.config.max_tokens,
|
|
44
|
+
top_p=self.config.top_p,
|
|
45
|
+
extra_body=self.config.extra_body,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
reasoning_started = False
|
|
49
|
+
for chunk in response:
|
|
50
|
+
delta = chunk.choices[0].delta
|
|
51
|
+
|
|
52
|
+
# Some models may have separate `reasoning_content` vs `content`
|
|
53
|
+
# For Qwen (DashScope), likely only `content` is used
|
|
54
|
+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
55
|
+
if not reasoning_started and not self.config.remove_think_prefix:
|
|
56
|
+
yield "<think>"
|
|
57
|
+
reasoning_started = True
|
|
58
|
+
yield delta.reasoning_content
|
|
59
|
+
elif hasattr(delta, "content") and delta.content:
|
|
60
|
+
if reasoning_started and not self.config.remove_think_prefix:
|
|
61
|
+
yield "</think>"
|
|
62
|
+
reasoning_started = False
|
|
63
|
+
yield delta.content
|
memos/llms/vllm.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any, cast
|
|
2
|
+
|
|
3
|
+
from memos.configs.llm import VLLMLLMConfig
|
|
4
|
+
from memos.llms.base import BaseLLM
|
|
5
|
+
from memos.llms.utils import remove_thinking_tags
|
|
6
|
+
from memos.log import get_logger
|
|
7
|
+
from memos.types import MessageDict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VLLMLLM(BaseLLM):
|
|
14
|
+
"""
|
|
15
|
+
VLLM LLM class for connecting to existing vLLM servers.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: VLLMLLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Initialize the VLLM LLM to connect to an existing vLLM server.
|
|
21
|
+
"""
|
|
22
|
+
self.config = config
|
|
23
|
+
|
|
24
|
+
# Initialize OpenAI client for API calls
|
|
25
|
+
self.client = None
|
|
26
|
+
api_key = getattr(self.config, "api_key", "dummy")
|
|
27
|
+
if not api_key:
|
|
28
|
+
api_key = "dummy"
|
|
29
|
+
|
|
30
|
+
import openai
|
|
31
|
+
|
|
32
|
+
self.client = openai.Client(
|
|
33
|
+
api_key=api_key, base_url=getattr(self.config, "api_base", "http://localhost:8088/v1")
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def build_vllm_kv_cache(self, messages: Any) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Build a KV cache from chat messages via one vLLM request.
|
|
39
|
+
Handles str, list[str], and MessageList formats.
|
|
40
|
+
"""
|
|
41
|
+
# 1. Normalize input to a MessageList
|
|
42
|
+
processed_messages: list[MessageDict] = []
|
|
43
|
+
if isinstance(messages, str):
|
|
44
|
+
processed_messages = [
|
|
45
|
+
{
|
|
46
|
+
"role": "system",
|
|
47
|
+
"content": f"Below is some information about the user.\n{messages}",
|
|
48
|
+
}
|
|
49
|
+
]
|
|
50
|
+
elif isinstance(messages, list):
|
|
51
|
+
if not messages:
|
|
52
|
+
pass # Empty list
|
|
53
|
+
elif isinstance(messages[0], str):
|
|
54
|
+
str_content = " ".join(str(msg) for msg in messages)
|
|
55
|
+
processed_messages = [
|
|
56
|
+
{
|
|
57
|
+
"role": "system",
|
|
58
|
+
"content": f"Below is some information about the user.\n{str_content}",
|
|
59
|
+
}
|
|
60
|
+
]
|
|
61
|
+
elif isinstance(messages[0], dict):
|
|
62
|
+
processed_messages = cast("list[MessageDict]", messages)
|
|
63
|
+
|
|
64
|
+
# 2. Convert to prompt for logging/return value.
|
|
65
|
+
prompt = self._messages_to_prompt(processed_messages)
|
|
66
|
+
|
|
67
|
+
if not prompt.strip():
|
|
68
|
+
raise ValueError("Prompt is empty, cannot build KV cache.")
|
|
69
|
+
|
|
70
|
+
# 3. Send request to vLLM server to preload the KV cache
|
|
71
|
+
if self.client:
|
|
72
|
+
try:
|
|
73
|
+
# Use the processed messages for the API call
|
|
74
|
+
prefill_kwargs = {
|
|
75
|
+
"model": self.config.model_name_or_path,
|
|
76
|
+
"messages": processed_messages,
|
|
77
|
+
"max_tokens": 2,
|
|
78
|
+
"temperature": 0.0,
|
|
79
|
+
"top_p": 1.0,
|
|
80
|
+
}
|
|
81
|
+
self.client.chat.completions.create(**prefill_kwargs)
|
|
82
|
+
logger.info(f"vLLM KV cache prefill completed for prompt: '{prompt[:100]}...'")
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.warning(f"Failed to prefill vLLM KV cache: {e}")
|
|
85
|
+
|
|
86
|
+
return prompt
|
|
87
|
+
|
|
88
|
+
def generate(self, messages: list[MessageDict]) -> str:
|
|
89
|
+
"""
|
|
90
|
+
Generate a response from the model.
|
|
91
|
+
"""
|
|
92
|
+
if self.client:
|
|
93
|
+
return self._generate_with_api_client(messages)
|
|
94
|
+
else:
|
|
95
|
+
raise RuntimeError("API client is not available")
|
|
96
|
+
|
|
97
|
+
def _generate_with_api_client(self, messages: list[MessageDict]) -> str:
|
|
98
|
+
"""
|
|
99
|
+
Generate response using vLLM API client.
|
|
100
|
+
"""
|
|
101
|
+
if self.client:
|
|
102
|
+
completion_kwargs = {
|
|
103
|
+
"model": self.config.model_name_or_path,
|
|
104
|
+
"messages": messages,
|
|
105
|
+
"temperature": float(getattr(self.config, "temperature", 0.8)),
|
|
106
|
+
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
|
|
107
|
+
"top_p": float(getattr(self.config, "top_p", 0.9)),
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
response = self.client.chat.completions.create(**completion_kwargs)
|
|
111
|
+
response_text = response.choices[0].message.content or ""
|
|
112
|
+
logger.info(f"VLLM API response: {response_text}")
|
|
113
|
+
return (
|
|
114
|
+
remove_thinking_tags(response_text)
|
|
115
|
+
if getattr(self.config, "remove_think_prefix", False)
|
|
116
|
+
else response_text
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
raise RuntimeError("API client is not available")
|
|
120
|
+
|
|
121
|
+
def _messages_to_prompt(self, messages: list[MessageDict]) -> str:
|
|
122
|
+
"""
|
|
123
|
+
Convert messages to prompt string.
|
|
124
|
+
"""
|
|
125
|
+
prompt_parts = []
|
|
126
|
+
for msg in messages:
|
|
127
|
+
role = msg["role"]
|
|
128
|
+
content = msg["content"]
|
|
129
|
+
prompt_parts.append(f"{role.capitalize()}: {content}")
|
|
130
|
+
return "\n".join(prompt_parts)
|
|
131
|
+
|
|
132
|
+
def generate_stream(self, messages: list[MessageDict]):
|
|
133
|
+
"""
|
|
134
|
+
Generate a response from the model using streaming.
|
|
135
|
+
Yields content chunks as they are received.
|
|
136
|
+
"""
|
|
137
|
+
if self.client:
|
|
138
|
+
completion_kwargs = {
|
|
139
|
+
"model": self.config.model_name_or_path,
|
|
140
|
+
"messages": messages,
|
|
141
|
+
"temperature": float(getattr(self.config, "temperature", 0.8)),
|
|
142
|
+
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
|
|
143
|
+
"top_p": float(getattr(self.config, "top_p", 0.9)),
|
|
144
|
+
"stream": True, # Enable streaming
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
stream = self.client.chat.completions.create(**completion_kwargs)
|
|
148
|
+
for chunk in stream:
|
|
149
|
+
content = chunk.choices[0].delta.content
|
|
150
|
+
if content:
|
|
151
|
+
yield content
|
|
152
|
+
else:
|
|
153
|
+
raise RuntimeError("API client is not available")
|
memos/mem_cube/general.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
3
5
|
from memos.configs.mem_cube import GeneralMemCubeConfig
|
|
4
6
|
from memos.configs.utils import get_json_file_model_schema
|
|
5
7
|
from memos.exceptions import ConfigurationError, MemCubeError
|
|
6
8
|
from memos.log import get_logger
|
|
7
9
|
from memos.mem_cube.base import BaseMemCube
|
|
8
|
-
from memos.mem_cube.utils import download_repo
|
|
10
|
+
from memos.mem_cube.utils import download_repo, merge_config_with_default
|
|
9
11
|
from memos.memories.activation.base import BaseActMemory
|
|
10
12
|
from memos.memories.factory import MemoryFactory
|
|
11
13
|
from memos.memories.parametric.base import BaseParaMemory
|
|
@@ -37,10 +39,15 @@ class GeneralMemCube(BaseMemCube):
|
|
|
37
39
|
else None
|
|
38
40
|
)
|
|
39
41
|
|
|
40
|
-
def load(
|
|
42
|
+
def load(
|
|
43
|
+
self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
|
|
44
|
+
) -> None:
|
|
41
45
|
"""Load memories.
|
|
42
46
|
Args:
|
|
43
47
|
dir (str): The directory containing the memory files.
|
|
48
|
+
memory_types (list[str], optional): List of memory types to load.
|
|
49
|
+
If None, loads all available memory types.
|
|
50
|
+
Options: ["text_mem", "act_mem", "para_mem"]
|
|
44
51
|
"""
|
|
45
52
|
loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename))
|
|
46
53
|
if loaded_schema != self.config.model_schema:
|
|
@@ -48,60 +55,114 @@ class GeneralMemCube(BaseMemCube):
|
|
|
48
55
|
f"Configuration schema mismatch. Expected {self.config.model_schema}, "
|
|
49
56
|
f"but found {loaded_schema}."
|
|
50
57
|
)
|
|
51
|
-
self.text_mem.load(dir) if self.text_mem else None
|
|
52
|
-
self.act_mem.load(dir) if self.act_mem else None
|
|
53
|
-
self.para_mem.load(dir) if self.para_mem else None
|
|
54
58
|
|
|
55
|
-
|
|
59
|
+
# If no specific memory types specified, load all
|
|
60
|
+
if memory_types is None:
|
|
61
|
+
memory_types = ["text_mem", "act_mem", "para_mem"]
|
|
62
|
+
|
|
63
|
+
# Load specified memory types
|
|
64
|
+
if "text_mem" in memory_types and self.text_mem:
|
|
65
|
+
self.text_mem.load(dir)
|
|
66
|
+
logger.debug(f"Loaded text_mem from {dir}")
|
|
67
|
+
|
|
68
|
+
if "act_mem" in memory_types and self.act_mem:
|
|
69
|
+
self.act_mem.load(dir)
|
|
70
|
+
logger.info(f"Loaded act_mem from {dir}")
|
|
56
71
|
|
|
57
|
-
|
|
72
|
+
if "para_mem" in memory_types and self.para_mem:
|
|
73
|
+
self.para_mem.load(dir)
|
|
74
|
+
logger.info(f"Loaded para_mem from {dir}")
|
|
75
|
+
|
|
76
|
+
logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})")
|
|
77
|
+
|
|
78
|
+
def dump(
|
|
79
|
+
self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
|
|
80
|
+
) -> None:
|
|
58
81
|
"""Dump memories.
|
|
59
82
|
Args:
|
|
60
83
|
dir (str): The directory where the memory files will be saved.
|
|
84
|
+
memory_types (list[str], optional): List of memory types to dump.
|
|
85
|
+
If None, dumps all available memory types.
|
|
86
|
+
Options: ["text_mem", "act_mem", "para_mem"]
|
|
61
87
|
"""
|
|
62
88
|
if os.path.exists(dir) and os.listdir(dir):
|
|
63
89
|
raise MemCubeError(
|
|
64
90
|
f"Directory {dir} is not empty. Please provide an empty directory for dumping."
|
|
65
91
|
)
|
|
66
92
|
|
|
93
|
+
# Always dump config
|
|
67
94
|
self.config.to_json_file(os.path.join(dir, self.config.config_filename))
|
|
68
|
-
self.text_mem.dump(dir) if self.text_mem else None
|
|
69
|
-
self.act_mem.dump(dir) if self.act_mem else None
|
|
70
|
-
self.para_mem.dump(dir) if self.para_mem else None
|
|
71
95
|
|
|
72
|
-
|
|
96
|
+
# If no specific memory types specified, dump all
|
|
97
|
+
if memory_types is None:
|
|
98
|
+
memory_types = ["text_mem", "act_mem", "para_mem"]
|
|
99
|
+
|
|
100
|
+
# Dump specified memory types
|
|
101
|
+
if "text_mem" in memory_types and self.text_mem:
|
|
102
|
+
self.text_mem.dump(dir)
|
|
103
|
+
logger.info(f"Dumped text_mem to {dir}")
|
|
104
|
+
|
|
105
|
+
if "act_mem" in memory_types and self.act_mem:
|
|
106
|
+
self.act_mem.dump(dir)
|
|
107
|
+
logger.info(f"Dumped act_mem to {dir}")
|
|
108
|
+
|
|
109
|
+
if "para_mem" in memory_types and self.para_mem:
|
|
110
|
+
self.para_mem.dump(dir)
|
|
111
|
+
logger.info(f"Dumped para_mem to {dir}")
|
|
112
|
+
|
|
113
|
+
logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})")
|
|
73
114
|
|
|
74
115
|
@staticmethod
|
|
75
|
-
def init_from_dir(
|
|
116
|
+
def init_from_dir(
|
|
117
|
+
dir: str,
|
|
118
|
+
memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
|
|
119
|
+
default_config: GeneralMemCubeConfig | None = None,
|
|
120
|
+
) -> "GeneralMemCube":
|
|
76
121
|
"""Create a MemCube instance from a MemCube directory.
|
|
77
122
|
|
|
78
123
|
Args:
|
|
79
124
|
dir (str): The directory containing the memory files.
|
|
125
|
+
memory_types (list[str], optional): List of memory types to load.
|
|
126
|
+
If None, loads all available memory types.
|
|
127
|
+
default_config (GeneralMemCubeConfig, optional): Default configuration to merge with existing config.
|
|
128
|
+
If provided, will merge general settings while preserving critical user-specific fields.
|
|
80
129
|
|
|
81
130
|
Returns:
|
|
82
131
|
MemCube: An instance of MemCube loaded with memories from the specified directory.
|
|
83
132
|
"""
|
|
84
133
|
config_path = os.path.join(dir, "config.json")
|
|
85
134
|
config = GeneralMemCubeConfig.from_json_file(config_path)
|
|
135
|
+
|
|
136
|
+
# Merge with default config if provided
|
|
137
|
+
if default_config is not None:
|
|
138
|
+
config = merge_config_with_default(config, default_config)
|
|
139
|
+
logger.info(f"Applied default config to cube {config.cube_id}")
|
|
140
|
+
|
|
86
141
|
mem_cube = GeneralMemCube(config)
|
|
87
|
-
mem_cube.load(dir)
|
|
142
|
+
mem_cube.load(dir, memory_types)
|
|
88
143
|
return mem_cube
|
|
89
144
|
|
|
90
145
|
@staticmethod
|
|
91
146
|
def init_from_remote_repo(
|
|
92
|
-
cube_id: str,
|
|
147
|
+
cube_id: str,
|
|
148
|
+
base_url: str = "https://huggingface.co/datasets",
|
|
149
|
+
memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
|
|
150
|
+
default_config: GeneralMemCubeConfig | None = None,
|
|
93
151
|
) -> "GeneralMemCube":
|
|
94
152
|
"""Create a MemCube instance from a remote repository.
|
|
95
153
|
|
|
96
154
|
Args:
|
|
97
|
-
|
|
155
|
+
cube_id (str): The repository name.
|
|
98
156
|
base_url (str): The base URL of the remote repository.
|
|
157
|
+
memory_types (list[str], optional): List of memory types to load.
|
|
158
|
+
If None, loads all available memory types.
|
|
159
|
+
default_config (GeneralMemCubeConfig, optional): Default configuration to merge with existing config.
|
|
99
160
|
|
|
100
161
|
Returns:
|
|
101
162
|
MemCube: An instance of MemCube loaded with memories from the specified remote repository.
|
|
102
163
|
"""
|
|
103
164
|
dir = download_repo(cube_id, base_url)
|
|
104
|
-
return GeneralMemCube.init_from_dir(dir)
|
|
165
|
+
return GeneralMemCube.init_from_dir(dir, memory_types, default_config)
|
|
105
166
|
|
|
106
167
|
@property
|
|
107
168
|
def text_mem(self) -> "BaseTextMemory | None":
|