MemoryOS 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (80) hide show
  1. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/METADATA +66 -26
  2. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/RECORD +80 -56
  3. memoryos-0.2.1.dist-info/entry_points.txt +3 -0
  4. memos/__init__.py +1 -1
  5. memos/api/config.py +471 -0
  6. memos/api/exceptions.py +28 -0
  7. memos/api/mcp_serve.py +502 -0
  8. memos/api/product_api.py +35 -0
  9. memos/api/product_models.py +159 -0
  10. memos/api/routers/__init__.py +1 -0
  11. memos/api/routers/product_router.py +358 -0
  12. memos/chunkers/sentence_chunker.py +8 -2
  13. memos/cli.py +113 -0
  14. memos/configs/embedder.py +27 -0
  15. memos/configs/graph_db.py +83 -2
  16. memos/configs/llm.py +47 -0
  17. memos/configs/mem_cube.py +1 -1
  18. memos/configs/mem_scheduler.py +91 -5
  19. memos/configs/memory.py +5 -4
  20. memos/dependency.py +52 -0
  21. memos/embedders/ark.py +92 -0
  22. memos/embedders/factory.py +4 -0
  23. memos/embedders/sentence_transformer.py +8 -2
  24. memos/embedders/universal_api.py +32 -0
  25. memos/graph_dbs/base.py +2 -2
  26. memos/graph_dbs/factory.py +2 -0
  27. memos/graph_dbs/neo4j.py +331 -122
  28. memos/graph_dbs/neo4j_community.py +300 -0
  29. memos/llms/base.py +9 -0
  30. memos/llms/deepseek.py +54 -0
  31. memos/llms/factory.py +10 -1
  32. memos/llms/hf.py +170 -13
  33. memos/llms/hf_singleton.py +114 -0
  34. memos/llms/ollama.py +4 -0
  35. memos/llms/openai.py +67 -1
  36. memos/llms/qwen.py +63 -0
  37. memos/llms/vllm.py +153 -0
  38. memos/mem_cube/general.py +77 -16
  39. memos/mem_cube/utils.py +102 -0
  40. memos/mem_os/core.py +131 -41
  41. memos/mem_os/main.py +93 -11
  42. memos/mem_os/product.py +1098 -35
  43. memos/mem_os/utils/default_config.py +352 -0
  44. memos/mem_os/utils/format_utils.py +1154 -0
  45. memos/mem_reader/simple_struct.py +5 -5
  46. memos/mem_scheduler/base_scheduler.py +467 -36
  47. memos/mem_scheduler/general_scheduler.py +125 -244
  48. memos/mem_scheduler/modules/base.py +9 -0
  49. memos/mem_scheduler/modules/dispatcher.py +68 -2
  50. memos/mem_scheduler/modules/misc.py +39 -0
  51. memos/mem_scheduler/modules/monitor.py +228 -49
  52. memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
  53. memos/mem_scheduler/modules/redis_service.py +32 -22
  54. memos/mem_scheduler/modules/retriever.py +250 -23
  55. memos/mem_scheduler/modules/schemas.py +189 -7
  56. memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
  57. memos/mem_scheduler/utils.py +51 -2
  58. memos/mem_user/persistent_user_manager.py +260 -0
  59. memos/memories/activation/item.py +25 -0
  60. memos/memories/activation/kv.py +10 -3
  61. memos/memories/activation/vllmkv.py +219 -0
  62. memos/memories/factory.py +2 -0
  63. memos/memories/textual/general.py +7 -5
  64. memos/memories/textual/tree.py +9 -5
  65. memos/memories/textual/tree_text_memory/organize/conflict.py +5 -3
  66. memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
  67. memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
  68. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +11 -13
  69. memos/memories/textual/tree_text_memory/organize/reorganizer.py +73 -51
  70. memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
  71. memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
  72. memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
  73. memos/parsers/markitdown.py +8 -2
  74. memos/templates/mem_reader_prompts.py +65 -23
  75. memos/templates/mem_scheduler_prompts.py +96 -47
  76. memos/templates/tree_reorganize_prompts.py +85 -30
  77. memos/vec_dbs/base.py +12 -0
  78. memos/vec_dbs/qdrant.py +46 -20
  79. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
  80. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
memos/configs/llm.py CHANGED
@@ -27,6 +27,40 @@ class OpenAILLMConfig(BaseLLMConfig):
27
27
  extra_body: Any = Field(default=None, description="extra body")
28
28
 
29
29
 
30
+ class QwenLLMConfig(BaseLLMConfig):
31
+ api_key: str = Field(..., description="API key for DashScope (Qwen)")
32
+ api_base: str = Field(
33
+ default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
34
+ description="Base URL for Qwen OpenAI-compatible API",
35
+ )
36
+ extra_body: Any = Field(default=None, description="extra body")
37
+ model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'")
38
+
39
+
40
+ class DeepSeekLLMConfig(BaseLLMConfig):
41
+ api_key: str = Field(..., description="API key for DeepSeek")
42
+ api_base: str = Field(
43
+ default="https://api.deepseek.com",
44
+ description="Base URL for DeepSeek OpenAI-compatible API",
45
+ )
46
+ extra_body: Any = Field(default=None, description="Extra options for API")
47
+ model_name_or_path: str = Field(
48
+ ..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'"
49
+ )
50
+
51
+
52
+ class AzureLLMConfig(BaseLLMConfig):
53
+ base_url: str = Field(
54
+ default="https://api.openai.azure.com/",
55
+ description="Base URL for Azure OpenAI API",
56
+ )
57
+ api_version: str = Field(
58
+ default="2024-03-01-preview",
59
+ description="API version for Azure OpenAI",
60
+ )
61
+ api_key: str = Field(..., description="API key for Azure OpenAI")
62
+
63
+
30
64
  class OllamaLLMConfig(BaseLLMConfig):
31
65
  api_base: str = Field(
32
66
  default="http://localhost:11434",
@@ -45,6 +79,14 @@ class HFLLMConfig(BaseLLMConfig):
45
79
  )
46
80
 
47
81
 
82
+ class VLLMLLMConfig(BaseLLMConfig):
83
+ api_key: str = Field(default="", description="API key for vLLM (optional for local server)")
84
+ api_base: str = Field(
85
+ default="http://localhost:8088/v1",
86
+ description="Base URL for vLLM API",
87
+ )
88
+
89
+
48
90
  class LLMConfigFactory(BaseConfig):
49
91
  """Factory class for creating LLM configurations."""
50
92
 
@@ -54,7 +96,12 @@ class LLMConfigFactory(BaseConfig):
54
96
  backend_to_class: ClassVar[dict[str, Any]] = {
55
97
  "openai": OpenAILLMConfig,
56
98
  "ollama": OllamaLLMConfig,
99
+ "azure": AzureLLMConfig,
57
100
  "huggingface": HFLLMConfig,
101
+ "vllm": VLLMLLMConfig,
102
+ "huggingface_singleton": HFLLMConfig, # Add singleton support
103
+ "qwen": QwenLLMConfig,
104
+ "deepseek": DeepSeekLLMConfig,
58
105
  }
59
106
 
60
107
  @field_validator("backend")
memos/configs/mem_cube.py CHANGED
@@ -70,7 +70,7 @@ class GeneralMemCubeConfig(BaseMemCubeConfig):
70
70
  @classmethod
71
71
  def validate_act_mem(cls, act_mem: MemoryConfigFactory) -> MemoryConfigFactory:
72
72
  """Validate the act_mem field."""
73
- allowed_backends = ["kv_cache", "uninitialized"]
73
+ allowed_backends = ["kv_cache", "vllm_kv_cache", "uninitialized"]
74
74
  if act_mem.backend not in allowed_backends:
75
75
  raise ConfigurationError(
76
76
  f"GeneralMemCubeConfig requires act_mem backend to be one of {allowed_backends}, got '{act_mem.backend}'"
@@ -1,13 +1,17 @@
1
+ import os
2
+
3
+ from pathlib import Path
1
4
  from typing import Any, ClassVar
2
5
 
3
6
  from pydantic import ConfigDict, Field, field_validator, model_validator
4
7
 
5
8
  from memos.configs.base import BaseConfig
6
9
  from memos.mem_scheduler.modules.schemas import (
10
+ BASE_DIR,
7
11
  DEFAULT_ACT_MEM_DUMP_PATH,
8
- DEFAULT_ACTIVATION_MEM_SIZE,
9
12
  DEFAULT_CONSUME_INTERVAL_SECONDS,
10
13
  DEFAULT_THREAD__POOL_MAX_WORKERS,
14
+ DictConversionMixin,
11
15
  )
12
16
 
13
17
 
@@ -33,6 +37,10 @@ class BaseSchedulerConfig(BaseConfig):
33
37
  le=60,
34
38
  description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})",
35
39
  )
40
+ auth_config_path: str | None = Field(
41
+ default=None,
42
+ description="Path to the authentication configuration file containing private credentials",
43
+ )
36
44
 
37
45
 
38
46
  class GeneralSchedulerConfig(BaseSchedulerConfig):
@@ -42,14 +50,13 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
42
50
  context_window_size: int | None = Field(
43
51
  default=5, description="Size of the context window for conversation history"
44
52
  )
45
- activation_mem_size: int | None = Field(
46
- default=DEFAULT_ACTIVATION_MEM_SIZE, # Assuming DEFAULT_ACTIVATION_MEM_SIZE is 1000
47
- description="Maximum size of the activation memory",
48
- )
49
53
  act_mem_dump_path: str | None = Field(
50
54
  default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH
51
55
  description="File path for dumping activation memory",
52
56
  )
57
+ enable_act_memory_update: bool = Field(
58
+ default=False, description="Whether to enable automatic activation memory updates"
59
+ )
53
60
 
54
61
 
55
62
  class SchedulerConfigFactory(BaseConfig):
@@ -76,3 +83,82 @@ class SchedulerConfigFactory(BaseConfig):
76
83
  config_class = self.backend_to_class[self.backend]
77
84
  self.config = config_class(**self.config)
78
85
  return self
86
+
87
+
88
+ # ************************* Auth *************************
89
+ class RabbitMQConfig(
90
+ BaseConfig,
91
+ ):
92
+ host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access")
93
+ user_name: str = Field(default="", description="Static username for RabbitMQ instance")
94
+ password: str = Field(default="", description="Password for the static username")
95
+ virtual_host: str = Field(default="", description="Vhost name for RabbitMQ instance")
96
+ erase_on_connect: bool = Field(
97
+ default=True, description="Whether to clear connection state or buffers upon connecting"
98
+ )
99
+ port: int = Field(
100
+ default=5672,
101
+ description="Port number for RabbitMQ instance access",
102
+ ge=1, # Port must be >= 1
103
+ le=65535, # Port must be <= 65535
104
+ )
105
+
106
+
107
+ class GraphDBAuthConfig(BaseConfig):
108
+ uri: str = Field(default="localhost", description="URI for graph database access")
109
+
110
+
111
+ class OpenAIConfig(BaseConfig):
112
+ api_key: str = Field(default="", description="API key for OpenAI service")
113
+ base_url: str = Field(default="", description="Base URL for API endpoint")
114
+ default_model: str = Field(default="", description="Default model to use")
115
+
116
+
117
+ class AuthConfig(BaseConfig, DictConversionMixin):
118
+ rabbitmq: RabbitMQConfig
119
+ openai: OpenAIConfig
120
+ graph_db: GraphDBAuthConfig
121
+ default_config_path: ClassVar[str] = (
122
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/scheduler_auth.yaml"
123
+ )
124
+
125
+ @classmethod
126
+ def from_local_yaml(cls, config_path: str | None = None) -> "AuthConfig":
127
+ """
128
+ Load configuration from YAML file
129
+
130
+ Args:
131
+ config_path: Path to YAML configuration file
132
+
133
+ Returns:
134
+ AuthConfig instance
135
+
136
+ Raises:
137
+ FileNotFoundError: If config file doesn't exist
138
+ ValueError: If YAML parsing or validation fails
139
+ """
140
+
141
+ if config_path is None:
142
+ config_path = cls.default_config_path
143
+
144
+ # Check file exists
145
+ if not Path(config_path).exists():
146
+ raise FileNotFoundError(f"Config file not found: {config_path}")
147
+
148
+ return cls.from_yaml_file(yaml_path=config_path)
149
+
150
+ def set_openai_config_to_environment(self):
151
+ # Set environment variables
152
+ os.environ["OPENAI_API_KEY"] = self.openai.api_key
153
+ os.environ["OPENAI_BASE_URL"] = self.openai.base_url
154
+ os.environ["MODEL"] = self.openai.default_model
155
+
156
+ @classmethod
157
+ def default_config_exists(cls) -> bool:
158
+ """
159
+ Check if the default configuration file exists.
160
+
161
+ Returns:
162
+ bool: True if the default config file exists, False otherwise
163
+ """
164
+ return Path(cls.default_config_path).exists()
memos/configs/memory.py CHANGED
@@ -52,9 +52,9 @@ class KVCacheMemoryConfig(BaseActMemoryConfig):
52
52
  @classmethod
53
53
  def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
54
54
  """Validate the extractor_llm field."""
55
- if extractor_llm.backend != "huggingface":
55
+ if extractor_llm.backend not in ["huggingface", "huggingface_singleton", "vllm"]:
56
56
  raise ConfigurationError(
57
- f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
57
+ f"KVCacheMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
58
58
  )
59
59
  return extractor_llm
60
60
 
@@ -84,9 +84,9 @@ class LoRAMemoryConfig(BaseParaMemoryConfig):
84
84
  @classmethod
85
85
  def validate_extractor_llm(cls, extractor_llm: LLMConfigFactory) -> LLMConfigFactory:
86
86
  """Validate the extractor_llm field."""
87
- if extractor_llm.backend not in ["huggingface"]:
87
+ if extractor_llm.backend not in ["huggingface", "huggingface_singleton"]:
88
88
  raise ConfigurationError(
89
- f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface', got '{extractor_llm.backend}'"
89
+ f"LoRAMemoryConfig requires extractor_llm backend to be 'huggingface' or 'huggingface_singleton', got '{extractor_llm.backend}'"
90
90
  )
91
91
  return extractor_llm
92
92
 
@@ -181,6 +181,7 @@ class MemoryConfigFactory(BaseConfig):
181
181
  "general_text": GeneralTextMemoryConfig,
182
182
  "tree_text": TreeTextMemoryConfig,
183
183
  "kv_cache": KVCacheMemoryConfig,
184
+ "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache
184
185
  "lora": LoRAMemoryConfig,
185
186
  "uninitialized": UninitializedMemoryConfig,
186
187
  }
memos/dependency.py ADDED
@@ -0,0 +1,52 @@
1
+ """
2
+ This utility provides tools for managing dependencies in MemOS.
3
+ """
4
+
5
+ import functools
6
+ import importlib
7
+
8
+
9
+ def require_python_package(
10
+ import_name: str, install_command: str | None = None, install_link: str | None = None
11
+ ):
12
+ """Check if a package is available and provide installation hints on import failure.
13
+
14
+ Args:
15
+ import_name (str): The top-level importable module name a package provides.
16
+ install_command (str, optional): Installation command.
17
+ install_link (str, optional): URL link to installation guide.
18
+
19
+ Returns:
20
+ Callable: A decorator function that wraps the target function with package availability check.
21
+
22
+ Raises:
23
+ ImportError: When the specified package is not available, with installation
24
+ instructions included in the error message.
25
+
26
+ Example:
27
+ >>> @require_python_package(
28
+ ... import_name='faiss',
29
+ ... install_command='pip install faiss-cpu',
30
+ ... install_link='https://github.com/facebookresearch/faiss/blob/main/INSTALL.md'
31
+ ... )
32
+ ... def create_faiss_index():
33
+ ... from faiss import IndexFlatL2 # Actual import in function
34
+ ... return IndexFlatL2(128)
35
+ """
36
+
37
+ def decorator(func):
38
+ @functools.wraps(func)
39
+ def wrapper(*args, **kwargs):
40
+ try:
41
+ importlib.import_module(import_name)
42
+ except ImportError:
43
+ error_msg = f"Missing required module - '{import_name}'\n"
44
+ error_msg += f"💡 Install command: {install_command}\n" if install_command else ""
45
+ error_msg += f"💡 Install guide: {install_link}\n" if install_link else ""
46
+
47
+ raise ImportError(error_msg) from None
48
+ return func(*args, **kwargs)
49
+
50
+ return wrapper
51
+
52
+ return decorator
memos/embedders/ark.py ADDED
@@ -0,0 +1,92 @@
1
+ from memos.configs.embedder import ArkEmbedderConfig
2
+ from memos.dependency import require_python_package
3
+ from memos.embedders.base import BaseEmbedder
4
+ from memos.log import get_logger
5
+
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ class ArkEmbedder(BaseEmbedder):
11
+ """Ark Embedder class."""
12
+
13
+ @require_python_package(
14
+ import_name="volcenginesdkarkruntime",
15
+ install_command="pip install 'volcengine-python-sdk[ark]'",
16
+ install_link="https://www.volcengine.com/docs/82379/1541595",
17
+ )
18
+ def __init__(self, config: ArkEmbedderConfig):
19
+ from volcenginesdkarkruntime import Ark
20
+
21
+ self.config = config
22
+
23
+ if self.config.embedding_dims is not None:
24
+ logger.warning(
25
+ "Ark does not support specifying embedding dimensions. "
26
+ "The embedding dimensions is determined by the model."
27
+ "`embedding_dims` will be set to None."
28
+ )
29
+ self.config.embedding_dims = None
30
+
31
+ # Default model if not specified
32
+ if not self.config.model_name_or_path:
33
+ self.config.model_name_or_path = "doubao-embedding-vision-250615"
34
+
35
+ # Initialize ark client
36
+ self.client = Ark(api_key=self.config.api_key, base_url=self.config.api_base)
37
+
38
+ def embed(self, texts: list[str]) -> list[list[float]]:
39
+ """
40
+ Generate embeddings for the given texts.
41
+
42
+ Args:
43
+ texts: List of texts to embed.
44
+
45
+ Returns:
46
+ List of embeddings, each represented as a list of floats.
47
+ """
48
+ from volcenginesdkarkruntime.types.multimodal_embedding import (
49
+ MultimodalEmbeddingContentPartTextParam,
50
+ )
51
+
52
+ if self.config.multi_modal:
53
+ texts_input = [
54
+ MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
55
+ ]
56
+ return self.multimodal_embeddings(inputs=texts_input, chunk_size=self.config.chunk_size)
57
+ return self.text_embedding(texts, chunk_size=self.config.chunk_size)
58
+
59
+ def text_embedding(self, inputs: list[str], chunk_size: int | None = None) -> list[list[float]]:
60
+ chunk_size_ = chunk_size or self.config.chunk_size
61
+ embeddings: list[list[float]] = []
62
+ for i in range(0, len(inputs), chunk_size_):
63
+ response = self.client.embeddings.create(
64
+ model=self.config.model_name_or_path,
65
+ input=inputs[i : i + chunk_size_],
66
+ )
67
+
68
+ data = [response.data] if isinstance(response.data, dict) else response.data
69
+ embeddings.extend(r.embedding for r in data)
70
+
71
+ return embeddings
72
+
73
+ def multimodal_embeddings(
74
+ self, inputs: list, chunk_size: int | None = None
75
+ ) -> list[list[float]]:
76
+ from volcenginesdkarkruntime.types.multimodal_embedding import (
77
+ MultimodalEmbeddingResponse, # noqa: TC002
78
+ )
79
+
80
+ chunk_size_ = chunk_size or self.config.chunk_size
81
+ embeddings: list[list[float]] = []
82
+
83
+ for i in range(0, len(inputs), chunk_size_):
84
+ response: MultimodalEmbeddingResponse = self.client.multimodal_embeddings.create(
85
+ model=self.config.model_name_or_path,
86
+ input=inputs[i : i + chunk_size_],
87
+ )
88
+
89
+ data = [response.data] if isinstance(response.data, dict) else response.data
90
+ embeddings.extend(r["embedding"] for r in data)
91
+
92
+ return embeddings
@@ -1,9 +1,11 @@
1
1
  from typing import Any, ClassVar
2
2
 
3
3
  from memos.configs.embedder import EmbedderConfigFactory
4
+ from memos.embedders.ark import ArkEmbedder
4
5
  from memos.embedders.base import BaseEmbedder
5
6
  from memos.embedders.ollama import OllamaEmbedder
6
7
  from memos.embedders.sentence_transformer import SenTranEmbedder
8
+ from memos.embedders.universal_api import UniversalAPIEmbedder
7
9
 
8
10
 
9
11
  class EmbedderFactory(BaseEmbedder):
@@ -12,6 +14,8 @@ class EmbedderFactory(BaseEmbedder):
12
14
  backend_to_class: ClassVar[dict[str, Any]] = {
13
15
  "ollama": OllamaEmbedder,
14
16
  "sentence_transformer": SenTranEmbedder,
17
+ "ark": ArkEmbedder,
18
+ "universal_api": UniversalAPIEmbedder,
15
19
  }
16
20
 
17
21
  @classmethod
@@ -1,6 +1,5 @@
1
- from sentence_transformers import SentenceTransformer
2
-
3
1
  from memos.configs.embedder import SenTranEmbedderConfig
2
+ from memos.dependency import require_python_package
4
3
  from memos.embedders.base import BaseEmbedder
5
4
  from memos.log import get_logger
6
5
 
@@ -11,7 +10,14 @@ logger = get_logger(__name__)
11
10
  class SenTranEmbedder(BaseEmbedder):
12
11
  """Sentence Transformer Embedder class."""
13
12
 
13
+ @require_python_package(
14
+ import_name="sentence_transformers",
15
+ install_command="pip install sentence-transformers",
16
+ install_link="https://www.sbert.net/docs/installation.html",
17
+ )
14
18
  def __init__(self, config: SenTranEmbedderConfig):
19
+ from sentence_transformers import SentenceTransformer
20
+
15
21
  self.config = config
16
22
  self.model = SentenceTransformer(
17
23
  self.config.model_name_or_path, trust_remote_code=self.config.trust_remote_code
@@ -0,0 +1,32 @@
1
+ from openai import AzureOpenAI as AzureClient
2
+ from openai import OpenAI as OpenAIClient
3
+
4
+ from memos.configs.embedder import UniversalAPIEmbedderConfig
5
+ from memos.embedders.base import BaseEmbedder
6
+
7
+
8
+ class UniversalAPIEmbedder(BaseEmbedder):
9
+ def __init__(self, config: UniversalAPIEmbedderConfig):
10
+ self.provider = config.provider
11
+ self.config = config
12
+
13
+ if self.provider == "openai":
14
+ self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url)
15
+ elif self.provider == "azure":
16
+ self.client = AzureClient(
17
+ azure_endpoint=config.base_url,
18
+ api_version="2024-03-01-preview",
19
+ api_key=config.api_key,
20
+ )
21
+ else:
22
+ raise ValueError(f"Unsupported provider: {self.provider}")
23
+
24
+ def embed(self, texts: list[str]) -> list[list[float]]:
25
+ if self.provider == "openai" or self.provider == "azure":
26
+ response = self.client.embeddings.create(
27
+ model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
28
+ input=texts,
29
+ )
30
+ return [r.embedding for r in response.data]
31
+ else:
32
+ raise ValueError(f"Unsupported provider: {self.provider}")
memos/graph_dbs/base.py CHANGED
@@ -9,12 +9,12 @@ class BaseGraphDB(ABC):
9
9
 
10
10
  # Node (Memory) Management
11
11
  @abstractmethod
12
- def add_node(self, id: str, content: str, metadata: dict[str, Any]) -> None:
12
+ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
13
13
  """
14
14
  Add a memory node to the graph.
15
15
  Args:
16
16
  id: Unique identifier for the memory node.
17
- content: Raw memory content (e.g., text).
17
+ memory: Raw memory content (e.g., text).
18
18
  metadata: Dictionary of metadata (e.g., timestamp, tags, source).
19
19
  """
20
20
 
@@ -3,6 +3,7 @@ from typing import Any, ClassVar
3
3
  from memos.configs.graph_db import GraphDBConfigFactory
4
4
  from memos.graph_dbs.base import BaseGraphDB
5
5
  from memos.graph_dbs.neo4j import Neo4jGraphDB
6
+ from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
6
7
 
7
8
 
8
9
  class GraphStoreFactory(BaseGraphDB):
@@ -10,6 +11,7 @@ class GraphStoreFactory(BaseGraphDB):
10
11
 
11
12
  backend_to_class: ClassVar[dict[str, Any]] = {
12
13
  "neo4j": Neo4jGraphDB,
14
+ "neo4j-community": Neo4jCommunityGraphDB,
13
15
  }
14
16
 
15
17
  @classmethod