MemoryOS 0.0.1__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- memoryos-0.1.12.dist-info/METADATA +257 -0
- memoryos-0.1.12.dist-info/RECORD +117 -0
- memos/__init__.py +20 -1
- memos/api/start_api.py +420 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/factory.py +22 -0
- memos/chunkers/sentence_chunker.py +35 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +45 -0
- memos/configs/embedder.py +53 -0
- memos/configs/graph_db.py +45 -0
- memos/configs/llm.py +71 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +89 -0
- memos/configs/mem_os.py +70 -0
- memos/configs/mem_reader.py +53 -0
- memos/configs/mem_scheduler.py +78 -0
- memos/configs/memory.py +190 -0
- memos/configs/parser.py +38 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +64 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/base.py +15 -0
- memos/embedders/factory.py +23 -0
- memos/embedders/ollama.py +74 -0
- memos/embedders/sentence_transformer.py +40 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +215 -0
- memos/graph_dbs/factory.py +21 -0
- memos/graph_dbs/neo4j.py +827 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +16 -0
- memos/llms/factory.py +25 -0
- memos/llms/hf.py +231 -0
- memos/llms/ollama.py +82 -0
- memos/llms/openai.py +34 -0
- memos/llms/utils.py +14 -0
- memos/log.py +78 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +29 -0
- memos/mem_cube/general.py +146 -0
- memos/mem_cube/utils.py +24 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +819 -0
- memos/mem_os/main.py +12 -0
- memos/mem_os/product.py +89 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +27 -0
- memos/mem_reader/factory.py +21 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/simple_struct.py +241 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/base_scheduler.py +164 -0
- memos/mem_scheduler/general_scheduler.py +305 -0
- memos/mem_scheduler/modules/__init__.py +0 -0
- memos/mem_scheduler/modules/base.py +74 -0
- memos/mem_scheduler/modules/dispatcher.py +103 -0
- memos/mem_scheduler/modules/monitor.py +82 -0
- memos/mem_scheduler/modules/redis_service.py +146 -0
- memos/mem_scheduler/modules/retriever.py +41 -0
- memos/mem_scheduler/modules/schemas.py +146 -0
- memos/mem_scheduler/scheduler_factory.py +21 -0
- memos/mem_scheduler/utils.py +26 -0
- memos/mem_user/user_manager.py +478 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +232 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +34 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +89 -0
- memos/memories/textual/general.py +286 -0
- memos/memories/textual/item.py +167 -0
- memos/memories/textual/naive.py +185 -0
- memos/memories/textual/tree.py +289 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +305 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +64 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +158 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +13 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +166 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +68 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +48 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +19 -0
- memos/parsers/markitdown.py +22 -0
- memos/settings.py +8 -0
- memos/templates/__init__.py +0 -0
- memos/templates/mem_reader_prompts.py +98 -0
- memos/templates/mem_scheduler_prompts.py +65 -0
- memos/types.py +55 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +105 -0
- memos/vec_dbs/factory.py +21 -0
- memos/vec_dbs/item.py +43 -0
- memos/vec_dbs/qdrant.py +292 -0
- memoryos-0.0.1.dist-info/METADATA +0 -53
- memoryos-0.0.1.dist-info/RECORD +0 -5
- {memoryos-0.0.1.dist-info → memoryos-0.1.12.dist-info}/LICENSE +0 -0
- {memoryos-0.0.1.dist-info → memoryos-0.1.12.dist-info}/WHEEL +0 -0
memos/configs/parser.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, field_validator, model_validator
|
|
4
|
+
|
|
5
|
+
from memos.configs.base import BaseConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseParserConfig(BaseConfig):
|
|
9
|
+
"""Base configuration class for parser models."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MarkItDownParserConfig(BaseParserConfig):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ParserConfigFactory(BaseConfig):
|
|
17
|
+
"""Factory class for creating Parser configurations."""
|
|
18
|
+
|
|
19
|
+
backend: str = Field(..., description="Backend for parser")
|
|
20
|
+
config: dict[str, Any] = Field(..., description="Configuration for the parser backend")
|
|
21
|
+
|
|
22
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
23
|
+
"markitdown": MarkItDownParserConfig,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
@field_validator("backend")
|
|
27
|
+
@classmethod
|
|
28
|
+
def validate_backend(cls, backend: str) -> str:
|
|
29
|
+
"""Validate the backend field."""
|
|
30
|
+
if backend not in cls.backend_to_class:
|
|
31
|
+
raise ValueError(f"Invalid backend: {backend}")
|
|
32
|
+
return backend
|
|
33
|
+
|
|
34
|
+
@model_validator(mode="after")
|
|
35
|
+
def create_config(self) -> "ParserConfigFactory":
|
|
36
|
+
config_class = self.backend_to_class[self.backend]
|
|
37
|
+
self.config = config_class(**self.config)
|
|
38
|
+
return self
|
memos/configs/utils.py
ADDED
memos/configs/vec_db.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, field_validator, model_validator
|
|
4
|
+
|
|
5
|
+
from memos import settings
|
|
6
|
+
from memos.configs.base import BaseConfig
|
|
7
|
+
from memos.log import get_logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseVecDBConfig(BaseConfig):
|
|
14
|
+
"""Base class for all vector database configurations."""
|
|
15
|
+
|
|
16
|
+
collection_name: str = Field(..., description="Name of the collection")
|
|
17
|
+
vector_dimension: int | None = Field(default=None, description="Dimension of the vectors")
|
|
18
|
+
distance_metric: Literal["cosine", "euclidean", "dot"] | None = Field(
|
|
19
|
+
default=None,
|
|
20
|
+
description="Distance metric for vector similarity calculation. Options: 'cosine', 'euclidean', 'dot'",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class QdrantVecDBConfig(BaseVecDBConfig):
|
|
25
|
+
"""Configuration for Qdrant vector database."""
|
|
26
|
+
|
|
27
|
+
host: str | None = Field(default=None, description="Host for Qdrant")
|
|
28
|
+
port: int | None = Field(default=None, description="Port for Qdrant")
|
|
29
|
+
path: str | None = Field(default=None, description="Path for Qdrant")
|
|
30
|
+
|
|
31
|
+
@model_validator(mode="after")
|
|
32
|
+
def set_default_path(self):
|
|
33
|
+
if all(x is None for x in (self.host, self.port, self.path)):
|
|
34
|
+
logger.warning(
|
|
35
|
+
"No host, port, or path provided for Qdrant. Defaulting to local path: %s",
|
|
36
|
+
settings.MEMOS_DIR / "qdrant",
|
|
37
|
+
)
|
|
38
|
+
self.path = str(settings.MEMOS_DIR / "qdrant")
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class VectorDBConfigFactory(BaseConfig):
|
|
43
|
+
"""Factory class for creating vector database configurations."""
|
|
44
|
+
|
|
45
|
+
backend: str = Field(..., description="Backend for vector database")
|
|
46
|
+
config: dict[str, Any] = Field(..., description="Configuration for the vector database backend")
|
|
47
|
+
|
|
48
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
49
|
+
"qdrant": QdrantVecDBConfig,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
@field_validator("backend")
|
|
53
|
+
@classmethod
|
|
54
|
+
def validate_backend(cls, backend: str) -> str:
|
|
55
|
+
"""Validate the backend field."""
|
|
56
|
+
if backend not in cls.backend_to_class:
|
|
57
|
+
raise ValueError(f"Invalid vector database backend: {backend}")
|
|
58
|
+
return backend
|
|
59
|
+
|
|
60
|
+
@model_validator(mode="after")
|
|
61
|
+
def create_config(self) -> "VectorDBConfigFactory":
|
|
62
|
+
config_class = self.backend_to_class[self.backend]
|
|
63
|
+
self.config = config_class(**self.config)
|
|
64
|
+
return self
|
memos/deprecation.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides utilities for marking functions, classes, and parameters
|
|
3
|
+
as deprecated. It includes decorators for deprecation, a function to issue
|
|
4
|
+
warnings, and utilities to check deprecation status.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Any, TypeVar
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
warnings.simplefilter("default", DeprecationWarning)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
18
|
+
C = TypeVar("C", bound=type)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def deprecated(
|
|
22
|
+
reason: str | None = None,
|
|
23
|
+
version: str | None = None,
|
|
24
|
+
alternative: str | None = None,
|
|
25
|
+
category: type[Warning] = DeprecationWarning,
|
|
26
|
+
stacklevel: int = 2,
|
|
27
|
+
) -> Callable[[F], F]:
|
|
28
|
+
"""
|
|
29
|
+
Decorator to mark functions as deprecated.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
reason: Optional reason for deprecation
|
|
33
|
+
version: Version when the function was deprecated
|
|
34
|
+
alternative: Suggested alternative function/method
|
|
35
|
+
category: Warning category to use
|
|
36
|
+
stacklevel: Stack level for the warning
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
@deprecated(reason="Use new_function instead", version="1.2.0")
|
|
40
|
+
def old_function():
|
|
41
|
+
pass
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def decorator(func: F) -> F:
|
|
45
|
+
@functools.wraps(func)
|
|
46
|
+
def wrapper(*args, **kwargs):
|
|
47
|
+
# Build deprecation message
|
|
48
|
+
msg_parts = [f"Function '{func.__name__}' is deprecated"]
|
|
49
|
+
|
|
50
|
+
if version:
|
|
51
|
+
msg_parts.append(f"since version {version}")
|
|
52
|
+
|
|
53
|
+
if reason:
|
|
54
|
+
msg_parts.append(f"- {reason}")
|
|
55
|
+
|
|
56
|
+
if alternative:
|
|
57
|
+
msg_parts.append(f"Use '{alternative}' instead")
|
|
58
|
+
|
|
59
|
+
message = ". ".join(msg_parts) + "."
|
|
60
|
+
|
|
61
|
+
warnings.warn(message, category=category, stacklevel=stacklevel)
|
|
62
|
+
return func(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
# Mark the wrapper as deprecated for introspection
|
|
65
|
+
wrapper.__deprecated__ = True
|
|
66
|
+
wrapper.__deprecation_info__ = {
|
|
67
|
+
"reason": reason,
|
|
68
|
+
"version": version,
|
|
69
|
+
"alternative": alternative,
|
|
70
|
+
"category": category,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return wrapper
|
|
74
|
+
|
|
75
|
+
return decorator
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def deprecated_class(
|
|
79
|
+
reason: str | None = None,
|
|
80
|
+
version: str | None = None,
|
|
81
|
+
alternative: str | None = None,
|
|
82
|
+
category: type[Warning] = DeprecationWarning,
|
|
83
|
+
stacklevel: int = 2,
|
|
84
|
+
) -> Callable[[C], C]:
|
|
85
|
+
"""
|
|
86
|
+
Decorator to mark classes as deprecated.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
reason: Optional reason for deprecation
|
|
90
|
+
version: Version when the class was deprecated
|
|
91
|
+
alternative: Suggested alternative class
|
|
92
|
+
category: Warning category to use
|
|
93
|
+
stacklevel: Stack level for the warning
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
@deprecated_class(reason="Use NewClass instead", version="1.2.0")
|
|
97
|
+
class OldClass:
|
|
98
|
+
pass
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def decorator(cls: C) -> C:
|
|
102
|
+
# Store original __init__
|
|
103
|
+
original_init = cls.__init__
|
|
104
|
+
|
|
105
|
+
@functools.wraps(original_init)
|
|
106
|
+
def new_init(self, *args, **kwargs):
|
|
107
|
+
# Build deprecation message
|
|
108
|
+
msg_parts = [f"Class '{cls.__name__}' is deprecated"]
|
|
109
|
+
|
|
110
|
+
if version:
|
|
111
|
+
msg_parts.append(f"since version {version}")
|
|
112
|
+
|
|
113
|
+
if reason:
|
|
114
|
+
msg_parts.append(f"- {reason}")
|
|
115
|
+
|
|
116
|
+
if alternative:
|
|
117
|
+
msg_parts.append(f"Use '{alternative}' instead")
|
|
118
|
+
|
|
119
|
+
message = ". ".join(msg_parts) + "."
|
|
120
|
+
|
|
121
|
+
warnings.warn(message, category=category, stacklevel=stacklevel)
|
|
122
|
+
original_init(self, *args, **kwargs)
|
|
123
|
+
|
|
124
|
+
# Replace __init__
|
|
125
|
+
cls.__init__ = new_init
|
|
126
|
+
|
|
127
|
+
# Mark the class as deprecated for introspection
|
|
128
|
+
cls.__deprecated__ = True
|
|
129
|
+
cls.__deprecation_info__ = {
|
|
130
|
+
"reason": reason,
|
|
131
|
+
"version": version,
|
|
132
|
+
"alternative": alternative,
|
|
133
|
+
"category": category,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return cls
|
|
137
|
+
|
|
138
|
+
return decorator
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def deprecated_parameter(
|
|
142
|
+
parameter_name: str,
|
|
143
|
+
reason: str | None = None,
|
|
144
|
+
version: str | None = None,
|
|
145
|
+
alternative: str | None = None,
|
|
146
|
+
category: type[Warning] = DeprecationWarning,
|
|
147
|
+
stacklevel: int = 2,
|
|
148
|
+
) -> Callable[[F], F]:
|
|
149
|
+
"""
|
|
150
|
+
Decorator to mark specific parameters as deprecated.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
parameter_name: Name of the deprecated parameter
|
|
154
|
+
reason: Optional reason for deprecation
|
|
155
|
+
version: Version when the parameter was deprecated
|
|
156
|
+
alternative: Suggested alternative parameter
|
|
157
|
+
category: Warning category to use
|
|
158
|
+
stacklevel: Stack level for the warning
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
@deprecated_parameter("old_param", alternative="new_param", version="1.2.0")
|
|
162
|
+
def my_function(new_param=None, old_param=None):
|
|
163
|
+
pass
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def decorator(func: F) -> F:
|
|
167
|
+
@functools.wraps(func)
|
|
168
|
+
def wrapper(*args, **kwargs):
|
|
169
|
+
# Check if deprecated parameter is used
|
|
170
|
+
if parameter_name in kwargs:
|
|
171
|
+
# Build deprecation message
|
|
172
|
+
msg_parts = [
|
|
173
|
+
f"Parameter '{parameter_name}' in function '{func.__name__}' is deprecated"
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
if version:
|
|
177
|
+
msg_parts.append(f"since version {version}")
|
|
178
|
+
|
|
179
|
+
if reason:
|
|
180
|
+
msg_parts.append(f"- {reason}")
|
|
181
|
+
|
|
182
|
+
if alternative:
|
|
183
|
+
msg_parts.append(f"Use parameter '{alternative}' instead")
|
|
184
|
+
|
|
185
|
+
message = ". ".join(msg_parts) + "."
|
|
186
|
+
|
|
187
|
+
warnings.warn(message, category=category, stacklevel=stacklevel)
|
|
188
|
+
|
|
189
|
+
return func(*args, **kwargs)
|
|
190
|
+
|
|
191
|
+
return wrapper
|
|
192
|
+
|
|
193
|
+
return decorator
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def warn_deprecated(
|
|
197
|
+
item_name: str,
|
|
198
|
+
item_type: str = "feature",
|
|
199
|
+
reason: str | None = None,
|
|
200
|
+
version: str | None = None,
|
|
201
|
+
alternative: str | None = None,
|
|
202
|
+
category: type[Warning] = DeprecationWarning,
|
|
203
|
+
stacklevel: int = 2,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Issue a deprecation warning for any item.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
item_name: Name of the deprecated item
|
|
210
|
+
item_type: Type of item (e.g., "function", "class", "parameter", "feature")
|
|
211
|
+
reason: Optional reason for deprecation
|
|
212
|
+
version: Version when the item was deprecated
|
|
213
|
+
alternative: Suggested alternative
|
|
214
|
+
category: Warning category to use
|
|
215
|
+
stacklevel: Stack level for the warning
|
|
216
|
+
|
|
217
|
+
Example:
|
|
218
|
+
warn_deprecated("old_method", "method", version="1.2.0", alternative="new_method")
|
|
219
|
+
"""
|
|
220
|
+
# Build deprecation message
|
|
221
|
+
msg_parts = [f"{item_type.capitalize()} '{item_name}' is deprecated"]
|
|
222
|
+
|
|
223
|
+
if version:
|
|
224
|
+
msg_parts.append(f"since version {version}")
|
|
225
|
+
|
|
226
|
+
if reason:
|
|
227
|
+
msg_parts.append(f"- {reason}")
|
|
228
|
+
|
|
229
|
+
if alternative:
|
|
230
|
+
msg_parts.append(f"Use '{alternative}' instead")
|
|
231
|
+
|
|
232
|
+
message = ". ".join(msg_parts) + "."
|
|
233
|
+
|
|
234
|
+
warnings.warn(message, category=category, stacklevel=stacklevel)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def is_deprecated(obj: Any) -> bool:
|
|
238
|
+
"""
|
|
239
|
+
Check if an object is marked as deprecated.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
obj: Object to check
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
True if the object is deprecated, False otherwise
|
|
246
|
+
"""
|
|
247
|
+
return getattr(obj, "__deprecated__", False)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def get_deprecation_info(obj: Any) -> dict | None:
|
|
251
|
+
"""
|
|
252
|
+
Get deprecation information for an object.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
obj: Object to get deprecation info for
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Dictionary with deprecation info or None if not deprecated
|
|
259
|
+
"""
|
|
260
|
+
if is_deprecated(obj):
|
|
261
|
+
return getattr(obj, "__deprecation_info__", None)
|
|
262
|
+
return None
|
|
File without changes
|
memos/embedders/base.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from memos.configs.embedder import BaseEmbedderConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseEmbedder(ABC):
|
|
7
|
+
"""Base class for all Embedding models."""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def __init__(self, config: BaseEmbedderConfig):
|
|
11
|
+
"""Initialize the embedding model with the given configuration."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
15
|
+
"""Generate embeddings for the given texts."""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
|
|
3
|
+
from memos.configs.embedder import EmbedderConfigFactory
|
|
4
|
+
from memos.embedders.base import BaseEmbedder
|
|
5
|
+
from memos.embedders.ollama import OllamaEmbedder
|
|
6
|
+
from memos.embedders.sentence_transformer import SenTranEmbedder
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EmbedderFactory(BaseEmbedder):
|
|
10
|
+
"""Factory class for creating embedder instances."""
|
|
11
|
+
|
|
12
|
+
backend_to_class: ClassVar[dict[str, Any]] = {
|
|
13
|
+
"ollama": OllamaEmbedder,
|
|
14
|
+
"sentence_transformer": SenTranEmbedder,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_config(cls, config_factory: EmbedderConfigFactory) -> BaseEmbedder:
|
|
19
|
+
backend = config_factory.backend
|
|
20
|
+
if backend not in cls.backend_to_class:
|
|
21
|
+
raise ValueError(f"Invalid backend: {backend}")
|
|
22
|
+
embedder_class = cls.backend_to_class[backend]
|
|
23
|
+
return embedder_class(config_factory.config)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from ollama import Client
|
|
2
|
+
|
|
3
|
+
from memos.configs.embedder import OllamaEmbedderConfig
|
|
4
|
+
from memos.embedders.base import BaseEmbedder
|
|
5
|
+
from memos.log import get_logger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OllamaEmbedder(BaseEmbedder):
|
|
12
|
+
"""Ollama Embedder class."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: OllamaEmbedderConfig):
|
|
15
|
+
self.config = config
|
|
16
|
+
self.api_base = config.api_base
|
|
17
|
+
|
|
18
|
+
if self.config.embedding_dims is not None:
|
|
19
|
+
logger.warning(
|
|
20
|
+
"Ollama does not support specifying embedding dimensions. "
|
|
21
|
+
"The embedding dimensions is determined by the model."
|
|
22
|
+
"`embedding_dims` will be set to None."
|
|
23
|
+
)
|
|
24
|
+
self.config.embedding_dims = None
|
|
25
|
+
|
|
26
|
+
# Default model if not specified
|
|
27
|
+
if not self.config.model_name_or_path:
|
|
28
|
+
self.config.model_name_or_path = "nomic-embed-text:latest"
|
|
29
|
+
|
|
30
|
+
# Initialize ollama client
|
|
31
|
+
self.client = Client(host=self.api_base)
|
|
32
|
+
|
|
33
|
+
# Ensure the model exists locally
|
|
34
|
+
self._ensure_model_exists()
|
|
35
|
+
|
|
36
|
+
def _list_models(self) -> list[str]:
|
|
37
|
+
"""
|
|
38
|
+
List all models available in the Ollama client.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
List of model names.
|
|
42
|
+
"""
|
|
43
|
+
local_models = self.client.list()["models"]
|
|
44
|
+
return [model.model for model in local_models]
|
|
45
|
+
|
|
46
|
+
def _ensure_model_exists(self):
|
|
47
|
+
"""
|
|
48
|
+
Ensure the specified model exists locally. If not, pull it from Ollama.
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
local_models = self._list_models()
|
|
52
|
+
if self.config.model_name_or_path not in local_models:
|
|
53
|
+
logger.warning(
|
|
54
|
+
f"Model {self.config.model_name_or_path} not found locally. Pulling from Ollama..."
|
|
55
|
+
)
|
|
56
|
+
self.client.pull(self.config.model_name_or_path)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.warning(f"Could not verify model existence: {e}")
|
|
59
|
+
|
|
60
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
61
|
+
"""
|
|
62
|
+
Generate embeddings for the given texts.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
texts: List of texts to embed.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List of embeddings, each represented as a list of floats.
|
|
69
|
+
"""
|
|
70
|
+
response = self.client.embed(
|
|
71
|
+
model=self.config.model_name_or_path,
|
|
72
|
+
input=texts,
|
|
73
|
+
)
|
|
74
|
+
return response.embeddings
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from sentence_transformers import SentenceTransformer
|
|
2
|
+
|
|
3
|
+
from memos.configs.embedder import SenTranEmbedderConfig
|
|
4
|
+
from memos.embedders.base import BaseEmbedder
|
|
5
|
+
from memos.log import get_logger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SenTranEmbedder(BaseEmbedder):
|
|
12
|
+
"""Sentence Transformer Embedder class."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: SenTranEmbedderConfig):
|
|
15
|
+
self.config = config
|
|
16
|
+
self.model = SentenceTransformer(
|
|
17
|
+
self.config.model_name_or_path, trust_remote_code=self.config.trust_remote_code
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if self.config.embedding_dims is not None:
|
|
21
|
+
logger.warning(
|
|
22
|
+
"SentenceTransformer does not support specifying embedding dimensions directly. "
|
|
23
|
+
"The embedding dimension is determined by the model."
|
|
24
|
+
"`embedding_dims` will be ignored."
|
|
25
|
+
)
|
|
26
|
+
# Get embedding dimensions from the model
|
|
27
|
+
self.config.embedding_dims = self.model.get_sentence_embedding_dimension()
|
|
28
|
+
|
|
29
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
30
|
+
"""
|
|
31
|
+
Generate embeddings for the given texts.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
texts: List of texts to embed.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List of embeddings, each represented as a list of floats.
|
|
38
|
+
"""
|
|
39
|
+
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
|
40
|
+
return embeddings.tolist()
|
memos/exceptions.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Custom exceptions for the MemOS library.
|
|
2
|
+
|
|
3
|
+
This module defines all custom exceptions used throughout the MemOS project.
|
|
4
|
+
All exceptions inherit from a base MemOSError class to provide a consistent
|
|
5
|
+
error handling interface.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MemOSError(Exception): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigurationError(MemOSError): ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MemoryError(MemOSError): ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MemCubeError(MemOSError): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VectorDBError(MemOSError): ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMError(MemOSError): ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EmbedderError(MemOSError): ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ParserError(MemOSError): ...
|
|
File without changes
|