MemoryOS 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/METADATA +66 -26
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/RECORD +80 -56
- memoryos-0.2.1.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +471 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +159 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +358 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +83 -2
- memos/configs/llm.py +47 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_scheduler.py +91 -5
- memos/configs/memory.py +5 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +2 -2
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/neo4j.py +331 -122
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +67 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +102 -0
- memos/mem_os/core.py +131 -41
- memos/mem_os/main.py +93 -11
- memos/mem_os/product.py +1098 -35
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1154 -0
- memos/mem_reader/simple_struct.py +5 -5
- memos/mem_scheduler/base_scheduler.py +467 -36
- memos/mem_scheduler/general_scheduler.py +125 -244
- memos/mem_scheduler/modules/base.py +9 -0
- memos/mem_scheduler/modules/dispatcher.py +68 -2
- memos/mem_scheduler/modules/misc.py +39 -0
- memos/mem_scheduler/modules/monitor.py +228 -49
- memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +250 -23
- memos/mem_scheduler/modules/schemas.py +189 -7
- memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
- memos/mem_scheduler/utils.py +51 -2
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/general.py +7 -5
- memos/memories/textual/tree.py +9 -5
- memos/memories/textual/tree_text_memory/organize/conflict.py +5 -3
- memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
- memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +11 -13
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +73 -51
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
- memos/parsers/markitdown.py +8 -2
- memos/templates/mem_reader_prompts.py +65 -23
- memos/templates/mem_scheduler_prompts.py +96 -47
- memos/templates/tree_reorganize_prompts.py +85 -30
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
memos/configs/llm.py
CHANGED
|
@@ -27,6 +27,40 @@ class OpenAILLMConfig(BaseLLMConfig):
|
|
|
27
27
|
extra_body: Any = Field(default=None, description="extra body")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class QwenLLMConfig(BaseLLMConfig):
|
|
31
|
+
api_key: str = Field(..., description="API key for DashScope (Qwen)")
|
|
32
|
+
api_base: str = Field(
|
|
33
|
+
default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
34
|
+
description="Base URL for Qwen OpenAI-compatible API",
|
|
35
|
+
)
|
|
36
|
+
extra_body: Any = Field(default=None, description="extra body")
|
|
37
|
+
model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DeepSeekLLMConfig(BaseLLMConfig):
|
|
41
|
+
api_key: str = Field(..., description="API key for DeepSeek")
|
|
42
|
+
api_base: str = Field(
|
|
43
|
+
default="https://api.deepseek.com",
|
|
44
|
+
description="Base URL for DeepSeek OpenAI-compatible API",
|
|
45
|
+
)
|
|
46
|
+
extra_body: Any = Field(default=None, description="Extra options for API")
|
|
47
|
+
model_name_or_path: str = Field(
|
|
48
|
+
..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AzureLLMConfig(BaseLLMConfig):
|
|
53
|
+
base_url: str = Field(
|
|
54
|
+
default="https://api.openai.azure.com/",
|
|
55
|
+
description="Base URL for Azure OpenAI API",
|
|
56
|
+
)
|
|
57
|
+
api_version: str = Field(
|
|
58
|
+
default="2024-03-01-preview",
|
|
59
|
+
description="API version for Azure OpenAI",
|
|
60
|
+
)
|
|
61
|
+
api_key: str = Field(..., description="API key for Azure OpenAI")
|
|
62
|
+
|
|
63
|
+
|
|
30
64
|
class OllamaLLMConfig(BaseLLMConfig):
|
|
31
65
|
api_base: str = Field(
|
|
32
66
|
default="http://localhost:11434",
|
|
@@ -45,6 +79,14 @@ class HFLLMConfig(BaseLLMConfig):
|
|
|
45
79
|
)
|
|
46
80
|
|
|
47
81
|
|
|
82
|
+
class VLLMLLMConfig(BaseLLMConfig):
|
|
83
|
+
api_key: str = Field(default="", description="API key for vLLM (optional for local server)")
|
|
84
|
+
api_base: str = Field(
|
|
85
|
+
default="http://localhost:8088/v1",
|
|
86
|
+
description="Base URL for vLLM API",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
48
90
|
class LLMConfigFactory(BaseConfig):
|
|
49
91
|
"""Factory class for creating LLM configurations."""
|
|
50
92
|
|
|
@@ -54,7 +96,12 @@ class LLMConfigFactory(BaseConfig):
|
|
|
54
96
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
55
97
|
"openai": OpenAILLMConfig,
|
|
56
98
|
"ollama": OllamaLLMConfig,
|
|
99
|
+
"azure": AzureLLMConfig,
|
|
57
100
|
"huggingface": HFLLMConfig,
|
|
101
|
+
"vllm": VLLMLLMConfig,
|
|
102
|
+
"huggingface_singleton": HFLLMConfig, # Add singleton support
|
|
103
|
+
"qwen": QwenLLMConfig,
|
|
104
|
+
"deepseek": DeepSeekLLMConfig,
|
|
58
105
|
}
|
|
59
106
|
|
|
60
107
|
@field_validator("backend")
|
memos/configs/mem_cube.py
CHANGED
|
@@ -70,7 +70,7 @@ class GeneralMemCubeConfig(BaseMemCubeConfig):
|
|
|
70
70
|
@classmethod
|
|
71
71
|
def validate_act_mem(cls, act_mem: MemoryConfigFactory) -> MemoryConfigFactory:
|
|
72
72
|
"""Validate the act_mem field."""
|
|
73
|
-
allowed_backends = ["kv_cache", "uninitialized"]
|
|
73
|
+
allowed_backends = ["kv_cache", "vllm_kv_cache", "uninitialized"]
|
|
74
74
|
if act_mem.backend not in allowed_backends:
|
|
75
75
|
raise ConfigurationError(
|
|
76
76
|
f"GeneralMemCubeConfig requires act_mem backend to be one of {allowed_backends}, got '{act_mem.backend}'"
|
memos/configs/mem_scheduler.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
1
4
|
from typing import Any, ClassVar
|
|
2
5
|
|
|
3
6
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
4
7
|
|
|
5
8
|
from memos.configs.base import BaseConfig
|
|
6
9
|
from memos.mem_scheduler.modules.schemas import (
|
|
10
|
+
BASE_DIR,
|
|
7
11
|
DEFAULT_ACT_MEM_DUMP_PATH,
|
|
8
|
-
DEFAULT_ACTIVATION_MEM_SIZE,
|
|
9
12
|
DEFAULT_CONSUME_INTERVAL_SECONDS,
|
|
10
13
|
DEFAULT_THREAD__POOL_MAX_WORKERS,
|
|
14
|
+
DictConversionMixin,
|
|
11
15
|
)
|
|
12
16
|
|
|
13
17
|
|
|
@@ -33,6 +37,10 @@ class BaseSchedulerConfig(BaseConfig):
|
|
|
33
37
|
le=60,
|
|
34
38
|
description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})",
|
|
35
39
|
)
|
|
40
|
+
auth_config_path: str | None = Field(
|
|
41
|
+
default=None,
|
|
42
|
+
description="Path to the authentication configuration file containing private credentials",
|
|
43
|
+
)
|
|
36
44
|
|
|
37
45
|
|
|
38
46
|
class GeneralSchedulerConfig(BaseSchedulerConfig):
|
|
@@ -42,14 +50,13 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
|
|
|
42
50
|
context_window_size: int | None = Field(
|
|
43
51
|
default=5, description="Size of the context window for conversation history"
|
|
44
52
|
)
|
|
45
|
-
activation_mem_size: int | None = Field(
|
|
46
|
-
default=DEFAULT_ACTIVATION_MEM_SIZE, # Assuming DEFAULT_ACTIVATION_MEM_SIZE is 1000
|
|
47
|
-
description="Maximum size of the activation memory",
|
|
48
|
-
)
|
|
49
53
|
act_mem_dump_path: str | None = Field(
|
|
50
54
|
default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH
|
|
51
55
|
description="File path for dumping activation memory",
|
|
52
56
|
)
|
|
57
|
+
enable_act_memory_update: bool = Field(
|
|
58
|
+
default=False, description="Whether to enable automatic activation memory updates"
|
|
59
|
+
)
|
|
53
60
|
|
|
54
61
|
|
|
55
62
|
class SchedulerConfigFactory(BaseConfig):
|
|
@@ -76,3 +83,82 @@ class SchedulerConfigFactory(BaseConfig):
|
|
|
76
83
|
config_class = self.backend_to_class[self.backend]
|
|
77
84
|
self.config = config_class(**self.config)
|
|
78
85
|
return self
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# ************************* Auth *************************
|
|
89
|
+
class RabbitMQConfig(
|
|
90
|
+
BaseConfig,
|
|
91
|
+
):
|
|
92
|
+
host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access")
|
|
93
|
+
user_name: str = Field(default="", description="Static username for RabbitMQ instance")
|
|
94
|
+
password: str = Field(default="", description="Password for the static username")
|
|
95
|
+
virtual_host: str = Field(default="", description="Vhost name for RabbitMQ instance")
|
|
96
|
+
erase_on_connect: bool = Field(
|
|
97
|
+
default=True, description="Whether to clear connection state or buffers upon connecting"
|
|
98
|
+
)
|
|
99
|
+
port: int = Field(
|
|
100
|
+
default=5672,
|
|
101
|
+
description="Port number for RabbitMQ instance access",
|
|
102
|
+
ge=1, # Port must be >= 1
|
|
103
|
+
le=65535, # Port must be <= 65535
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class GraphDBAuthConfig(BaseConfig):
|
|
108
|
+
uri: str = Field(default="localhost", description="URI for graph database access")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class OpenAIConfig(BaseConfig):
|
|
112
|
+
api_key: str = Field(default="", description="API key for OpenAI service")
|
|
113
|
+
base_url: str = Field(default="", description="Base URL for API endpoint")
|
|
114
|
+
default_model: str = Field(default="", description="Default model to use")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AuthConfig(BaseConfig, DictConversionMixin):
|
|
118
|
+
rabbitmq: RabbitMQConfig
|
|
119
|
+
openai: OpenAIConfig
|
|
120
|
+
graph_db: GraphDBAuthConfig
|
|
121
|
+
default_config_path: ClassVar[str] = (
|
|
122
|
+
f"{BASE_DIR}/examples/data/config/mem_scheduler/scheduler_auth.yaml"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_local_yaml(cls, config_path: str | None = None) -> "AuthConfig":
|
|
127
|
+
"""
|
|
128
|
+
Load configuration from YAML file
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
config_path: Path to YAML configuration file
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
AuthConfig instance
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
FileNotFoundError: If config file doesn't exist
|
|
138
|
+
ValueError: If YAML parsing or validation fails
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
if config_path is None:
|
|
142
|
+
config_path = cls.default_config_path
|
|
143
|
+
|
|
144
|
+
# Check file exists
|
|
145
|
+
if not Path(config_path).exists():
|
|
146
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
147
|
+
|
|
148
|
+
return cls.from_yaml_file(yaml_path=config_path)
|
|
149
|
+
|
|
150
|
+
def set_openai_config_to_environment(self):
|
|
151
|
+
# Set environment variables
|
|
152
|
+
os.environ["OPENAI_API_KEY"] = self.openai.api_key
|
|
153
|
+
os.environ["OPENAI_BASE_URL"] = self.openai.base_url
|
|
154
|
+
os.environ["MODEL"] = self.openai.default_model
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def default_config_exists(cls) -> bool:
|
|
158
|
+
"""
|
|
159
|
+
Check if the default configuration file exists.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
bool: True if the default config file exists, False otherwise
|
|
163
|
+
"""
|
|
164
|
+
return Path(cls.default_config_path).exists()
|
memos/configs/memory.py
CHANGED
|
@@ -52,9 +52,9 @@ class KVCacheMemoryConfig(BaseActMemoryConfig):
|
|
|
52
52
|
@classmethod
|
|
53
53
|
def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
|
|
54
54
|
"""Validate the extractor_llm field."""
|
|
55
|
-
if extractor_llm.backend
|
|
55
|
+
if extractor_llm.backend not in ["huggingface", "huggingface_singleton", "vllm"]:
|
|
56
56
|
raise ConfigurationError(
|
|
57
|
-
f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
|
|
57
|
+
f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
|
|
58
58
|
)
|
|
59
59
|
return extractor_llm
|
|
60
60
|
|
|
@@ -84,9 +84,9 @@ class LoRAMemoryConfig(BaseParaMemoryConfig):
|
|
|
84
84
|
@classmethod
|
|
85
85
|
def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
|
|
86
86
|
"""Validate the extractor_llm field."""
|
|
87
|
-
if extractor_llm.backend not in ["huggingface"]:
|
|
87
|
+
if extractor_llm.backend not in ["huggingface", "huggingface_singleton"]:
|
|
88
88
|
raise ConfigurationError(
|
|
89
|
-
f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
|
|
89
|
+
f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
|
|
90
90
|
)
|
|
91
91
|
return extractor_llm
|
|
92
92
|
|
|
@@ -181,6 +181,7 @@ class MemoryConfigFactory(BaseConfig):
|
|
|
181
181
|
"general_text": GeneralTextMemoryConfig,
|
|
182
182
|
"tree_text": TreeTextMemoryConfig,
|
|
183
183
|
"kv_cache": KVCacheMemoryConfig,
|
|
184
|
+
"vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache
|
|
184
185
|
"lora": LoRAMemoryConfig,
|
|
185
186
|
"uninitialized": UninitializedMemoryConfig,
|
|
186
187
|
}
|
memos/dependency.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This utility provides tools for managing dependencies in MemOS.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import importlib
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def require_python_package(
|
|
10
|
+
import_name: str, install_command: str | None = None, install_link: str | None = None
|
|
11
|
+
):
|
|
12
|
+
"""Check if a package is available and provide installation hints on import failure.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
import_name (str): The top-level importable module name a package provides.
|
|
16
|
+
install_command (str, optional): Installation command.
|
|
17
|
+
install_link (str, optional): URL link to installation guide.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Callable: A decorator function that wraps the target function with package availability check.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ImportError: When the specified package is not available, with installation
|
|
24
|
+
instructions included in the error message.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
>>> @require_python_package(
|
|
28
|
+
... import_name='faiss',
|
|
29
|
+
... install_command='pip install faiss-cpu',
|
|
30
|
+
... install_link='https://github.com/facebookresearch/faiss/blob/main/INSTALL.md'
|
|
31
|
+
... )
|
|
32
|
+
... def create_faiss_index():
|
|
33
|
+
... from faiss import IndexFlatL2 # Actual import in function
|
|
34
|
+
... return IndexFlatL2(128)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def decorator(func):
|
|
38
|
+
@functools.wraps(func)
|
|
39
|
+
def wrapper(*args, **kwargs):
|
|
40
|
+
try:
|
|
41
|
+
importlib.import_module(import_name)
|
|
42
|
+
except ImportError:
|
|
43
|
+
error_msg = f"Missing required module - '{import_name}'\n"
|
|
44
|
+
error_msg += f"💡 Install command: {install_command}\n" if install_command else ""
|
|
45
|
+
error_msg += f"💡 Install guide: {install_link}\n" if install_link else ""
|
|
46
|
+
|
|
47
|
+
raise ImportError(error_msg) from None
|
|
48
|
+
return func(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
return wrapper
|
|
51
|
+
|
|
52
|
+
return decorator
|
memos/embedders/ark.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from memos.configs.embedder import ArkEmbedderConfig
|
|
2
|
+
from memos.dependency import require_python_package
|
|
3
|
+
from memos.embedders.base import BaseEmbedder
|
|
4
|
+
from memos.log import get_logger
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ArkEmbedder(BaseEmbedder):
|
|
11
|
+
"""Ark Embedder class."""
|
|
12
|
+
|
|
13
|
+
@require_python_package(
|
|
14
|
+
import_name="volcenginesdkarkruntime",
|
|
15
|
+
install_command="pip install 'volcengine-python-sdk[ark]'",
|
|
16
|
+
install_link="https://www.volcengine.com/docs/82379/1541595",
|
|
17
|
+
)
|
|
18
|
+
def __init__(self, config: ArkEmbedderConfig):
|
|
19
|
+
from volcenginesdkarkruntime import Ark
|
|
20
|
+
|
|
21
|
+
self.config = config
|
|
22
|
+
|
|
23
|
+
if self.config.embedding_dims is not None:
|
|
24
|
+
logger.warning(
|
|
25
|
+
"Ark does not support specifying embedding dimensions. "
|
|
26
|
+
"The embedding dimensions is determined by the model."
|
|
27
|
+
"`embedding_dims` will be set to None."
|
|
28
|
+
)
|
|
29
|
+
self.config.embedding_dims = None
|
|
30
|
+
|
|
31
|
+
# Default model if not specified
|
|
32
|
+
if not self.config.model_name_or_path:
|
|
33
|
+
self.config.model_name_or_path = "doubao-embedding-vision-250615"
|
|
34
|
+
|
|
35
|
+
# Initialize ark client
|
|
36
|
+
self.client = Ark(api_key=self.config.api_key, base_url=self.config.api_base)
|
|
37
|
+
|
|
38
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
39
|
+
"""
|
|
40
|
+
Generate embeddings for the given texts.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
texts: List of texts to embed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
List of embeddings, each represented as a list of floats.
|
|
47
|
+
"""
|
|
48
|
+
from volcenginesdkarkruntime.types.multimodal_embedding import (
|
|
49
|
+
MultimodalEmbeddingContentPartTextParam,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if self.config.multi_modal:
|
|
53
|
+
texts_input = [
|
|
54
|
+
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
|
|
55
|
+
]
|
|
56
|
+
return self.multimodal_embeddings(inputs=texts_input, chunk_size=self.config.chunk_size)
|
|
57
|
+
return self.text_embedding(texts, chunk_size=self.config.chunk_size)
|
|
58
|
+
|
|
59
|
+
def text_embedding(self, inputs: list[str], chunk_size: int | None = None) -> list[list[float]]:
|
|
60
|
+
chunk_size_ = chunk_size or self.config.chunk_size
|
|
61
|
+
embeddings: list[list[float]] = []
|
|
62
|
+
for i in range(0, len(inputs), chunk_size_):
|
|
63
|
+
response = self.client.embeddings.create(
|
|
64
|
+
model=self.config.model_name_or_path,
|
|
65
|
+
input=inputs[i : i + chunk_size_],
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
data = [response.data] if isinstance(response.data, dict) else response.data
|
|
69
|
+
embeddings.extend(r.embedding for r in data)
|
|
70
|
+
|
|
71
|
+
return embeddings
|
|
72
|
+
|
|
73
|
+
def multimodal_embeddings(
|
|
74
|
+
self, inputs: list, chunk_size: int | None = None
|
|
75
|
+
) -> list[list[float]]:
|
|
76
|
+
from volcenginesdkarkruntime.types.multimodal_embedding import (
|
|
77
|
+
MultimodalEmbeddingResponse, # noqa: TC002
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
chunk_size_ = chunk_size or self.config.chunk_size
|
|
81
|
+
embeddings: list[list[float]] = []
|
|
82
|
+
|
|
83
|
+
for i in range(0, len(inputs), chunk_size_):
|
|
84
|
+
response: MultimodalEmbeddingResponse = self.client.multimodal_embeddings.create(
|
|
85
|
+
model=self.config.model_name_or_path,
|
|
86
|
+
input=inputs[i : i + chunk_size_],
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
data = [response.data] if isinstance(response.data, dict) else response.data
|
|
90
|
+
embeddings.extend(r["embedding"] for r in data)
|
|
91
|
+
|
|
92
|
+
return embeddings
|
memos/embedders/factory.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from typing import Any, ClassVar
|
|
2
2
|
|
|
3
3
|
from memos.configs.embedder import EmbedderConfigFactory
|
|
4
|
+
from memos.embedders.ark import ArkEmbedder
|
|
4
5
|
from memos.embedders.base import BaseEmbedder
|
|
5
6
|
from memos.embedders.ollama import OllamaEmbedder
|
|
6
7
|
from memos.embedders.sentence_transformer import SenTranEmbedder
|
|
8
|
+
from memos.embedders.universal_api import UniversalAPIEmbedder
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class EmbedderFactory(BaseEmbedder):
|
|
@@ -12,6 +14,8 @@ class EmbedderFactory(BaseEmbedder):
|
|
|
12
14
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
13
15
|
"ollama": OllamaEmbedder,
|
|
14
16
|
"sentence_transformer": SenTranEmbedder,
|
|
17
|
+
"ark": ArkEmbedder,
|
|
18
|
+
"universal_api": UniversalAPIEmbedder,
|
|
15
19
|
}
|
|
16
20
|
|
|
17
21
|
@classmethod
|
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
from sentence_transformers import SentenceTransformer
|
|
2
|
-
|
|
3
1
|
from memos.configs.embedder import SenTranEmbedderConfig
|
|
2
|
+
from memos.dependency import require_python_package
|
|
4
3
|
from memos.embedders.base import BaseEmbedder
|
|
5
4
|
from memos.log import get_logger
|
|
6
5
|
|
|
@@ -11,7 +10,14 @@ logger = get_logger(__name__)
|
|
|
11
10
|
class SenTranEmbedder(BaseEmbedder):
|
|
12
11
|
"""Sentence Transformer Embedder class."""
|
|
13
12
|
|
|
13
|
+
@require_python_package(
|
|
14
|
+
import_name="sentence_transformers",
|
|
15
|
+
install_command="pip install sentence-transformers",
|
|
16
|
+
install_link="https://www.sbert.net/docs/installation.html",
|
|
17
|
+
)
|
|
14
18
|
def __init__(self, config: SenTranEmbedderConfig):
|
|
19
|
+
from sentence_transformers import SentenceTransformer
|
|
20
|
+
|
|
15
21
|
self.config = config
|
|
16
22
|
self.model = SentenceTransformer(
|
|
17
23
|
self.config.model_name_or_path, trust_remote_code=self.config.trust_remote_code
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from openai import AzureOpenAI as AzureClient
|
|
2
|
+
from openai import OpenAI as OpenAIClient
|
|
3
|
+
|
|
4
|
+
from memos.configs.embedder import UniversalAPIEmbedderConfig
|
|
5
|
+
from memos.embedders.base import BaseEmbedder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UniversalAPIEmbedder(BaseEmbedder):
|
|
9
|
+
def __init__(self, config: UniversalAPIEmbedderConfig):
|
|
10
|
+
self.provider = config.provider
|
|
11
|
+
self.config = config
|
|
12
|
+
|
|
13
|
+
if self.provider == "openai":
|
|
14
|
+
self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url)
|
|
15
|
+
elif self.provider == "azure":
|
|
16
|
+
self.client = AzureClient(
|
|
17
|
+
azure_endpoint=config.base_url,
|
|
18
|
+
api_version="2024-03-01-preview",
|
|
19
|
+
api_key=config.api_key,
|
|
20
|
+
)
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
23
|
+
|
|
24
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
25
|
+
if self.provider == "openai" or self.provider == "azure":
|
|
26
|
+
response = self.client.embeddings.create(
|
|
27
|
+
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
|
|
28
|
+
input=texts,
|
|
29
|
+
)
|
|
30
|
+
return [r.embedding for r in response.data]
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(f"Unsupported provider: {self.provider}")
|
memos/graph_dbs/base.py
CHANGED
|
@@ -9,12 +9,12 @@ class BaseGraphDB(ABC):
|
|
|
9
9
|
|
|
10
10
|
# Node (Memory) Management
|
|
11
11
|
@abstractmethod
|
|
12
|
-
def add_node(self, id: str,
|
|
12
|
+
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
|
|
13
13
|
"""
|
|
14
14
|
Add a memory node to the graph.
|
|
15
15
|
Args:
|
|
16
16
|
id: Unique identifier for the memory node.
|
|
17
|
-
|
|
17
|
+
memory: Raw memory content (e.g., text).
|
|
18
18
|
metadata: Dictionary of metadata (e.g., timestamp, tags, source).
|
|
19
19
|
"""
|
|
20
20
|
|
memos/graph_dbs/factory.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any, ClassVar
|
|
|
3
3
|
from memos.configs.graph_db import GraphDBConfigFactory
|
|
4
4
|
from memos.graph_dbs.base import BaseGraphDB
|
|
5
5
|
from memos.graph_dbs.neo4j import Neo4jGraphDB
|
|
6
|
+
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class GraphStoreFactory(BaseGraphDB):
|
|
@@ -10,6 +11,7 @@ class GraphStoreFactory(BaseGraphDB):
|
|
|
10
11
|
|
|
11
12
|
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
12
13
|
"neo4j": Neo4jGraphDB,
|
|
14
|
+
"neo4j-community": Neo4jCommunityGraphDB,
|
|
13
15
|
}
|
|
14
16
|
|
|
15
17
|
@classmethod
|